Skip to content

kmaherx/ScBMLP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

93 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ScBMLP: Bilinear Multi-Layer Perceptrons for Single-Cell Analysis

A simple application of weights-based mechanistic interpretability to single-cell transcriptomics. See the associated blog post here.

Method

Multilayer perceptrons (MLPs) are the backbone of deep learning. However, they derive their flexibility and power from nonlinearity between their inputs and outputs, and it's this nonlinearity that prevents us from interpreting exactly how they work.

Like conventional MLPs, bilinear MLPs are nonlinear in their inputs and highly performant. However, they're linear in terms of pairs of their inputs. Each output dimension $\mathbf{y}_i$ can be expressed as a bilinear form over inputs $\mathbf{x}$.

$$ \mathbf{y}_i = \mathbf{x}^{\top} \mathbf{Q}_i \mathbf{x} $$

In other words, for any output feature we want to predict, we get a corresponding matrix $\mathbf{Q}_i$ of interactions between input features, which is learned during training. Then, just as PCA identifies modules from a covariance matrix, eigendecomposition of $\mathbf{Q}_i$ can identify modules of input features. By relating these modules to the output, we can gain a concise interpretation of what it is the model pays attention to that allows it to achieve its performance.

Applied to single-cell transcriptomics, the input data consists of gene expression profiles for each cell $\mathbf{x} \in \mathbb{R}^g$, and the resulting interaction matrix $\mathbf{Q}_i \in \mathbb{R}^{g \times g}$ represents a gene-gene network predictive of output $\mathbf{y}_i$. Eigendecomposition of $\mathbf{Q}_i$ yields corresponding gene modules. The question then becomes: what outputs would enable this approach to provide scientific insight? Here, we demonstrate applications to three different insightful tasks.

Tasks & Notebooks

Below are the core analysis tasks along with their associated notebooks and interpretations.

1. Cell Type Classification (cell_type.ipynb)

  • Dataset / Context: Developing mouse pancreas (endocrinogenesis) cells.
  • Input: Gene expression $x \in \mathbb{R}^g$
  • Output: Cell type probabilities (softmax over discovered clusters)
  • Architecture: ScBMLPClassifier with cross-entropy loss
  • Interpretation: Eigendecomposition of per-class interaction matrices yields gene modules distinguishing early progenitors vs mature hormone-producing lineages.

2. Frequency Regression (frequency.ipynb)

  • Dataset / Context: Same pancreas developmental trajectory; graph constructed over cells.
  • Input: Gene expression $x \in \mathbb{R}^g$
  • Output: Graph Laplacian eigenvector coefficients $f \in \mathbb{R}^k$ ("frequency" components)
  • Architecture: ScBMLPRegressor with MSE loss
  • Interpretation: Interaction modules stratify transcriptional programs by spatial/graph frequency: low frequencies capture broad developmental and cell cycle exit programs; higher frequencies isolate lineage-specific (e.g. α vs β) specification signals.

3. Perturbation Regression (perturbation.ipynb)

  • Dataset / Context: SciPlex3 single-cell chemical screening (multi-compound, multi-dose perturbations).
  • Input: Gene expression $x \in \mathbb{R}^g$
  • Output: Drug perturbation response embedding $p \in \mathbb{R}^k$ (latent representation summarizing compound + dose effect per cell)
  • Architecture: ScBMLPRegressor with MSE loss (predicting the learned/defined per-cell perturbation embedding)
  • Interpretation: Gene interaction modules reveal coordinated pathways (cell cycle control, stress response, metabolic rewiring) mediating compound- and dose-specific responses.

Usage as a package

Quick Start

import scvelo as scv
import scanpy as sc
from scbmlp.datasets import get_classification_datasets
from scbmlp.models import ScBMLPClassifier, Config
import torch
import einops

# Load pancreas endocrinogenesis data
adata = scv.datasets.pancreas()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, subset=True, n_top_genes=10000)

# Create train/validation/test datasets
train_dataset, val_dataset, test_dataset = get_classification_datasets(
    adata, class_key="clusters", random_state=42
)

# Configure and train model
cfg = Config(
    d_input=10000,
    d_hidden=128,
    d_output=adata.obs["clusters"].nunique(),
    n_epochs=100,
    lr=1e-4,
    batch_size=64,
)
model = ScBMLPClassifier(cfg)
model.fit(train_dataset, val_dataset)

# Extract gene interaction matrices for each cell type
Q = einops.einsum(
    model.w_p,
    model.w_l,
    model.w_r,
    "out hid, hid in1, hid in2 -> out in1 in2",
)

# Analyze modules for each cell type
for cell_type_idx in range(adata.obs["clusters"].nunique()):
    vals, vecs = torch.linalg.eigh(Q[cell_type_idx])
    vals, vecs = vals.flip([0]), vecs.flip([1])
    top_genes = adata.var_names[vecs[:, 0].topk(20).indices]
    print(f"Cell type {cell_type_idx}: {top_genes.tolist()}")

Installation

Prerequisites

  • Python ≥3.9
  • PyTorch ≥1.13.0

From Source

git clone https://github.com/kmaherx/ScBMLP.git
cd ScBMLP
pip install -e .

Virtual Environment (Recommended)

python -m venv scbmlp-env
source scbmlp-env/bin/activate  # on Windows: scbmlp-env\Scripts\activate
pip install -e .

About

Weights-based mechanistic interpretability applied to single-cell transcriptomics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •