### Example Dataset and Resource Requirements

This example data is a downsampled version of the real dataset, used solely to demonstrate the model's workflow.

- **Memory Requirements**:
  - A minimum of **64GB** of memory is recommended to ensure smooth data processing and model training.
  - In memory-limited environments, the memory requirement can be reduced by downsizing the dataset (e.g., selecting a smaller cell_size or reducing the number of neighbors to decrease the subgraph size).

In [1]:
import os
import time
import math
import gzip
import re
import numpy as np
import pandas as pd
import EmitGCL 
from EmitGCL.utils import *
from EmitGCL.loss_function import *
from EmitGCL.conv import *
from EmitGCL.emitgcl_model import *

In [2]:
def get_cancer_metastasis_genes():
    kegg = KEGG()

    # Define cancer metastasis pathways
    cancer_metastasis_pathways = {
        'Protein processing in endoplasmic reticulum': 'hsa04141',
        'mTOR signaling pathway': 'hsa04150',
        'NF-kappa B signaling pathway': 'hsa04064',
        'Autophagy': 'hsa04140',
        'p53 signaling pathway': 'hsa04115',
        'Apoptosis': 'hsa04210'
    }


    pathway_genes = {}

    # Iterate over each pathway
    for pathway_name, pathway_id in cancer_metastasis_pathways.items():
        # Retrieve pathway information
        pathway_info = kegg.get(pathway_id)
        parsed_pathway = kegg.parse(pathway_info)

        # Get and store the gene list
        genes = parsed_pathway['GENE']
        gene_symbols = []
        for gene_id, gene_info in genes.items():
            # Get the gene symbol
            gene_symbol = gene_info.split(' ')[0].split(';')[0]
            gene_symbols.append(gene_symbol)

        pathway_genes[pathway_name] = gene_symbols

    return pathway_genes

pathway_genes = get_cancer_metastasis_genes()

### Data Preprocessing

- Construct separate subgraphs for primary and metastatic sites.

In [3]:
import numpy as np
import pandas as pd
from scipy.sparse import vstack, csr_matrix
# Main execution
if __name__ == "__main__":

    ignore_warnings()
    set_random_seed()

    input_directory = '../Data'
    
    # Parses command-line arguments for training a GNN on a gene cell graph, you can modify it as needed.
    (output_file, labsm, lr, wd, n_hid, nheads, nlayers, 
     cell_size, neighbor, egrn) = parse_arguments(egrn=True, output_file= f'./{input_directory}/Result/')
    
    # Get the genes associated with the pathways
    pathway_genes = get_cancer_metastasis_genes()
    
    output_file = f'./{input_directory}/Result/'  # relative path for saving results
    attention_file = f'./{input_directory}/Result/Attention/'
    
    ensure_dir(output_file)
    ensure_dir(attention_file)
    
    # Load the saved AnnData object
    input_file = "../Data/combined_data.h5ad"
    adata_combined = sc.read(input_file)
    
    gene_names = adata_combined.var_names
    cell_names = adata_combined.obs_names

    labels = adata_combined.obs['label']
    
    # Extract the data matrix (genes x cells)
    cell_counts_matrix = adata_combined.X
    
    # Convert the matrix to CSR format
    cell_counts_matrix_csr = csr_matrix(cell_counts_matrix)
    
    # Transpose the matrix so that genes are columns and cells are rows
    gene_cell = cell_counts_matrix_csr.transpose()
    
    # Set observation names (cells) and variable names (genes)
    gene_cell.obs_names = adata_combined.obs_names  # Cells are now the rows (after transpose)
    gene_cell.var_names = adata_combined.var_names  # Genes are now the columns (after transpose)

    RNA_matrix = gene_cell

    cell_num = RNA_matrix.shape[1]
    gene_num = RNA_matrix.shape[0]

    print(f"Cell number: {cell_num}")
    print(f"Gene number: {gene_num}")

    device = torch.device("cuda" if cuda.is_available() else "cpu")
    device = torch.device('cpu')

    print('You will use : ',device)
    # Clustering result by scanpy
    # initial_pre = initial_clustering(RNA_matrix,custom_resolution=0.5,custom_n_neighbors=10)
    initial_pre = initial_clustering(RNA_matrix.copy())
    cluster_ini_num = len(set(initial_pre))
    ini_p1 = [int(i) for i in initial_pre]

    # Call the batch_select_whole function
    indices, Node_Ids, dic = batch_select_whole(RNA_matrix, labels, neighbor=[neighbor], cell_size=cell_size)
    # indices, Node_Ids, dic = batch_select_whole(RNA_matrix, label)
    sample_type = list(np.array(labels)[Node_Ids])

