# Imports

In [1]:
from scripts.datasets import simulation_classes
from scripts.bmlp import ScBMLPClassifier, Config

import scanpy as sc
import numpy as np
import plotly.express as px
import pandas as pd
import einops
import torch

# Set params

In [2]:
DEVICE = "cpu"  # faster than mps...

# Load data

In [3]:
class_key = "cell_type"
adata, train_dataset, val_dataset, test_dataset = simulation_classes(
    device=DEVICE, n_cell_types=5, class_key=class_key
)

In [4]:
n_cells, n_genes = adata.shape
n_cell_types = adata.obs["cell_type"].nunique()

## Visualize

In [5]:
sc.pp.pca(adata, n_comps=2)
fig = px.scatter(
    x=adata.obsm["X_pca"][:, 0],
    y=adata.obsm["X_pca"][:, 1],
    color=adata.obs["cell_type"],
    title="Myeloid progenitor differentiation data",
    width=600,
    height=600,
)
fig.update_traces(marker=dict(size=5))
fig.show()

# Train model

In [6]:
d_hidden = 64
n_epochs = 100
lr = 1e-5

In [7]:
cfg = Config(
    d_input=n_genes,
    d_hidden=d_hidden,
    d_output=n_cell_types,
    n_epochs=n_epochs,
    lr=lr,
    device=DEVICE,
)
model = ScBMLPClassifier(cfg)
train_losses, val_losses = model.fit(train_dataset, val_dataset)

Training for 100 epochs: 100%|██████████| 100/100 [00:01<00:00, 93.41it/s, train_acc=1.0000, train_loss=0.0037, val_acc=1.0000, val_loss=0.0139]


In [8]:
# Combine train and val losses into a single plot
loss_df = pd.DataFrame({
    'Epoch': list(range(len(train_losses))) + list(range(len(val_losses))),
    'Loss': train_losses + val_losses,
    'Type': ['Train'] * len(train_losses) + ['Validation'] * len(val_losses)
})

px.line(loss_df, x='Epoch', y='Loss', color='Type', 
        title='Training and Validation Loss', 
        labels={'Loss': 'Loss', 'Epoch': 'Epoch'}).show()

# Weight interpretation

In [9]:
b = einops.einsum(model.w_p, model.w_l, model.w_r, "out hid, hid in1, hid in2 -> out in1 in2")
b = 0.5 * (b + b.mT)  # symmetrize

print(b.shape)
print(f"Number of transcriptional scales: {b.shape[0]}")
print(f"Number of genes: {b.shape[1]}")

torch.Size([5, 100, 100])
Number of transcriptional scales: 5
Number of genes: 100


## Per frequency

In [21]:
def get_comps(adata, class_idx):
    vals, vecs = torch.linalg.eigh(b[class_idx])
    vals = vals.flip([0])
    vecs = vecs.flip([1])
    return vals, vecs


def print_marker_genes(adata, b, class_idx, n_top_comps=1, n_top_genes=10):
    vals, vecs = get_comps(adata, class_idx)
    for i in range(n_top_comps):  # top components
        top_idxs = vecs[:,i].topk(n_top_genes).indices
        top_genes = adata.var_names[top_idxs].tolist()
        bottom_idxs = (-vecs[:,i]).topk(n_top_genes).indices
        bottom_genes = adata.var_names[bottom_idxs].tolist()
        print(top_genes)
        print(bottom_genes)
    print()


def plot_scatter(adata, color):
    fig = px.scatter(
        x=adata.obsm["X_pca"][:, 0],
        y=adata.obsm["X_pca"][:, 1],
        color=color,
        labels={"x": "PC1", "y": "PC2"},
        color_continuous_scale='RdBu_r',
        color_continuous_midpoint=0,
        width=600,
        height=600,
    )
    fig.update_traces(marker=dict(size=3))
    fig.show()

In [22]:
n_classes = adata.obs[class_key].nunique()
for i in range(n_classes):
    print("="*20, f"Cell type {i}", "="*20)
    print_marker_genes(adata, b, i)

['97', '68', '30', '55', '34', '63', '4', '67', '99', '94']
['33', '22', '2', '49', '91', '18', '53', '83', '37', '6']

['80', '59', '82', '62', '29', '16', '11', '39', '13', '1']
['64', '86', '48', '92', '56', '44', '74', '41', '22', '58']

['90', '4', '79', '7', '3', '5', '11', '13', '1', '80']
['87', '30', '84', '71', '92', '99', '26', '51', '41', '81']

['4', '92', '93', '74', '97', '84', '76', '19', '89', '57']
['49', '85', '21', '87', '27', '46', '33', '24', '31', '94']

['25', '77', '49', '27', '83', '91', '73', '53', '65', '33']
['43', '92', '39', '4', '37', '63', '34', '30', '62', '78']



In [35]:
bmlp_marker_genes = ["68"]
for gene in bmlp_marker_genes:
    plot_scatter(adata, adata[:,gene].X.flatten())