# Imports

In [33]:
from typing import Tuple, List, Dict, Any

import pertpy as pt
import scanpy as sc
import numpy as np
import plotly.express as px
import pandas as pd
import einops
import gseapy as gp
from gseapy import enrichr

from scripts.datasets import get_classification_datasets
from scripts.bmlp import ScBMLPClassifier, Config
import torch

# Set params

In [2]:
class_key = "condition"
DEVICE = "cpu"  # faster than mps...

# Load data

## Load and format

In [3]:
adata = pt.data.kang_2018()

In [4]:
adata

AnnData object with n_obs × n_vars = 24673 × 15706
    obs: 'nCount_RNA', 'nFeature_RNA', 'tsne1', 'tsne2', 'label', 'cluster', 'cell_type', 'replicate', 'nCount_SCT', 'nFeature_SCT', 'integrated_snn_res.0.4', 'seurat_clusters'
    var: 'name'
    obsm: 'X_pca', 'X_umap'

In [5]:
sc.pp.filter_cells(adata, min_counts=100)
sc.pp.filter_genes(adata, min_cells=100)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
# sc.pp.highly_variable_genes(adata, n_top_genes=5000)

In [6]:
adata

AnnData object with n_obs × n_vars = 24673 × 9113
    obs: 'nCount_RNA', 'nFeature_RNA', 'tsne1', 'tsne2', 'label', 'cluster', 'cell_type', 'replicate', 'nCount_SCT', 'nFeature_SCT', 'integrated_snn_res.0.4', 'seurat_clusters', 'n_counts'
    var: 'name', 'n_cells'
    uns: 'log1p'
    obsm: 'X_pca', 'X_umap'

In [7]:
adata.obs

Unnamed: 0_level_0,nCount_RNA,nFeature_RNA,tsne1,tsne2,label,cluster,cell_type,replicate,nCount_SCT,nFeature_SCT,integrated_snn_res.0.4,seurat_clusters,n_counts
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
AAACATACATTTCC-1,3017.0,877,-27.640373,14.966629,ctrl,9,CD14+ Monocytes,patient_1016,1704.0,711,1,1,3017.0
AAACATACCAGAAA-1,2481.0,713,-27.493646,28.924885,ctrl,9,CD14+ Monocytes,patient_1256,1614.0,662,1,1,2481.0
AAACATACCATGCA-1,703.0,337,-10.468194,-5.984389,ctrl,3,CD4 T cells,patient_1488,908.0,337,6,6,703.0
AAACATACCTCGCT-1,3420.0,850,-24.367997,20.429285,ctrl,9,CD14+ Monocytes,patient_1256,1738.0,653,1,1,3420.0
AAACATACCTGGTA-1,3158.0,1111,27.952170,24.159738,ctrl,4,Dendritic cells,patient_1039,1857.0,928,12,12,3158.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGCATGCCTGAA-2,1033.0,468,18.268321,1.058202,stim,6,CD4 T cells,patient_1244,1128.0,468,2,2,1033.0
TTTGCATGCCTGTC-2,2116.0,819,-11.563067,2.574095,stim,4,B cells,patient_1256,1669.0,799,3,3,2116.0
TTTGCATGCTAAGC-2,1522.0,523,25.142392,6.603815,stim,6,CD4 T cells,patient_107,1422.0,523,0,0,1522.0
TTTGCATGGGACGA-2,1143.0,503,14.359657,10.965601,stim,6,CD4 T cells,patient_1488,1185.0,503,0,0,1143.0


In [8]:
random_state = 0
class_key = "label"
train_dataset, val_dataset, test_dataset = get_classification_datasets(
    adata, class_key, random_state=random_state, device=DEVICE,
)

In [10]:
train_dataset.adata

AnnData object with n_obs × n_vars = 17271 × 9113
    obs: 'nCount_RNA', 'nFeature_RNA', 'tsne1', 'tsne2', 'label', 'cluster', 'cell_type', 'replicate', 'nCount_SCT', 'nFeature_SCT', 'integrated_snn_res.0.4', 'seurat_clusters', 'n_counts'
    var: 'name', 'n_cells'
    uns: 'log1p'
    obsm: 'X_pca', 'X_umap'

In [14]:
train_dataset.adata.obs[class_key].value_counts()

label
ctrl    8653
stim    8618
Name: count, dtype: int64