Cell number: 832
Gene number: 36601
You will use :  cpu
	When the number of cells is less than or equal to 500, it is recommended to set the resolution value to 0.2.
	When the number of cells is within the range of 500 to 5000, the resolution value should be set to 0.5.
	When the number of cells is greater than 5000, the resolution value should be set to 0.8.
         Falling back to preprocessing with `sc.pp.pca` and default params.
Partitioning data into batches based on sample type.


Processing Tumor samples: 100%|██████████| 8/8 [00:00<00:00, 24.37it/s]
Processing Lymph Node samples: 100%|██████████| 21/21 [00:00<00:00, 27.71it/s]


### Model Training

- Obtain cell clustering results and attention-related matrices.

In [4]:
n_batch = len(indices)

# Reduce the dimensionality of features for cell, gene, and peak data.
node_model = NodeDimensionReduction(RNA_matrix, indices, ini_p1, n_hid=n_hid, n_heads=nheads, 
                                n_layers=nlayers, labsm=labsm, lr=lr, wd=wd, device=device, 
                                num_types=2, num_relations=1, epochs=10)
gnn = node_model.train_model(n_batch=n_batch)

# Tarin EmitGCL Model
EmitGCL_model = EmitGCL(gnn=gnn, labsm=labsm, n_hid=n_hid, n_batch=n_batch, device=device, lr=lr, wd=wd, 
                  pathway_genes=pathway_genes, gene_names=gene_names, sample_type=sample_type, num_epochs=20)
EmitGCL_gnn = EmitGCL_model.train_model(indices=indices, RNA_matrix=RNA_matrix, ini_p1=ini_p1, sample_type=sample_type, nodes_id=Node_Ids)

EmitGCL_result = EmitGCL_pred(RNA_matrix, EmitGCL_gnn=EmitGCL_gnn, indices=indices, nheads=nheads,
                            nodes_id=Node_Ids, cell_size=cell_size, device=device, 
                            gene_names=gene_names, node_dim_reduction_model=node_model)

The training process for the NodeDimensionReduction model has started. Please wait.


 10%|█         | 1/10 [01:14<11:14, 74.99s/it]

Epoch 1:
  KL Loss: 0.28742086835976305
  Cluster Loss: 3.684804603971284


 20%|██        | 2/10 [02:48<11:27, 85.95s/it]

Epoch 2:
  KL Loss: 0.2490111419866825
  Cluster Loss: 3.5531575433139144


 30%|███       | 3/10 [04:24<10:32, 90.39s/it]

Epoch 3:
  KL Loss: 0.2445886684902783
  Cluster Loss: 3.3537188316213675


 40%|████      | 4/10 [05:40<08:29, 84.86s/it]

Epoch 4:
  KL Loss: 0.2487232726195763
  Cluster Loss: 3.2271730653170883


 50%|█████     | 5/10 [07:19<07:29, 89.95s/it]

Epoch 5:
  KL Loss: 0.22498767900055852
  Cluster Loss: 3.281258591290178


 60%|██████    | 6/10 [08:53<06:05, 91.42s/it]

Epoch 6:
  KL Loss: 0.2523061340225154
  Cluster Loss: 3.1361886139573723


 70%|███████   | 7/10 [10:29<04:38, 92.85s/it]

