# Imports

In [None]:
from typing import Tuple, List, Dict, Any

import pertpy as pt
import scanpy as sc
import numpy as np
import plotly.express as px
import pandas as pd
import einops
import gseapy as gp
from gseapy import enrichr

from scripts.datasets import ClassifierDataset
from scripts.bmlp import ScBMLPClassifier, Config
import torch

# Set params

In [None]:
class_key = "condition"
DEVICE = "cpu"  # faster than mps...

# Load data

## Load and format

In [None]:
adata_train = sc.read("data/scgen-reproducibility/train_pbmc.h5ad")
adata_val = sc.read("data/scgen-reproducibility/valid_pbmc.h5ad")


This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(


In [None]:
adata_train

AnnData object with n_obs × n_vars = 16893 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'distances', 'connectivities'

In [None]:
adata_train.obs

Unnamed: 0_level_0,condition,n_counts,n_genes,mt_frac,cell_type
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
AAACATACCAAGCT-1-stimulated,stimulated,1160.0,589,0.0,NK
AAACATACCCCTAC-1-stimulated,stimulated,1729.0,795,0.0,Dendritic
AAACATACCCGTAA-1-stimulated,stimulated,1360.0,585,0.0,CD4T
AAACATACCCTCGT-1-stimulated,stimulated,1442.0,732,0.0,B
AAACATACGAGGTG-1-stimulated,stimulated,1237.0,546,0.0,CD4T
...,...,...,...,...,...
TTTGACTGGCGGAA-1-control,control,2505.0,821,0.0,CD8T
TTTGACTGTCGTAG-1-control,control,3704.0,1101,0.0,CD14+Mono
TTTGACTGTTACCT-1-control,control,2133.0,629,0.0,CD14+Mono
TTTGCATGCTTCGC-1-control,control,2317.0,875,0.0,B


In [None]:
adata_train.obs["condition"].value_counts()

condition
stimulated    8886
control       8007
Name: count, dtype: int64

In [None]:
train_dataset = ClassifierDataset(adata_train, class_key)
val_dataset = ClassifierDataset(adata_val, class_key)

## Visualize

In [None]:
for adata in [adata_train, adata_val]:
    # sc.pp.pca(adata, n_comps=2)
    fig = px.scatter(
        x=adata.obsm["X_pca"][:, 0],
        y=adata.obsm["X_pca"][:, 1],
        color=adata.obs[class_key],
        title="PBMC IFN gamma dataset",
        width=600,
        height=600,
    )
    fig.update_traces(marker=dict(size=5))
    fig.show()

# Train model

In [None]:
n_genes = adata_train.shape[1]
n_classes = adata_train.obs[class_key].nunique()

d_hidden = 128
n_epochs = 25
lr = 1e-5

In [None]:
cfg = Config(
    d_input=n_genes,
    d_hidden=d_hidden,
    d_output=n_classes,
    n_epochs=n_epochs,
    lr=lr,
    device=DEVICE,
    batch_size=32,
)
model = ScBMLPClassifier(cfg)
train_losses, val_losses = model.fit(train_dataset, val_dataset)

Training for 25 epochs: 100%|██████████| 25/25 [01:13<00:00,  2.96s/it, train_acc=1.0000, train_loss=0.0009, val_acc=1.0000, val_loss=0.0018]


In [None]:
# Combine train and val losses into a single plot
loss_df = pd.DataFrame({
    'Epoch': list(range(len(train_losses))) + list(range(len(val_losses))),
    'Loss': train_losses + val_losses,
    'Type': ['Train'] * len(train_losses) + ['Validation'] * len(val_losses)
})

px.line(loss_df, x='Epoch', y='Loss', color='Type', 
        title='Training and Validation Loss', 
        labels={'Loss': 'Loss', 'Epoch': 'Epoch'}).show()