In [1]:
import sys
sys.path.append('..')

import pickle
import os
from tqdm import tqdm
from importlib import reload

import scanpy as sc

from cell2gsea import models, gnn_config
from cell2gsea.utils import *

from sklearn.model_selection import train_test_split

from torch_geometric.data import Data
from torch_geometric.loader import NodeLoader
from torch_geometric.sampler.neighbor_sampler import NeighborSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
conf = gnn_config()

In [3]:
# load model and dataset --> umap gene set scores according to cell type?
# or maybe plot top N most active gene sets across cell type?

model_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/finished_datasets/local(192)/local(192)_2024_12_27__20_40_56/best_model_checkpoint.pt"
# model_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/sixth_set/local(638)/local(638)_2025_01_18__15_56_28/best_model_checkpoint.pt"

saved_input_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/finished_datasets/local(192)/training_inputs.pickle"
# saved_input_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/sixth_set/local(638)/training_inputs.pickle"

saved_output_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/finished_datasets/local(192)/training_output.pickle"
# saved_output_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/sixth_set/local(638)/training_output.pickle"

In [4]:
with open(saved_input_path, 'rb') as f:
    inputs = pickle.load(f)

In [5]:
inputs.keys()

dict_keys(['input_gene_set_csv_path', 'input_h5ad_file_path', 'gene_set_names', 'gene_set_sizes', 'gene_names', 'cell_ids', 'extra', 'train_x_sparse', 'train_edges', 'progs_sparse', 'prog_clusters'])

In [6]:
progs = inputs['progs_sparse']
progs = progs.toarray()
num_prog, num_genes = progs.shape

In [7]:
n_cell, _ = inputs['train_x_sparse'].shape
input_train_x = inputs['train_x_sparse'].toarray() 

In [8]:
edge_list = inputs['train_edges']
edge_list = np.array(edge_list)

In [26]:
# with open(saved_output_path, 'rb') as f:
#     outputs = pickle.load(f) 

In [10]:
# outputs['output'].keys()

dict_keys(['training_num_rep', 'training_cell_r2', 'training_gene_r2', 'training_source_epoch', 'training_assigned_prog_scores', 'validation_num_rep', 'validation_cell_r2', 'validation_gene_r2', 'validation_source_epoch', 'validation_assigned_prog_scores'])

### Create data loader to see number of samples (testing)

In [25]:
train_indices, val_indices = train_test_split(
    np.arange(n_cell),
    test_size = 0.2,
    random_state = 42
)

In [37]:
train_x = input_train_x[train_indices]
val_x = input_train_x[val_indices]

n_train, _ = train_x.shape

In [None]:
# set up train loader -- I want to see why these datasets are not training well
train_edges = get_knn_edges(train_x, conf.KNN_GRAPH_K, conf.KNN_GRAPH_N_PCA)

{'cluster_500': array([113,  37,  37, 290]),
 'cluster_475': array([799,  97,  97,  73]),
 'cluster_450': array([662, 206, 206, 135]),
 'cluster_425': array([856, 366, 366,  98]),
 'cluster_400': array([722, 473, 473, 114]),
 'cluster_375': array([732, 531, 531, 178]),
 'cluster_350': array([736, 561, 561, 267]),
 'cluster_325': array([737, 579, 579, 327]),
 'cluster_300': array([741, 582, 582, 358]),
 'cluster_275': array([741, 579, 579, 378]),
 'cluster_250': array([737, 577, 577, 382]),
 'cluster_225': array([739, 577, 577, 382]),
 'cluster_200': array([750, 576, 576, 382]),
 'cluster_175': array([739, 579, 579, 382]),
 'cluster_150': array([739, 581, 581, 383]),
 'cluster_125': array([737, 578, 578, 386]),
 'cluster_100': array([443, 505, 505, 374]),
 'cluster_75': array([432, 491, 491, 381]),
 'cluster_50': array([459, 442, 442, 388]),
 'cluster_25': array([ 74,  74,  74, 387]),
 'cluster_10': array([ 54,  54,  54, 101]),
 'cluster_1': array([5, 5, 5, 5])}

In [38]:
train_edge_list = np.array(train_edges)
train_tru_labels = np.zeros(n_train)

In [39]:
train_X = torch.tensor(train_x, dtype=torch.float32)
train_labels = torch.tensor(train_tru_labels, dtype=torch.long)
train_edge_list = torch.tensor(train_edge_list, dtype=torch.long)
train_node_pos = torch.arange(n_train).reshape(n_train,1)
train_data = Data(x=train_X , edge_index=train_edge_list,y=train_labels,pos = train_node_pos)

In [None]:
train_neighbor_sampler = NeighborSampler(
    train_data,
    num_neighbors=[conf.GSAMP_NUM_NBR, conf.GSAMP_NUM_NBR]
)

In [None]:
train_loader = NodeLoader(
    train_data,
    node_sampler=train_neighbor_sampler,
    batch_size = conf.GSAMP_BATCH_SIZE,  # number of seed nodes
    num_workers = conf.GSAMP_NUM_WORKER,
)

### Inference (continued)

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [12]:
trained_model = models.get_gnn_model(num_genes, num_prog, conf).to(device)

In [13]:
trained_model

gene_program_model_gcn_nonneg(
  (conv1): GCNConv(17635, 128)
  (act1): ReLU()
  (drop1): Dropout(p=0.5, inplace=False)
  (conv2): GCNConv(128, 128)
  (fc1): Linear(in_features=128, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=31709, bias=True)
)

In [14]:
checkpoint = torch.load(model_path)
trained_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
inference_X = torch.tensor(input_train_x, dtype=torch.float32)
inference_edge_list = torch.tensor(edge_list, dtype=torch.long)

inference_data = Data(x=inference_X, edge_index=inference_edge_list)

In [16]:
inference_neighbor_sampler = NeighborSampler(
    inference_data,
    num_neighbors=[conf.GSAMP_NUM_NBR, conf.GSAMP_NUM_NBR]
)

In [17]:
inference_loader = NodeLoader(
    inference_data,
    node_sampler = inference_neighbor_sampler,
    batch_size = conf.GSAMP_BATCH_SIZE,  # number of seed nodes
    num_workers = conf.GSAMP_NUM_WORKER,
)

In [18]:
n_cell, n_genes = input_train_x.shape

device_name = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
log_str(f"Using device: {device_name}")
device = torch.device(device_name)

model = trained_model.to(device)
model.eval()

dummy_x = torch.zeros(1,n_genes,device=device)
dummy_edges = torch.tensor([range(1),range(1)], dtype=torch.long,device=device)
dummy_out = model(dummy_x,dummy_edges)
_, n_progs = dummy_out.shape
output_scores = np.zeros((n_cell,n_progs))

log_str(f"Running inference: {n_cell} cells - {n_genes} genes - {n_progs} programs")
with torch.no_grad():
    for idx, subgraph in tqdm(enumerate(inference_loader)):
        n_nodes, _ = subgraph.x.shape
        original_indices = subgraph.input_id
        subgraph = subgraph.to(device)
        batch_identity_adj = torch.tensor([range(n_nodes),range(n_nodes)], dtype = torch.long, device=device)
        edges = torch.cat([batch_identity_adj,subgraph.edge_index],dim = 1)
        batch_program_scores = model(subgraph.x,edges)
        seed_nodes_prog_scores = batch_program_scores[:len(original_indices),:].detach().cpu().numpy()
        output_scores[original_indices,:] = seed_nodes_prog_scores

[2025_01_25__18_13_36] Using device: cuda
[2025_01_25__18_13_36] Running inference: 68458 cells - 17635 genes - 31709 programs


6846it [01:22, 83.32it/s]


In [19]:
output_scores

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 1.05615652, 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 1.68741798, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [21]:
# saved output_scores
save_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/finished_datasets/local(192)/output_scores.pickle"
# save_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/sixth_set/local(638)/output_scores.pickle"

with open(save_path, 'wb') as f:
    pickle.dump(output_scores, f)