Epoch 7:
  KL Loss: 0.2215259445124659
  Cluster Loss: 2.9959375200600458


 80%|████████  | 8/10 [11:58<03:03, 91.66s/it]

Epoch 8:
  KL Loss: 0.2106807298701385
  Cluster Loss: 2.916520825747786


 90%|█████████ | 9/10 [13:39<01:34, 94.39s/it]

Epoch 9:
  KL Loss: 0.19879213592101788
  Cluster Loss: 2.812184983286364


100%|██████████| 10/10 [15:12<00:00, 91.29s/it]


Epoch 10:
  KL Loss: 0.188623810122753
  Cluster Loss: 2.8423751140462943
The training for the NodeDimensionReduction model has been completed.
The training process for the EmitGCL model has started. Please wait.


100%|██████████| 29/29 [00:37<00:00,  1.28s/it]


Epoch 1:
  KL Loss: 1.8963366746902466
  Cluster Loss: 2.6857054233551025
  UCell Loss: -2.786606788635254
  Total Contrastive Loss: 0.9229361414909363
  Total Loss: 2.718371629714966


100%|██████████| 29/29 [00:44<00:00,  1.52s/it]


Epoch 2:
  KL Loss: 1.723281741142273
  Cluster Loss: 2.678626298904419
  UCell Loss: -2.205834150314331
  Total Contrastive Loss: 1.0022661685943604
  Total Loss: 3.1983399391174316


100%|██████████| 29/29 [00:37<00:00,  1.29s/it]


Epoch 3:
  KL Loss: 1.4433364868164062
  Cluster Loss: 2.750516176223755
  UCell Loss: -2.448504686355591
  Total Contrastive Loss: 0.9908048510551453
  Total Loss: 2.7361526489257812


100%|██████████| 29/29 [00:36<00:00,  1.27s/it]


Epoch 4:
  KL Loss: 1.336318016052246
  Cluster Loss: 2.7405014038085938
  UCell Loss: -2.2460124492645264
  Total Contrastive Loss: 0.0
  Total Loss: 1.8308069705963135


100%|██████████| 29/29 [00:42<00:00,  1.47s/it]


Epoch 5:
  KL Loss: 1.2646753787994385
  Cluster Loss: 2.7066493034362793
  UCell Loss: -2.565028667449951
  Total Contrastive Loss: 0.0
  Total Loss: 1.4062960147857666


100%|██████████| 29/29 [00:40<00:00,  1.40s/it]


Epoch 6:
  KL Loss: 1.2136815786361694
  Cluster Loss: 2.851593017578125
  UCell Loss: -2.8228402137756348
  Total Contrastive Loss: 0.0
  Total Loss: 1.2424345016479492


100%|██████████| 29/29 [00:39<00:00,  1.36s/it]


Epoch 7:
  KL Loss: 1.1195231676101685
  Cluster Loss: 2.8164398670196533
  UCell Loss: -2.910787582397461
  Total Contrastive Loss: 0.0
  Total Loss: 1.0251755714416504


100%|██████████| 29/29 [00:40<00:00,  1.41s/it]


Epoch 8:
  KL Loss: 1.111269235610962
  Cluster Loss: 2.8074893951416016
  UCell Loss: -3.57694411277771
  Total Contrastive Loss: 0.0
  Total Loss: 0.3418145179748535


100%|██████████| 29/29 [00:38<00:00,  1.33s/it]


Epoch 9:
  KL Loss: 1.0458542108535767
  Cluster Loss: 2.8296327590942383
  UCell Loss: -1.8149651288986206
  Total Contrastive Loss: 0.0
  Total Loss: 2.0605216026306152


100%|██████████| 29/29 [00:42<00:00,  1.48s/it]


Epoch 10:
  KL Loss: 0.9917914271354675
  Cluster Loss: 2.8613200187683105
  UCell Loss: -2.9914090633392334
  Total Contrastive Loss: 0.9031197428703308
  Total Loss: 1.764822244644165


