A simple application of weights-based mechanistic interpretability to single-cell transcriptomics. See the associated blog post here.
Multilayer perceptrons (MLPs) are the backbone of deep learning. However, they derive their flexibility and power from nonlinearity between their inputs and outputs, and it's this nonlinearity that prevents us from interpreting exactly how they work.
Like conventional MLPs, bilinear MLPs are nonlinear in their inputs and highly performant.
However, they're linear in terms of pairs of their inputs.
Each output dimension
In other words, for any output feature we want to predict, we get a corresponding matrix
Applied to single-cell transcriptomics, the input data consists of gene expression profiles for each cell
Below are the core analysis tasks along with their associated notebooks and interpretations.
1. Cell Type Classification (cell_type.ipynb)
- Dataset / Context: Developing mouse pancreas (endocrinogenesis) cells.
-
Input: Gene expression
$x \in \mathbb{R}^g$ - Output: Cell type probabilities (softmax over discovered clusters)
-
Architecture:
ScBMLPClassifierwith cross-entropy loss - Interpretation: Eigendecomposition of per-class interaction matrices yields gene modules distinguishing early progenitors vs mature hormone-producing lineages.
2. Frequency Regression (frequency.ipynb)
- Dataset / Context: Same pancreas developmental trajectory; graph constructed over cells.
-
Input: Gene expression
$x \in \mathbb{R}^g$ -
Output: Graph Laplacian eigenvector coefficients
$f \in \mathbb{R}^k$ ("frequency" components) -
Architecture:
ScBMLPRegressorwith MSE loss - Interpretation: Interaction modules stratify transcriptional programs by spatial/graph frequency: low frequencies capture broad developmental and cell cycle exit programs; higher frequencies isolate lineage-specific (e.g. α vs β) specification signals.
3. Perturbation Regression (perturbation.ipynb)
- Dataset / Context: SciPlex3 single-cell chemical screening (multi-compound, multi-dose perturbations).
-
Input: Gene expression
$x \in \mathbb{R}^g$ -
Output: Drug perturbation response embedding
$p \in \mathbb{R}^k$ (latent representation summarizing compound + dose effect per cell) -
Architecture:
ScBMLPRegressorwith MSE loss (predicting the learned/defined per-cell perturbation embedding) - Interpretation: Gene interaction modules reveal coordinated pathways (cell cycle control, stress response, metabolic rewiring) mediating compound- and dose-specific responses.
import scvelo as scv
import scanpy as sc
from scbmlp.datasets import get_classification_datasets
from scbmlp.models import ScBMLPClassifier, Config
import torch
import einops
# Load pancreas endocrinogenesis data
adata = scv.datasets.pancreas()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, subset=True, n_top_genes=10000)
# Create train/validation/test datasets
train_dataset, val_dataset, test_dataset = get_classification_datasets(
adata, class_key="clusters", random_state=42
)
# Configure and train model
cfg = Config(
d_input=10000,
d_hidden=128,
d_output=adata.obs["clusters"].nunique(),
n_epochs=100,
lr=1e-4,
batch_size=64,
)
model = ScBMLPClassifier(cfg)
model.fit(train_dataset, val_dataset)
# Extract gene interaction matrices for each cell type
Q = einops.einsum(
model.w_p,
model.w_l,
model.w_r,
"out hid, hid in1, hid in2 -> out in1 in2",
)
# Analyze modules for each cell type
for cell_type_idx in range(adata.obs["clusters"].nunique()):
vals, vecs = torch.linalg.eigh(Q[cell_type_idx])
vals, vecs = vals.flip([0]), vecs.flip([1])
top_genes = adata.var_names[vecs[:, 0].topk(20).indices]
print(f"Cell type {cell_type_idx}: {top_genes.tolist()}")- Python ≥3.9
- PyTorch ≥1.13.0
git clone https://github.com/kmaherx/ScBMLP.git
cd ScBMLP
pip install -e .python -m venv scbmlp-env
source scbmlp-env/bin/activate # on Windows: scbmlp-env\Scripts\activate
pip install -e .