In [22]:
len(inputs['gene_set_names'])

31709

In [23]:
inputs['input_h5ad_file_path']

'/home/sr2464/scratch/C2S_Files/Cell2Sentence_Datasets/hca_cellxgene_cleaned_h5ad/local(192)_cleaned.h5ad'

In [24]:
adata = sc.read_h5ad(inputs['input_h5ad_file_path'])

In [25]:
import pandas as pd
import numpy as np

# Assuming output_scores is a numpy array with shape (n_cells, n_gene_programs)
# inputs['gene_set_names'] is a list of gene program names matching the columns of output_scores
# adata.obs['cell_type'] and adata.obs['disease'] contain cell types and disease states, respectively

# Convert output_scores to a DataFrame for easier manipulation
gene_program_names = inputs['gene_set_names']
cell_types = adata.obs['cell_type']
diseases = adata.obs['disease']

# Create a DataFrame
output_scores_df = pd.DataFrame(output_scores, columns=gene_program_names)
output_scores_df['cell_type'] = cell_types.values
output_scores_df['disease'] = diseases.values

# Group by disease and cell type
grouped = output_scores_df.groupby(['disease', 'cell_type'])

# Prepare a list to store results for saving into a CSV
results_list = []

# Find top 5 gene programs for each group
for (disease, cell_type), group in grouped:
    # Calculate mean scores for each gene program
    mean_scores = group[gene_program_names].mean()
    
    # Get the top 5 gene programs by mean activity
    top_gene_programs = mean_scores.nlargest(5).index.tolist()
    top_scores = mean_scores.nlargest(5).values.tolist()
    
    # Print the results
    print(f"Disease: {disease}, Cell type: {cell_type}")
    print("Top 5 gene programs:")
    for program, score in zip(top_gene_programs, top_scores):
        print(f"  {program}: {score:.3f}")
    
    # Add results to the list for CSV
    for program, score in zip(top_gene_programs, top_scores):
        results_list.append({
            'Disease': disease,
            'Cell Type': cell_type,
            'Gene Program': program,
            'Mean Activity Score': score
        })

# Save the results to a CSV file
results_df = pd.DataFrame(results_list)
# csv_path = "/home/ddz5/scratch/Cell2GSEA_QA_dataset_models/finished_datasets/local(192)/top_gene_programs.csv"  # Replace with your desired path
results_df.to_csv(csv_path, index=False)

print(f"\nResults saved to {csv_path}")

Disease: normal, Cell type: B cell
Top 5 gene programs:
  CEBPZ_TARGET_GENES: 108.484
  BARX2_TARGET_GENES: 108.316
  SFMBT1_TARGET_GENES: 100.423
  DODD_NASOPHARYNGEAL_CARCINOMA_UP: 98.542
  AEBP2_TARGET_GENES: 92.064
Disease: normal, Cell type: T cell
Top 5 gene programs:
  BARX2_TARGET_GENES: 110.151
  CEBPZ_TARGET_GENES: 109.900
  SFMBT1_TARGET_GENES: 102.248
  DODD_NASOPHARYNGEAL_CARCINOMA_UP: 100.325
  AEBP2_TARGET_GENES: 93.351
Disease: normal, Cell type: endothelial cell
Top 5 gene programs:
  BARX2_TARGET_GENES: 94.229
  CEBPZ_TARGET_GENES: 93.957
  CUI_TCF21_TARGETS_2_DN: 89.741
  SFMBT1_TARGET_GENES: 88.157
  DODD_NASOPHARYNGEAL_CARCINOMA_UP: 87.383
Disease: normal, Cell type: epithelial cell of proximal tubule
Top 5 gene programs:
  BARX2_TARGET_GENES: 160.901
  CEBPZ_TARGET_GENES: 155.391
  SFMBT1_TARGET_GENES: 149.729
  DODD_NASOPHARYNGEAL_CARCINOMA_UP: 145.281
  FEV_TARGET_GENES: 135.887
Disease: normal, Cell type: fibroblast
Top 5 gene programs:
  BARX2_TARGET_GENES: 11