100%|██████████| 29/29 [00:40<00:00,  1.41s/it]


Epoch 11:
  KL Loss: 0.9587863087654114
  Cluster Loss: 2.812537908554077
  UCell Loss: -3.0062460899353027
  Total Contrastive Loss: 0.8788116574287415
  Total Loss: 1.6438896656036377


100%|██████████| 29/29 [00:30<00:00,  1.06s/it]


Epoch 12:
  KL Loss: 0.9186223149299622
  Cluster Loss: 2.8220739364624023
  UCell Loss: -2.953612804412842
  Total Contrastive Loss: 0.0
  Total Loss: 0.7870833873748779


100%|██████████| 29/29 [00:31<00:00,  1.08s/it]


Epoch 13:
  KL Loss: 0.9058510065078735
  Cluster Loss: 2.8103415966033936
  UCell Loss: -2.765408515930176
  Total Contrastive Loss: 0.9047013521194458
  Total Loss: 1.8554855585098267


100%|██████████| 29/29 [00:44<00:00,  1.52s/it]


Epoch 14:
  KL Loss: 0.8911846876144409
  Cluster Loss: 2.8558287620544434
  UCell Loss: -3.59956431388855
  Total Contrastive Loss: 0.0
  Total Loss: 0.14744925498962402


100%|██████████| 29/29 [00:41<00:00,  1.43s/it]


Epoch 15:
  KL Loss: 0.8780050277709961
  Cluster Loss: 2.8699851036071777
  UCell Loss: -3.0269453525543213
  Total Contrastive Loss: 0.0
  Total Loss: 0.7210447788238525


100%|██████████| 29/29 [00:35<00:00,  1.21s/it]


Epoch 16:
  KL Loss: 0.8466627597808838
  Cluster Loss: 2.808638095855713
  UCell Loss: -2.7214972972869873
  Total Contrastive Loss: 1.012995958328247
  Total Loss: 1.9467995166778564


100%|██████████| 29/29 [00:40<00:00,  1.41s/it]


Epoch 17:
  KL Loss: 0.8190692663192749
  Cluster Loss: 2.810232162475586
  UCell Loss: -3.952270984649658
  Total Contrastive Loss: 0.0
  Total Loss: -0.3229694366455078


100%|██████████| 29/29 [00:38<00:00,  1.31s/it]


Epoch 18:
  KL Loss: 0.8105126023292542
  Cluster Loss: 2.8065614700317383
  UCell Loss: -2.5836429595947266
  Total Contrastive Loss: 0.9922962188720703
  Total Loss: 2.0257272720336914


100%|██████████| 29/29 [00:50<00:00,  1.73s/it]


Epoch 19:
  KL Loss: 0.7860991954803467
  Cluster Loss: 2.8021435737609863
  UCell Loss: 0.0
  Total Contrastive Loss: 0.0
  Total Loss: 3.588242769241333


100%|██████████| 29/29 [00:41<00:00,  1.44s/it]


Epoch 20:
  KL Loss: 0.7762186527252197
  Cluster Loss: 2.7814955711364746
  UCell Loss: -3.61641263961792
  Total Contrastive Loss: 0.0
  Total Loss: -0.058698415756225586
The training for the EmitGCL model has been completed.


100%|██████████| 29/29 [00:57<00:00,  1.98s/it]


### Save Model Output Results

- `EmitGCL_result['pred_label']`: Predicted cell clustering labels
- `EmitGCL_result['cell_embedding']`: Low-dimensional cell embeddings
- `EmitGCL_result['attention_weights']`: Cell-gene attention matrix

In [6]:
# Save Result
attention_matrices = EmitGCL_result['attention_weights']

for head, matrix in attention_matrices.items():
    save_attention_matrix_to_mtx(matrix, attention_file, head)

np.save(output_file + "pred.npy", EmitGCL_result['pred_label'])
np.save(output_file + "cell_embedding.npy", EmitGCL_result['cell_embedding'])