## Visualize

In [18]:
for adata in [train_dataset.adata, val_dataset.adata]:
    fig = px.scatter(
        x=adata.obsm["X_pca"][:, 0],
        y=adata.obsm["X_pca"][:, 1],
        color=adata.obs[class_key],
        title="PBMC IFN gamma dataset",
        width=600,
        height=600,
    )
    fig.update_traces(marker=dict(size=5))
    fig.show()

# Train model

In [23]:
n_genes = train_dataset.adata.shape[1]
n_classes = train_dataset.adata.obs[class_key].nunique()

d_hidden = 68
n_epochs = 25
lr = 1e-5

In [24]:
cfg = Config(
    d_input=n_genes,
    d_hidden=d_hidden,
    d_output=n_classes,
    n_epochs=n_epochs,
    lr=lr,
    device=DEVICE,
    batch_size=32,
)
model = ScBMLPClassifier(cfg)
train_losses, val_losses = model.fit(train_dataset, val_dataset)

Training for 25 epochs: 100%|██████████| 25/25 [01:08<00:00,  2.72s/it, train_acc=1.0000, train_loss=0.0014, val_acc=0.9878, val_loss=0.0360]


In [25]:
# Combine train and val losses into a single plot
loss_df = pd.DataFrame({
    'Epoch': list(range(len(train_losses))) + list(range(len(val_losses))),
    'Loss': train_losses + val_losses,
    'Type': ['Train'] * len(train_losses) + ['Validation'] * len(val_losses)
})

px.line(loss_df, x='Epoch', y='Loss', color='Type', 
        title='Training and Validation Loss', 
        labels={'Loss': 'Loss', 'Epoch': 'Epoch'}).show()

# Weight interpretation

In [26]:
def get_marker_gene_lists(
    gene_names: np.ndarray,
    vecs: np.ndarray,
    n_top_comps: int = 1,
    n_top_genes: int = 50,
) -> np.ndarray:
    """Extract marker genes optimized for GO analysis."""
    gene_lists = []
    for i in range(n_top_comps):
        top_idxs = vecs[:,i].topk(n_top_genes).indices
        top_genes = gene_names[top_idxs].tolist()
        bottom_idxs = (-vecs[:,i]).topk(n_top_genes).indices
        bottom_genes = gene_names[bottom_idxs].tolist()
        gene_lists.append([top_genes, bottom_genes])
    return np.array(gene_lists)

In [27]:
adata = sc.concat([train_dataset.adata, val_dataset.adata])

## Bilinear

### Gene markers

In [28]:
# Binary classification ==> output directions yield the same results
q = einops.einsum(model.w_p[0], model.w_l, model.w_r, "hid, hid in1, hid in2 -> in1 in2")
q = 0.5 * (q + q.mT)  # symmetrize

In [29]:
# Eigendecompose to get gene module weights
_, vecs_bmlp = torch.linalg.eigh(q)
vecs_bmlp = vecs_bmlp.flip([1])

In [None]:
# Get gene names per module (i.e. component; "comp")
n_top_comps = 3
n_top_genes = 20
gene_names = adata.var_names.values
gene_lists_bmlp = get_marker_gene_lists(
    gene_names, vecs_bmlp, n_top_comps=n_top_comps, n_top_genes=n_top_genes
)  # [comp, top/bottom, gene]

In [30]:
for i in range(n_top_comps):
    print("="*20, "Component", i, "="*20)
    print(f"Top genes: {gene_lists_bmlp[i,0,:8]}...")
    print(f"Bottom genes: {gene_lists_bmlp[i,1,:8]}...")

Top genes: ['IFI6' 'ISG15' 'IFIT3' 'LAG3' 'IFIT1' 'LY6E' 'IFI44' 'ISG20']...
Bottom genes: ['ARL6IP5' 'EEF1A1' 'RPL6' 'FTH1' 'EIF3D' 'BTG1' 'VIM' 'STK17A']...
Top genes: ['MT2A' 'DDX6' 'NCL' 'USE1' 'ACAP1' 'PARL' 'DSTYK' 'EEF1B2']...
Bottom genes: ['FGL2' 'SERTAD1' 'CST3' 'ZXDC' 'DNTTIP2' 'SNX6' 'PARP6' 'ARIH2']...
Top genes: ['COX7C' 'CBY1' 'PITHD1' 'FZR1' 'UQCRC2' 'GLIPR1' 'NR1D2' 'FEM1A']...
Bottom genes: ['SEPHS1' 'DCUN1D1' 'PSME2' 'LSR' 'CXCL10' 'MPPE1' 'LGALS9' 'GPR183']...


