In [1]:
import os 
from os.path import join, abspath, dirname
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

from torch_geometric.utils import negative_sampling
import lightning.pytorch as pl

from scGraphLLM.data import *
from scGraphLLM.GNN_modules import *
from scGraphLLM.MLP_modules import *
from scGraphLLM._globals import *
from scGraphLLM.flash_transformer import GDTransformer
from scGraphLLM.config import *

from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
import tqdm
import matplotlib.pyplot as plt
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
torch.cuda.is_available()

True

In [3]:
gene_to_node = pd.read_csv("/hpc/projects/group.califano/GLM/data/cellxgene_gene2index.csv")

In [4]:
gene_to_node = gene_to_node.set_index("gene_name")["idx"].to_dict()

In [5]:
class GeneEmbeddingDataset(Dataset):
    def __init__(self, path):
        assert os.path.exists(path), f"File not found: {path}"
        self.path = path
        self.embedding = np.load(self.path, allow_pickle=True)
        self.x = self.embedding["x"]
        self.pad_indices = self.embedding["pad_indices"]
        self.edges = self.embedding["edges"].item()

    def __len__(self):
        return len(self.embedding["x"])

    def __getitem__(self, idx):
        
        return {
            "x": torch.tensor(self.x[idx]), 
            "pad_indices": torch.tensor(self.pad_indices[idx]),
            "edges": torch.tensor(self.edges[idx])
        }

def embedding_collate_fn(batch):
    return {
        "x": torch.stack([item["x"] for item in batch]),
        "pad_indices": torch.stack([item["pad_indices"] for item in batch]),
        "edges": [item["edges"] for item in batch]
    }
    
scgpt_embedding_dataset = GeneEmbeddingDataset(
    path="/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/cd8_t_cells/embeddings/scgpt/embedding.npz"
)


In [6]:
from torch.utils.data import Dataset, DataLoader
scgpt_dataloader = DataLoader(
    scgpt_embedding_dataset, 
    batch_size=32, 
    shuffle=False,
    collate_fn=embedding_collate_fn
)



In [10]:
link_predictor_scgpt = LinkPredictHead(512, 1).to("cuda")

In [11]:
x = scgpt_embedding_dataset[0]["x"].to(device)

In [12]:
x.shape

torch.Size([1048, 512])

In [13]:
x_i, x_j = x[0,:].to(device), x[1,:].to(device)

In [14]:
x_i.shape

torch.Size([512])

In [15]:
link_predictor_scgpt(x, x)

