In [2]:
import os 
from os.path import join, abspath, dirname
import sys
sys.path.append('..')
from argparse import ArgumentParser

from torch_geometric.utils import negative_sampling
import lightning.pytorch as pl

from scGraphLLM.data import *
from scGraphLLM.GNN_modules import *
from scGraphLLM.MLP_modules import *
from scGraphLLM._globals import *
#from scGraphLLM.flash_transformer import GDTransformer
from scGraphLLM.config import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
import tqdm
import matplotlib.pyplot as plt
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from pathlib import Path
cell_type_dir = "/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type"

def get_subdirs_with_info_json(root_dir):
    root = Path(root_dir)
    return [
        subdir.name 
        for subdir in root.iterdir() 
        if subdir.is_dir() and any(child.name.startswith("info") and child.suffix == ".json" 
                                   for child in subdir.iterdir())
    ]

# Example usage:
root_directory = "/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type"
subdirs = get_subdirs_with_info_json(root_directory)
subdirs

['cd14_monocytes',
 'cd16_monocytes',
 'cd20_b_cells',
 'cd4_t_cells',
 'cd8_t_cells',
 'erythrocytes',
 'monocyte-derived_dendritic_cells',
 'nk_cells',
 'nkt_cells']

In [3]:
import sys
sys.path.append("/hpc/mydata/rowan.cassius/scGraphLLM/embeddings")
from benchmark import GeneEmbeddingDataset

In [4]:
immune_cell_types = [
  "cd14_monocytes",
  "cd16_monocytes",
  "cd20_b_cells",
  "cd4_t_cells",
  "cd8_t_cells",
  "erythrocytes",
  "monocyte-derived_dendritic_cells",
  "nk_cells",
  "nkt_cells"
]

In [5]:
train_cell_types = [
  "cd14_monocytes",
#   "cd16_monocytes",
#   "cd20_b_cells",
#   "cd4_t_cells",
#   "cd8_t_cells",
#   "erythrocytes",
#   "monocyte-derived_dendritic_cells",
#   "nk_cells",
  "nkt_cells"
]

In [6]:
len(immune_cell_types)

9

In [None]:
scgpt_embedding_path = "/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/{}/embeddings/scgpt/embedding.npz"
# train_dataset = GeneEmbeddingDataset(paths=[scgpt_embedding_path.format(cell_type) for cell_type in train_cell_types])

In [None]:
len(train_dataset)

In [7]:
cd14_monocytes_dataset = GeneEmbeddingDataset(
    paths="/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/cd14_monocytes/embeddings/scgpt/embedding.npz"
)

In [18]:
len(cd14_monocytes_dataset)

1000

In [20]:
nkt_cells_dataset = GeneEmbeddingDataset(
    paths="/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/nkt_cells/embeddings/scgpt/embedding.npz"
)

: 

In [None]:
len(nkt_cells_dataset)

In [16]:
import sys

def format_size(size):
    """Converts the size in bytes to a human-readable format (KB, MB, GB)."""
    if size < 1024:
        return f"{size} B"
    elif size < 1024**2:
        return f"{size / 1024:.2f} KB"
    elif size < 1024**3:
        return f"{size / 1024**2:.2f} MB"
    else:
        return f"{size / 1024**3:.2f} GB"

def get_object_memory_size(obj, seen=None, format=True):
    """Recursively computes the memory usage of an object, accounting for referenced objects."""
    if seen is None:
        seen = set()

    # Check if we've already seen this object (to avoid infinite recursion in case of circular references)
    if id(obj) in seen:
        return 0
    
    # Mark this object as seen
    seen.add(id(obj))

    size = sys.getsizeof(obj)

    # If the object is a container (e.g., list, dict, etc.), recursively check its contents
    if isinstance(obj, dict):
        size += sum([get_object_memory_size(v, seen, format=False) + get_object_memory_size(k, seen, format=False) for k, v in obj.items()])
    elif isinstance(obj, (list, tuple, set)):
        size += sum([get_object_memory_size(i, seen, format=False) for i in obj])

    # If the 'format' flag is True, format the result into a human-readable form
    if format:
        return format_size(size)
    
    return size

# Example usage:
my_list = [i for i in range(1000000)]
print(get_object_memory_size(my_list))

34.76 MB