### GO analysis

In [35]:
n_results = 5
results_cols = ["Term", "Genes", "Gene_set", "Adjusted P-value"]

for comp in range(n_top_comps):
    print("="*40, "Component", comp, "="*40)
    for i in range(2):
        enr = gp.enrichr(
            gene_list=gene_lists_bmlp[comp, i].tolist(),
            gene_sets=[
                # "GO_Biological_Process_2023",
                # "KEGG_2021_Human",
                "Reactome_2022"
            ],
            cutoff=0.05,
        )
        display(enr.results.head(n_results)[results_cols])



Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Interferon Alpha/Beta Signaling R-HSA-909733,ISG20;OAS1;STAT1;OAS3;MX1;IFI6;ISG15;IFIT1;IFI...,Reactome_2022,3.3490470000000002e-18
1,Interferon Signaling R-HSA-913531,ISG20;MT2A;OAS1;STAT1;OAS3;MX1;IFI6;EIF2AK2;IS...,Reactome_2022,4.121495e-18
2,Cytokine Signaling In Immune System R-HSA-1280215,ISG20;MT2A;OAS1;STAT1;OAS3;MX1;IFI6;EIF2AK2;IS...,Reactome_2022,1.013465e-11
3,Antiviral Mechanism By IFN-stimulated Genes R-...,OAS1;STAT1;OAS3;MX1;EIF2AK2;ISG15;IFIT1,Reactome_2022,2.975361e-11
4,Immune System R-HSA-168256,LAG3;STAT1;MX1;IFI6;EIF2AK2;ISG15;IFIT1;IFIT3;...,Reactome_2022,5.216679e-08


Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Translation R-HSA-72766,EEF1A1;RPS25;EEF1D;EIF3L;EIF3D;RPL6;RPL7,Reactome_2022,6.250253e-07
1,Eukaryotic Translation Elongation R-HSA-156842,EEF1A1;RPS25;EEF1D;RPL6;RPL7,Reactome_2022,1.139117e-06
2,Formation Of A Pool Of Free 40S Subunits R-HSA...,RPS25;EIF3L;EIF3D;RPL6;RPL7,Reactome_2022,1.16759e-06
3,L13a-mediated Translational Silencing Of Cerul...,RPS25;EIF3L;EIF3D;RPL6;RPL7,Reactome_2022,1.196904e-06
4,GTP Hydrolysis And Joining Of 60S Ribosomal Su...,RPS25;EIF3L;EIF3D;RPL6;RPL7,Reactome_2022,1.196904e-06




Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Metallothioneins Bind Metals R-HSA-5661231,MT2A,Reactome_2022,0.099573
1,Response To Metal Ions R-HSA-5660526,MT2A,Reactome_2022,0.099573
2,mRNA Decay By 5 To 3 Exoribonuclease R-HSA-430039,DDX6,Reactome_2022,0.099573
3,Response Of EIF2AK1 (HRI) To Heme Deficiency R...,EIF2AK1,Reactome_2022,0.099573
4,Processing Of SMDT1 R-HSA-8949664,PARL,Reactome_2022,0.099573


Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Post-translational Protein Phosphorylation R-H...,CST3;MIA3,Reactome_2022,0.122784
1,Regulation Of IGF Transport And Uptake By IGFB...,CST3;MIA3,Reactome_2022,0.122784
2,Aryl Hydrocarbon Receptor Signaling R-HSA-8937144,MIA3,Reactome_2022,0.122784
3,SLBP Independent Processing Of Histone Pre-mRN...,LSM10,Reactome_2022,0.122784
4,Regulation Of Gene Expression By Hypoxia-induc...,MIA3,Reactome_2022,0.122784




Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Aberrant Regulation Of Mitotic Cell Cycle Due ...,FZR1;CDK4,Reactome_2022,0.042157
1,Diseases Of Mitotic Cell Cycle R-HSA-9675126,FZR1;CDK4,Reactome_2022,0.042157
2,RNA Polymerase II Transcription R-HSA-73857,POMC;FZR1;CDK4;NABP2;NR1D2;COX7C,Reactome_2022,0.059051
3,Gene Expression (Transcription) R-HSA-74160,POMC;FZR1;CDK4;NABP2;NR1D2;COX7C,Reactome_2022,0.064346
4,Senescence-Associated Secretory Phenotype (SAS...,FZR1;CDK4,Reactome_2022,0.064346


Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Regulation Of Expression Of SLITs And ROBOs R-...,PSME2;RNPS1;RPL15,Reactome_2022,0.119211
1,Signaling By ROBO Receptors R-HSA-376176,PSME2;RNPS1;RPL15,Reactome_2022,0.119211
2,VLDL Clearance R-HSA-8964046,LSR,Reactome_2022,0.160475
3,RUNX1 Regulates Transcription Of Genes Involve...,ELOF1,Reactome_2022,0.160475
4,MECP2 Regulates Transcription Factors R-HSA-90...,MECP2,Reactome_2022,0.160475


## PCA

### Gene markers

In [36]:
cov = adata.X.T @ adata.X
cov = torch.tensor(cov.toarray()).to(DEVICE)

In [37]:
# Eigendecompose to get gene module weights
_, vecs_cov = torch.linalg.eigh(cov)
vecs_cov = vecs_cov.flip([1])

In [38]:
# Get gene names per module (i.e. component; "comp")
n_top_comps = 3
n_top_genes = 20
gene_names = adata.var_names.values
gene_lists_cov = get_marker_gene_lists(
    gene_names, vecs_cov, n_top_comps=n_top_comps, n_top_genes=n_top_genes
)  # [comp, top/bottom, gene]

In [39]:
for i in range(n_top_comps):
    print("="*20, "Component", i, "="*20)
    print(f"Top genes: {gene_lists_cov[i,0,:8]}...")
    print(f"Bottom genes: {gene_lists_cov[i,1,:8]}...")

Top genes: ['HBD' 'NRP1' 'GUCY1A3' 'LINC00900' 'ARHGAP22' 'LIPN' 'RP11-262H14.1'
 'IDO2']...
Bottom genes: ['MALAT1' 'B2M' 'TMSB4X' 'FTH1' 'HLA-B' 'RPS2' 'RPL10' 'RPL13']...
Top genes: ['FTL' 'FTH1' 'TIMP1' 'ISG15' 'SOD2' 'CXCL10' 'C15orf48' 'CCL2']...
Bottom genes: ['RPL3' 'RPS6' 'RPL21' 'RPS18' 'RPL7' 'RPL13A' 'RPS3' 'RPL13']...
Top genes: ['ISG15' 'ISG20' 'IFIT3' 'IFI6' 'CXCL10' 'IFIT1' 'LY6E' 'TNFSF10']...
Bottom genes: ['FTH1' 'FTL' 'TIMP1' 'ACTB' 'IL8' 'GAPDH' 'PFN1' 'S100A8']...


### GO analysis

In [41]:
n_results = 5
results_cols = ["Term", "Genes", "Gene_set", "Adjusted P-value"]

for comp in range(n_top_comps):
    print("="*40, "Component", comp, "="*40)
    for i in range(2):
        enr = gp.enrichr(
            gene_list=gene_lists_cov[comp, i].tolist(),
            gene_sets=[
                # "GO_Biological_Process_2023",
                # "KEGG_2021_Human",
                "Reactome_2022"
            ],
            cutoff=0.05,
        )
        display(enr.results.head(n_results)[results_cols])



Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Transport And Synthesis Of PAPS R-HSA-174362,PAPSS2,Reactome_2022,0.129487
1,CHL1 Interactions R-HSA-447041,NRP1,Reactome_2022,0.129487
2,Tryptophan Catabolism R-HSA-71240,IDO2,Reactome_2022,0.129487
3,SEMA3A-Plexin Repulsion Signaling By Inhibitin...,NRP1,Reactome_2022,0.129487
4,Glycogen Breakdown (Glycogenolysis) R-HSA-70221,PYGL,Reactome_2022,0.129487


Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Peptide Chain Elongation R-HSA-156902,RPS4X;RPS14;RPL3;RPL21;RPL32;RPS19;RPLP1;RPS6;...,Reactome_2022,2.4295169999999998e-20
1,Selenocysteine Synthesis R-HSA-2408557,RPS4X;RPS14;RPL3;RPL21;RPL32;RPS19;RPLP1;RPS6;...,Reactome_2022,2.4295169999999998e-20
2,Viral mRNA Translation R-HSA-192823,RPS4X;RPS14;RPL3;RPL21;RPL32;RPS19;RPLP1;RPS6;...,Reactome_2022,2.4295169999999998e-20
3,Eukaryotic Translation Elongation R-HSA-156842,RPS4X;RPS14;RPL3;RPL21;RPL32;RPS19;RPLP1;RPS6;...,Reactome_2022,2.4295169999999998e-20
4,Eukaryotic Translation Termination R-HSA-72764,RPS4X;RPS14;RPL3;RPL21;RPL32;RPS19;RPLP1;RPS6;...,Reactome_2022,2.4295169999999998e-20




Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Immune System R-HSA-168256,IFITM3;CD63;CXCL10;TYROBP;FCER1G;FTH1;HLA-DRA;...,Reactome_2022,3e-06
1,Neutrophil Degranulation R-HSA-6798695,CD63;TYROBP;FCER1G;FTH1;S100A11;FTL,Reactome_2022,0.000194
2,Interleukin-10 Signaling R-HSA-6783783,CXCL10;CCL2;TIMP1,Reactome_2022,0.000327
3,Innate Immune System R-HSA-168249,CD63;TYROBP;FCER1G;FTH1;ISG15;S100A11;FTL,Reactome_2022,0.000773
4,Cytokine Signaling In Immune System R-HSA-1280215,IFITM3;CXCL10;HLA-DRA;CCL2;ISG15;TIMP1,Reactome_2022,0.000773


Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Peptide Chain Elongation R-HSA-156902,RPL3;RPL21;RPL32;RPL34;RPS5;RPS6;RPL13A;RPS3A;...,Reactome_2022,2.192849e-36
1,Viral mRNA Translation R-HSA-192823,RPL3;RPL21;RPL32;RPL34;RPS5;RPS6;RPL13A;RPS3A;...,Reactome_2022,2.192849e-36
2,Selenocysteine Synthesis R-HSA-2408557,RPL3;RPL21;RPL32;RPL34;RPS5;RPS6;RPL13A;RPS3A;...,Reactome_2022,2.192849e-36
3,Eukaryotic Translation Elongation R-HSA-156842,RPL3;RPL21;RPL32;RPL34;RPS5;RPS6;RPL13A;RPS3A;...,Reactome_2022,2.192849e-36
4,Eukaryotic Translation Termination R-HSA-72764,RPL3;RPL21;RPL32;RPL34;RPS5;RPS6;RPL13A;RPS3A;...,Reactome_2022,2.192849e-36




Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Interferon Alpha/Beta Signaling R-HSA-909733,IFITM3;ISG20;RSAD2;OAS1;MX1;IFI6;IRF7;ISG15;IF...,Reactome_2022,1.913127e-20
1,Interferon Signaling R-HSA-913531,IFITM3;ISG20;MT2A;RSAD2;OAS1;MX1;IFI6;IRF7;ISG...,Reactome_2022,8.327103e-18
2,Cytokine Signaling In Immune System R-HSA-1280215,IFITM3;RSAD2;MX1;IFI6;ISG15;IFIT1;IFIT3;IFIT2;...,Reactome_2022,7.983423e-15
3,Immune System R-HSA-168256,IFITM3;RSAD2;MX1;IFI6;ISG15;IFIT1;IFIT3;IFIT2;...,Reactome_2022,6.965282e-09
4,Antiviral Mechanism By IFN-stimulated Genes R-...,OAS1;MX1;ISG15;IFIT1,Reactome_2022,5.028582e-05


Unnamed: 0,Term,Genes,Gene_set,Adjusted P-value
0,Scavenging By Class A Receptors R-HSA-3000480,FTH1;FTL,Reactome_2022,0.007335
1,Platelet Degranulation R-HSA-114608,CD63;TIMP1;PFN1,Reactome_2022,0.007335
2,Response To Elevated Platelet Cytosolic Ca2+ R...,CD63;TIMP1;PFN1,Reactome_2022,0.007335
3,Binding And Uptake Of Ligands By Scavenger Rec...,FTH1;FTL,Reactome_2022,0.01484
4,Neutrophil Degranulation R-HSA-6798695,CD63;FTH1;S100A8;FTL,Reactome_2022,0.016198