tensor([[0.4882],
        [0.5073],
        [0.4822],
        ...,
        [0.5088],
        [0.5088],
        [0.5088]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [None]:
def link_pred_loss(predictor, node_embedding, pad_indices, edge_index_list):
    """
    predictor (nn.Module), predictor module
    node_embedding: node embedding matrix
    mask_locs: mask of valid nodes
    edge_index_list:
    """
    pos_out = []
    neg_out = []
    pos_labels = []
    neg_labels = []

    batch_size, num_nodes, embed_dim = node_embedding.shape
    device = node_embedding.device

    for i in range(batch_size):
        # masked_nodes = torch.where(mask_locs[batch])[0]
        # if masked_nodes.numel() == 0:
        #     continue
        # masked_nodes = masked_nodes.to(device)
        # edge_index = edge_index_list[batch].to(device)
        # masked_nodes_bool = torch.zeros(num_nodes, dtype=torch.bool, device=device)
        # masked_nodes_bool[masked_nodes] = True

        # src_nodes = edge_index[0]
        # dst_nodes = edge_index[1]
        # edge_mask = masked_nodes_bool[src_nodes] & masked_nodes_bool[dst_nodes]
        # pos_edge_index = edge_index[:, edge_mask]

        pos_edge_index = edge_index_list[i].to(device)
        num_nodes = pad_indices[i]

        if pos_edge_index.size(1) == 0:
            continue

        num_neg_samples = pos_edge_index.size(1)
        neg_edge_index = negative_sampling(
            # edge_index=edge_index,
            edge_index=pos_edge_index,
            num_nodes=num_nodes,
            num_neg_samples=num_neg_samples,
            method="sparse"
        ).to(device)

        # Positive scores
        src_emb_pos = node_embedding[i, pos_edge_index[0]]
        dst_emb_pos = node_embedding[i, pos_edge_index[1]]
        pos_scores = predictor(src_emb_pos, dst_emb_pos)
        pos_out.append(pos_scores)
        pos_labels.append(torch.ones_like(pos_scores, device=device))  # Positive labels (1)

        # Negative scores
        src_emb_neg = node_embedding[i, neg_edge_index[0]]
        dst_emb_neg = node_embedding[i, neg_edge_index[1]]
        neg_scores = predictor(src_emb_neg, dst_emb_neg)
        neg_out.append(neg_scores)
        neg_labels.append(torch.zeros_like(neg_scores, device=device))  # Negative labels (0)

    if pos_out:
        pos_out = torch.cat(pos_out, dim=0)
        neg_out = torch.cat(neg_out, dim=0)
        pos_labels = torch.cat(pos_labels, dim=0)
        neg_labels = torch.cat(neg_labels, dim=0)

        # Loss calculation
        pos_loss = -torch.log(pos_out + 1e-10).mean()
        neg_loss = -torch.log(1 - neg_out + 1e-10).mean()

        # Concatenate outputs and labels
        all_outputs = torch.cat([pos_out, neg_out], dim=0)
        all_labels = torch.cat([pos_labels, neg_labels], dim=0)

        return pos_loss + neg_loss, all_outputs, all_labels
    else:
        return torch.tensor(0.0, device=device), torch.tensor([], device=device), torch.tensor([], device=device)



In [24]:
for batch in scgpt_dataloader:
    print(f"x.shape: {batch['x'].shape}")
    print(f"pad_indices.shape: {batch['pad_indices'].shape}")
    print(f"len(edges): {len(batch['edges'])}")
    # calculate link prediciton loss for batch
    loss, _, _ = link_pred_loss(
        predictor=link_predictor_scgpt,
        node_embedding=batch["x"], 
        pad_indices=batch["pad_indices"],
        edge_index_list=batch["edges"]
    )
    pass

x.shape: torch.Size([10, 1048, 512])
pad_indices.shape: torch.Size([10])
len(edges): 10


IndexError: index 1 is out of bounds for dimension 0 with size 0

In [None]:
transformer_data_module = GraphTransformerDataModule(
    graph_kernel_attn_4096.data_config, 
    collate_fn=collate_fn
)
train_transformer_dl = transformer_data_module.train_dataloader()
val_transformer_dl = transformer_data_module.val_dataloader()

/hpc/projects/group.califano/GLM/data/cxg_cache_4096/train


/hpc/projects/group.califano/GLM/data/cxg_cache_4096/valSG
/hpc/projects/group.califano/GLM/data/cxg_cache_4096/valHOG


In [5]:
val_hog_data = val_transformer_dl[-1]
val_sg_data = val_transformer_dl[0]

In [6]:
def link_pred_loss(predictor, node_embedding, mask_locs, edge_index_list):
    pos_out = []
    neg_out = []
    pos_labels = []
    neg_labels = []

    batch_size, num_nodes, embed_dim = node_embedding.shape
    device = node_embedding.device

    for batch in range(batch_size):
        masked_nodes = torch.where(mask_locs[batch])[0]
        if masked_nodes.numel() == 0:
            continue
        masked_nodes = masked_nodes.to(device)
        edge_index = edge_index_list[batch].to(device)
        masked_nodes_bool = torch.zeros(num_nodes, dtype=torch.bool, device=device)
        masked_nodes_bool[masked_nodes] = True
        src_nodes = edge_index[0]
        dst_nodes = edge_index[1]
        edge_mask = masked_nodes_bool[src_nodes] & masked_nodes_bool[dst_nodes]
        pos_edge_index = edge_index[:, edge_mask]
        if pos_edge_index.size(1) == 0:
            continue

        num_neg_samples = pos_edge_index.size(1)
        neg_edge_index = negative_sampling(
            edge_index=edge_index,
            num_nodes=num_nodes,
            num_neg_samples=num_neg_samples,
            method='sparse'
        ).to(device)

        # Positive scores
        src_emb_pos = node_embedding[batch, pos_edge_index[0]]
        dst_emb_pos = node_embedding[batch, pos_edge_index[1]]
        pos_scores = predictor(src_emb_pos, dst_emb_pos)
        pos_out.append(pos_scores)
        pos_labels.append(torch.ones_like(pos_scores, device=device))  # Positive labels (1)

        # Negative scores
        src_emb_neg = node_embedding[batch, neg_edge_index[0]]
        dst_emb_neg = node_embedding[batch, neg_edge_index[1]]
        neg_scores = predictor(src_emb_neg, dst_emb_neg)
        neg_out.append(neg_scores)
        neg_labels.append(torch.zeros_like(neg_scores, device=device))  # Negative labels (0)

    if pos_out:
        pos_out = torch.cat(pos_out, dim=0)
        neg_out = torch.cat(neg_out, dim=0)
        pos_labels = torch.cat(pos_labels, dim=0)
        neg_labels = torch.cat(neg_labels, dim=0)

        # Loss calculation
        pos_loss = -torch.log(pos_out + 1e-10).mean()
        neg_loss = -torch.log(1 - neg_out + 1e-10).mean()

        # Concatenate outputs and labels
        all_outputs = torch.cat([pos_out, neg_out], dim=0)
        all_labels = torch.cat([pos_labels, neg_labels], dim=0)

        return pos_loss + neg_loss, all_outputs, all_labels
    else:
        return torch.tensor(0.0, device=device), torch.tensor([], device=device), torch.tensor([], device=device)

In [7]:
def send_to_gpu(data):
    if isinstance(data, torch.Tensor):
        return data.to('cuda')  # Send tensor to GPU
    elif isinstance(data, list):
        return [send_to_gpu(item) for item in data]  # Recursively process lists
    elif isinstance(data, dict):
        return {key: send_to_gpu(value) for key, value in data.items()}  # Recursively process dicts
    else:
        return data  # If not a tensor or list/dict, leave unchanged

In [None]:
def fine_tune_train_step(pretrained_model, ft_model, batch, opt):
    pretrained_model.eval()
    ft_model.train()
    batch = send_to_gpu(batch)
    embedding, target_gene_ids, target_rank_ids, mask_locs, edge_index_list, num_nodes_list = pretrained_model(batch)
    L_g, _, _ = link_pred_loss(ft_model, embedding, mask_locs[0], edge_index_list)
    L_g.backward()
    opt.step()
    return L_g

def fine_tune(train_dataloader, pretrained_model, ft_model, lr=1e-3, num_epochs=100, max_num_batches=200):
    train_losses = []
    opt = torch.optim.Adam(ft_model.parameters(), lr=1e-3, weight_decay=1e-4)
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        
        # Training phase
        train_loss_epoch = 0
        train_batches = len(train_dataloader)
        num_batches = 0
        for batch in  tqdm.tqdm(train_dataloader, desc="Training", leave=False):
            train_loss = fine_tune_train_step(pretrained_model, ft_model, batch, opt)
            train_loss_epoch += train_loss.item()
            num_batches += 1
            if num_batches >= max_num_batches:
                break
        train_loss_epoch /= train_batches
        train_losses.append(train_loss_epoch)
        print(f"Train loss: {train_loss_epoch:.4f}")
    return train_losses

def predict_and_compare(test_dataloader, pretrain1, pretrain2, model1, model2, max_num_batches=100):
    model1.eval().to("cuda")
    model2.eval().to("cuda")
    pretrain1.eval().to("cuda")
    pretrain2.eval().to("cuda")
    
    all_preds1 = []
    all_preds2 = []
    all_labels = []
    n_b = 0
    for batch in tqdm.tqdm(test_dataloader, leave=False):
        batch = send_to_gpu(batch)
        embedding1, target_gene_ids, target_rank_ids, mask_locs, edge_index_list, num_nodes_list = pretrain1(batch)
        embedding2, target_gene_ids, target_rank_ids, mask_locs, edge_index_list, num_nodes_list = pretrain2(batch)
        
        L_g1, preds1, labels = link_pred_loss(model1, embedding1, mask_locs[0], edge_index_list)
        L_g2, preds2, labels = link_pred_loss(model2, embedding2, mask_locs[0], edge_index_list)
        
        all_preds1.extend(preds1.cpu().detach().numpy())
        all_preds2.extend(preds2.cpu().detach().numpy())
        all_labels.extend(labels.cpu().detach().numpy())
        
        n_b += 1
        if n_b >= max_num_batches:
            break
    
    # AUROC
    fpr1, tpr1, _ = roc_curve(all_labels, all_preds1)
    fpr2, tpr2, _ = roc_curve(all_labels, all_preds2)
    auc1 = auc(fpr1, tpr1)
    auc2 = auc(fpr2, tpr2)
    
    # PR
    p1, r1, _ = precision_recall_curve(all_labels, all_preds1)
    p2, r2, _ = precision_recall_curve(all_labels, all_preds2)
    apr1 = average_precision_score(all_labels, all_preds1)
    apr2 = average_precision_score(all_labels, all_preds2)
    
    return fpr1, tpr1, auc1, fpr2, tpr2, auc2, p1, r1, apr1, p2, r2, apr2

In [9]:
def auroc_curves(fpr1, tpr1, auc1, fpr2, tpr2, auc2):
    plt.figure(figsize=(8, 6))
    plt.plot(fpr1, tpr1, label=f"fine-tune w. vanilla embedding (AUC = {auc1:.3f})")
    plt.plot(fpr2, tpr2, label=f"fine-tune w. GraphDKA embedding (AUC = {auc2:.3f})")
    plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Fine tuning AUROC, link pred")
    plt.legend(loc="best")
    plt.grid()
    plt.show()

In [10]:
def pr_curves(precision1, recall1, ap1, precision2, recall2, ap2):
    plt.figure(figsize=(8, 6))

    # Plot the first PR curve
    plt.plot(recall1, precision1, label=f"fine-tune w. vanilla embedding (Avg. Prec. = {ap1:.3f})", linestyle='-', marker=None)

    # Plot the second PR curve
    plt.plot(recall2, precision2, label=f"fine-tune w. GraphDKA embedding (Avg. Prec. = {ap2:.3f})", linestyle='--', marker=None)

    # Add labels and legend
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Fine tuning Precision-Recall Curve, link pred")
    plt.legend(loc="best")
    plt.grid()

    # Set axis limits for better visualization
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])

    # Show the plot
    plt.show()

In [11]:
def random_sample_batches(dataloader, n):
    # Convert the dataloader to a list of batches
    batches = list(dataloader)
    # Randomly sample n batches
    sampled_batches = random.sample(batches, n)
    return sampled_batches

In [12]:
vanilla_lp = LinkPredictHead(256, 1).to("cuda")
gdk_lp = LinkPredictHead(256, 1).to("cuda")

In [15]:
link_predictor_geneformer = LinkPredictHead(256, 1).to("cuda")

In [16]:
link_predictor_scf = LinkPredictHead(512, 1).to("cuda")

In [None]:
link_predictor_scgpt = LinkPredictHead(512, 1).to("cuda")

In [13]:
vanilla_loss = fine_tune(val_sg_data, pretrained_model=vanilla_model, ft_model=vanilla_lp, num_epochs=1, max_num_batches=100)

NameError: name 'vanilla_model' is not defined