In [19]:
get_object_memory_size(cd14_monocytes_dataset)

'48 B'

In [None]:
cd14_monocytes_dataset = GeneEmbeddingDataset(
    paths="/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/cd14_monocytes/embeddings/scgpt/embedding.npz"
)

In [2]:
import scanpy as sc

In [3]:
cd14_monocyte_cells = sc.read_h5ad("/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/cd14_monocytes/cells.h5ad")

In [4]:
cd14_monocyte_cells

AnnData object with n_obs × n_vars = 6158 × 11971
    obs: 'batch', 'chemistry', 'data_type', 'dpt_pseudotime', 'final_annotation', 'mt_frac', 'n_counts', 'n_genes', 'sample_ID', 'size_factors', 'species', 'study', 'tissue', 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'sample_id', 'cluster'
    var: 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'cluster', 'log1p', 'neighbors', 'pca'
    obsm: 'X_pca'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'connectivities', 'distances'

In [6]:
cd20_b_cells = sc.read_h5ad("/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/cd20_b_cells/cells.h5ad")

In [8]:
cd4_t_cells = sc.read_h5ad("/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/cd4_t_cells/cells.h5ad")

In [9]:
cd4_t_cells

AnnData object with n_obs × n_vars = 10460 × 11971
    obs: 'batch', 'chemistry', 'data_type', 'dpt_pseudotime', 'final_annotation', 'mt_frac', 'n_counts', 'n_genes', 'sample_ID', 'size_factors', 'species', 'study', 'tissue', 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'sample_id', 'cluster'
    var: 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'cluster', 'log1p', 'neighbors', 'pca'
    obsm: 'X_pca'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'connectivities', 'distances'

In [18]:
cd16_monocytes_dataset = GeneEmbeddingDataset(
    paths="/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/cd16_monocytes/embeddings/scgpt/embedding.npz"
)

In [None]:
len(cd16_monocytes_dataset)

972

In [20]:
nkt_cells_dataset = GeneEmbeddingDataset(
    paths="/hpc/mydata/rowan.cassius/data/scGPT/human_immune/cell_type/nkt_cells/embeddings/scgpt/embedding.npz"
)

In [None]:
len(nkt_cells_dataset)

In [None]:
for i in dataset:
    
    print(f"x.shape: {i['x'].shape}")
    print(f"seq_length: {i['seq_lengths']}")
    # print(i['edges'])
    print(f"edges.shape: {i['edges'].shape}")


In [None]:
len(dataset)

In [None]:
def predict_and_compare(test_dataloader, pretrain1, pretrain2, model1, model2, max_num_batches=100):
    model1.eval().to("cuda")
    model2.eval().to("cuda")
    pretrain1.eval().to("cuda")
    pretrain2.eval().to("cuda")
    
    all_preds1 = []
    all_preds2 = []
    all_labels = []
    n_b = 0
    for batch in tqdm.tqdm(test_dataloader, leave=False):
        batch = send_to_gpu(batch)
        embedding1, target_gene_ids, target_rank_ids, mask_locs, edge_index_list, num_nodes_list = pretrain1(batch)
        embedding2, target_gene_ids, target_rank_ids, mask_locs, edge_index_list, num_nodes_list = pretrain2(batch)
        
        L_g1, preds1, labels = link_pred_loss(model1, embedding1, mask_locs[0], edge_index_list)
        L_g2, preds2, labels = link_pred_loss(model2, embedding2, mask_locs[0], edge_index_list)
        
        all_preds1.extend(preds1.cpu().detach().numpy())
        all_preds2.extend(preds2.cpu().detach().numpy())
        all_labels.extend(labels.cpu().detach().numpy())
        
        n_b += 1
        if n_b >= max_num_batches:
            break
    
    # AUROC
    fpr1, tpr1, _ = roc_curve(all_labels, all_preds1)
    fpr2, tpr2, _ = roc_curve(all_labels, all_preds2)
    auc1 = auc(fpr1, tpr1)
    auc2 = auc(fpr2, tpr2)
    
    # PR
    p1, r1, _ = precision_recall_curve(all_labels, all_preds1)
    p2, r2, _ = precision_recall_curve(all_labels, all_preds2)
    apr1 = average_precision_score(all_labels, all_preds1)
    apr2 = average_precision_score(all_labels, all_preds2)
    
    return fpr1, tpr1, auc1, fpr2, tpr2, auc2, p1, r1, apr1, p2, r2, apr2

In [None]:
def predict_and_compare(test_dataloader, pretrain1, pretrain2, model1, model2, max_num_batches=100):
    model1.eval().to("cuda")
    model2.eval().to("cuda")
    pretrain1.eval().to("cuda")
    pretrain2.eval().to("cuda")
    
    all_preds1 = []
    all_preds2 = []
    all_labels = []
    n_b = 0
    for batch in tqdm.tqdm(test_dataloader, leave=False):
        batch = send_to_gpu(batch)
        embedding1, target_gene_ids, target_rank_ids, mask_locs, edge_index_list, num_nodes_list = pretrain1(batch)
        embedding2, target_gene_ids, target_rank_ids, mask_locs, edge_index_list, num_nodes_list = pretrain2(batch)
        
        L_g1, preds1, labels = link_pred_loss(model1, embedding1, mask_locs[0], edge_index_list)
        L_g2, preds2, labels = link_pred_loss(model2, embedding2, mask_locs[0], edge_index_list)
        
        all_preds1.extend(preds1.cpu().detach().numpy())
        all_preds2.extend(preds2.cpu().detach().numpy())
        all_labels.extend(labels.cpu().detach().numpy())
        
        n_b += 1
        if n_b >= max_num_batches:
            break
    
    # AUROC
    fpr1, tpr1, _ = roc_curve(all_labels, all_preds1)
    fpr2, tpr2, _ = roc_curve(all_labels, all_preds2)
    auc1 = auc(fpr1, tpr1)
    auc2 = auc(fpr2, tpr2)
    
    # PR
    p1, r1, _ = precision_recall_curve(all_labels, all_preds1)
    p2, r2, _ = precision_recall_curve(all_labels, all_preds2)
    apr1 = average_precision_score(all_labels, all_preds1)
    apr2 = average_precision_score(all_labels, all_preds2)
    
    return fpr1, tpr1, auc1, fpr2, tpr2, auc2, p1, r1, apr1, p2, r2, apr2

In [None]:
def auroc_curves(fpr1, tpr1, auc1, fpr2, tpr2, auc2):
    plt.figure(figsize=(8, 6))
    plt.plot(fpr1, tpr1, label=f"fine-tune w. vanilla embedding (AUC = {auc1:.3f})")
    plt.plot(fpr2, tpr2, label=f"fine-tune w. GraphDKA embedding (AUC = {auc2:.3f})")
    plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Fine tuning AUROC, link pred")
    plt.legend(loc="best")
    plt.grid()
    plt.show()

In [None]:
def pr_curves(precision1, recall1, ap1, precision2, recall2, ap2):
    plt.figure(figsize=(8, 6))

    # Plot the first PR curve
    plt.plot(recall1, precision1, label=f"fine-tune w. vanilla embedding (Avg. Prec. = {ap1:.3f})", linestyle='-', marker=None)

    # Plot the second PR curve
    plt.plot(recall2, precision2, label=f"fine-tune w. GraphDKA embedding (Avg. Prec. = {ap2:.3f})", linestyle='--', marker=None)

    # Add labels and legend
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Fine tuning Precision-Recall Curve, link pred")
    plt.legend(loc="best")
    plt.grid()

    # Set axis limits for better visualization
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])

    # Show the plot
    plt.show()

In [None]:
def random_sample_batches(dataloader, n):
    # Convert the dataloader to a list of batches
    batches = list(dataloader)
    # Randomly sample n batches
    sampled_batches = random.sample(batches, n)
    return sampled_batches

In [None]:
vanilla_lp = LinkPredictHead(256, 1).to("cuda")
gdk_lp = LinkPredictHead(256, 1).to("cuda")

In [None]:
link_predictor_geneformer = LinkPredictHead(256, 1).to("cuda")

In [None]:
link_predictor_scf = LinkPredictHead(512, 1).to("cuda")

In [None]:
link_predictor_scgpt = LinkPredictHead(512, 1).to("cuda")

In [None]:
vanilla_loss = fine_tune(val_sg_data, pretrained_model=vanilla_model, ft_model=vanilla_lp, num_epochs=1, max_num_batches=100)