[![Open In Colab](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/badge/open-in-colab.svg)](https://colab.research.google.com/github/crunchdao/quickstarters/blob/master/competitions/broad-obesity-1/quickstarters/regression-baseline/regression-baseline.ipynb)
[![Open In Kaggle](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/badge/open-in-kaggle.svg)](https://www.kaggle.com/code/crunchdao/broad-obesity-1-regression-baseline)

![Banner](https://raw.githubusercontent.com/crunchdao/quickstarters/refs/heads/master/competitions/broad-obesity-1/assets/banner.webp)

# Obesity ML Competition: Tackling metabolic diseases

## Crunch 1 – Predicting the effect of held-out single-gene perturbations

In Crunch 1, we will explore how well we can predict the single-cell transcriptomic response to single-gene perturbations that were not measured and provided in the training dataset.

# Setup

The first steps to get started are:
1. Get the setup command
2. Execute it in the cell below

### >> https://hub.crunchdao.com/competitions/broad-obesity-1/submit/notebook

![Reveal token](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/animations/reveal-token.gif)

In [None]:
# Install the Crunch CLI
%pip install --upgrade crunch-cli

# Setup your local environment
!crunch setup-notebook broad-obesity-1 aaaabbbbccccddddeeeeffff

# Regression baseline

## Setup

In [None]:
# Install required dependencies
%pip install anndata scanpy

In [1]:
import gc
import os
import typing

# Import your dependencies
import anndata
import h5py
import numpy as np
import pandas as pd
import psutil
import scanpy
import scipy
import sklearn  # <1.8
from tqdm.notebook import tqdm

In [None]:
import crunch

# Load the Crunch Toolings
crunch_tools = crunch.load_notebook()

## Constants

Store your global values here.

In [4]:
# @crunch/keep:on

# Label identifying negative control cells (no gene perturbation), used as the baseline reference
control_label = "NC"

## Understanding the Data

The data was downloaded when you setup your local environment and is now available in the `data/` directory.

- `obesity_challenge_1.h5ad`: The dataset includes perturbations targeting genes. For each perturbation, we provide **single-cell gene expression** (RNA-seq) profiles measured at the day 14 of adipocyte differentiation, annotated with gene perturbation identity, quality control (QC) metrics, and cell metadata. The training dataset contains a subset of these perturbations, while **a distinct set of single-gene perturbations is held out for validation and test for the leaderboard**.
- `obesity_challenge_1_local_gtruth.h5ad`: A local test set containing single-gene perturbations that do not appear in the training file. This dataset allows you to evaluate your method offline before submitting predictions.
- `predict_perturbations.txt`: A list of **2,863 unseen single-gene perturbations** for which you must predict the transcriptomic effect. These constitute the validation and test gene targets used on the official leaderboard.
- `genes_to_predict.txt`: A list of gene names (columns) to include in your predictions. Your model must generate expression values for this specific set of genes and the order of columns in the final prediction file must follow this list. The set of genes may change between validation and test phases.

In [5]:
def load_data(
    data_directory_path: str = "data",
):
    # If backed='r', load AnnData in backed mode instead of fully loading it into memory
    adata_train = scanpy.read_h5ad(os.path.join(data_directory_path, "obesity_challenge_1.h5ad"), backed="r")

    return adata_train

In [6]:
# Load training data from the 'obesity_challenge_1.h5ad' file
adata_train = load_data()

### Understanding `adata_train`

- The dataset is provided in [`AnnData` format](https://anndata.readthedocs.io/en/stable/) (.h5ad).

- Normalized gene expression values are stored in `adata.X` after per-cell total count normalization followed by $\log_2(1 + x)$ transformation (standard single-cell RNA-seq normalization; see [lecture 2 of the crash course](https://docs.crunchdao.com/competitions/competitions/broad-obesity/crash-course#lecture-2)).

- Raw gene expression counts prior to normalization are stored in `adata.layers['counts']` for reproducibility and alternative preprocessing.

- The perturbation target gene information is provided in `adata.obs['gene']`, with values corresponding to either `"NC"` for control cells or to the target gene name if the cell is perturbed. Control cells receive a perturbation that has no effect on the cell’s RNA-Seq profile.

- Cell state/program enrichment information is provided in `.obs`, with columns `pre_adipo`, `adipo`, `lipo`, and `other` indicating whether each cell was enriched for pre-adipocyte, adipocyte, or lipogenic programs. Other was defined as cells that were not enriched for either pre-adipocyte or adipocyte programs. **Program enrichment assignments were based on expert-curated canonical signature genes, and the list of signature genes is provided in `signature_genes.csv`.**

- We provide the cell state proportion for each of the perturbations in a separate file `program_proportion.csv`.

- During preprocessing, standard single-cell quality control (QC) was applied to remove low-quality cells and cell doublets based on sequencing library complexity, gene detection rate, and mitochondrial gene content. The dataset was then restricted to cells with a single confident guide assignment to a perturbation, and guides represented by fewer than 10 cells were excluded. Genes detected in fewer than 10 cells were removed, and known signature genes from `signature_genes.csv` were subsequently re-introduced.

**The .obs columns are defined as:**

- **cell index column**: The original sample ID.

- **nCount_RNA**: The number of Unique Molecular Identifiers (UMIs) detected per cell.

- **nFeature_RNA**: Number of genes with at least one detected UMI in the cell.

- **nCount_guide**: The number of single guide RNA (sgRNA) UMIs detected per cell.

- **nFeature_guide**: The number of sgRNAs detected per cell.

- **percent.mt**: The fraction of UMIs per cell that map to mitochondrial transcripts.

- **SampleID**: The sample ID.

- **Day**: The day of sample collection.

- **num_features**: The number of guides per cell (for low multiplicity of infection (MOI) data, after quality control, only the cells with 1 guide are kept).

- **feature_call**: The guide assignment of each cell.

- **num_umis**: The number of guide UMIs per cell.

- **gene**: **The perturbation target gene** (or perturbation identity).

- Cell state/program enrichment information is provided, with columns **pre_adipo**, **adipo**, **lipo**, and **other** indicating whether each cell was enriched for pre-adipocyte, adipocyte, or lipogenic programs.


**The "Cell Identity":**

Pre-adipocyte: Early-stage cells that haven't differentiated yet.

Adipocyte: Mature fat cells (the standard white fat).

Lipogenic: Specialized fat-producing cells.

Other: Cells that followed a different developmental path.

To understand how these **cell identity program proportions** were computed, please refer to the [program_analysis section in the GitHub repository](https://github.com/julielaffy/obesity-broad-ml-competition-2025?tab=readme-ov-file).

In [7]:
# Meta data for the cells in adata_train
# cells with "NC" in column "gene" are control cells (no gene perturbation)
adata_train.obs

Unnamed: 0_level_0,nCount_RNA,nFeature_RNA,nCount_guide,nFeature_guide,percent.mt,SampleID,Day,num_features,feature_call,num_umis,gene,adipo,pre_adipo,other,lipo
cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
AATF_P1_AACACACCAAGGTCCA-1,6779,2107,41,12,1.047352,TF150_1,Day14,1,AATF_P1P2_JW2,30,AATF,0,0,1,0
AATF_P1_AACGCTATCACCTAGA-1,67952,8001,119,1,3.358253,TF150_1,Day14,1,AATF_P1P2_JW3,119,AATF,0,0,1,0
AATF_P1_AAGGTTAAGATTCACG-1,52474,7595,259,4,7.298853,TF150_1,Day14,1,AATF_P1P2_JW3,256,AATF,1,0,0,0
AATF_P1_AATCGGCTCACTATGG-1,62788,7786,320,4,4.389374,TF150_1,Day14,1,AATF_P1P2_JW3,317,AATF,0,0,1,0
AATF_P1_AATCGTTAGCGATAGG-1,65373,7985,331,6,6.450675,TF150_1,Day14,1,AATF_P1P2_JW3,325,AATF,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZNF480_P7_AGTATCACATAATCGA-1,89522,8706,76,2,4.662541,TF150_7,Day14,1,ZNF480_P1P2_JW4,75,ZNF480,1,0,0,1
ZNF480_P7_ATCAAGGGTTGAGGTA-1,61802,7887,137,5,4.562959,TF150_7,Day14,1,ZNF480_P1P2_JW4,133,ZNF480,0,0,1,0
ZNF480_P7_ATCATTCCAAGCCGAC-1,85171,8599,230,7,3.454227,TF150_7,Day14,1,ZNF480_P1P2_JW1,224,ZNF480,0,1,0,0
ZNF480_P7_ATCCTGTTCCCTGGTC-1,71994,8411,92,5,3.764203,TF150_7,Day14,1,ZNF480_P1P2_JW4,88,ZNF480,1,0,0,1


### Subset of X_train

Normalized gene expression values are stored in `adata.X` after per-cell total count normalization followed by $\log_2(1 + x)$ transformation (standard single-cell RNA-seq normalization).
Use `adata.layers['counts']` instead of `adata.X` if you want to use unnormalized expression.

In [8]:
number_cells = 100
print(f"Subset of X_train: {number_cells} cells")
pd.DataFrame(adata_train.X[:number_cells, :].toarray(), index=adata_train.obs.index[:number_cells], columns=adata_train.var.index)

Subset of X_train: 100 cells


gene,MEX3D,LNCTAM34A,NIBAN1,ARHGEF28,PER2,KCNJ2,HDAC7,KCNIP2-AS1,ARMC7,NLK,...,LY86-AS1,RNASEL,TTC29,RFC5,ZNF517,DDX60,MYEOV,C1orf87,ASF1B,FLAD1
cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AATF_P1_AACACACCAAGGTCCA-1,0.000000,0.0,3.977611,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.000000
AATF_P1_AACGCTATCACCTAGA-1,0.000000,0.0,1.979513,1.305562,0.000000,0.0,1.305562,0.000000,0.000000,2.437068,...,0.000000,0.000000,0.0,0.000000,0.000000,1.305562,0.0,0.0,0.0,0.000000
AATF_P1_AAGGTTAAGATTCACG-1,2.266591,0.0,0.000000,0.000000,0.000000,0.0,1.538997,1.538997,1.538997,3.108306,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.0,1.538997
AATF_P1_AATCGGCTCACTATGG-1,1.374547,0.0,1.374547,0.000000,0.000000,0.0,2.065479,0.000000,0.000000,1.374547,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.0,2.065479
AATF_P1_AATCGTTAGCGATAGG-1,1.339050,0.0,2.831753,0.000000,0.000000,0.0,0.000000,0.000000,1.339050,2.482730,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.0,1.339050
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
AATF_P2_CGAATTTAGCCTCTAT-1,0.000000,0.0,4.422914,1.094992,0.000000,0.0,0.000000,0.000000,1.094992,1.710277,...,0.000000,1.094992,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.0,1.710277
AATF_P2_CGGGTTTAGCCAACTA-1,0.000000,0.0,0.000000,0.000000,0.000000,0.0,2.448929,0.000000,0.000000,1.314233,...,0.000000,0.000000,0.0,1.990375,1.314233,1.314233,0.0,0.0,0.0,1.990375
AATF_P2_CGTGGCTGTGACGAAC-1,0.923467,0.0,3.030926,0.923467,0.923467,0.0,0.923467,0.000000,0.923467,1.481989,...,0.923467,0.000000,0.0,0.000000,1.481989,0.000000,0.0,0.0,0.0,1.883622
AATF_P2_CTCCTGATCGCATCCC-1,1.844848,0.0,1.199170,0.000000,0.000000,0.0,1.844848,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.000000


In [9]:
# Convert the `.X` matrix of the AnnData object to a dense NumPy array
def to_array(
    adata: anndata.AnnData,
) -> np.typing.NDArray[np.float64]:
    if isinstance(adata.X, np.ndarray):
        return adata.X

    return adata.X.toarray()

In [10]:
# If you want to convert to a Numpy array
# X_train_values = to_array(adata_train)
# X_train_values.shape

print(f"Train set: {adata_train.X.shape[0]} cells and {adata_train.X.shape[1]} gene columns")

adata_train.X

Train set: 44846 cells and 11046 gene columns


CSRDataset: backend hdf5, shape (44846, 11046), data_dtype float64

## Prepare Local Train/Test Datasets

To speed up experimentation, the **small** setting (`--size small`) uses a reduced dataset (~4 GB) and a predefined split:

- **Local train:** `obesity_challenge_1.h5ad` (44846 cells, 11046 gene columns and 123 unique perturbed genes) <br />
  Contains single-cell expression profiles for a subset of gene perturbations used for training.

- **Local test:** `obesity_challenge_1_local_gtruth.h5ad` <br />
  A held-out subset of perturbations used locally to evaluate model performance.

If you prefer, you can create your **own train/test split** or run **cross-validation** by switching to the full dataset:

- **Full dataset mode:** `--size default` (88202 cells, 21592 gene columns and 123 unique perturbed genes) <br />
  Downloads the entire available dataset (~18 GB), allowing you to design your own evaluation strategy.

In [11]:
def prepare_local_split_adata(
    adata,
    data_dir="data",
    control_label="NC",
    filter_genes=True,
    filter_cells=True,
    number_cells_per_gene=100,
    test_split_mode="fixed",
    fixed_test_genes=None,
    test_ratio=0.2,
    genes_to_predict_file="genes_to_predict.txt",
    perturbations_file="predict_perturbations.txt",
    random_seed=42,
):
    """
    Prepares local training and test AnnData objects.

    Parameters
    ----------
    adata : AnnData
        Input AnnData object.
    data_dir : str
        Directory to read/write data files.
    control_label : str
        Label for control cells.
    number_cells_per_gene : int
        Maximum number of cells to keep per gene.
    test_split_mode : str
        "fixed" or "ratio" for test/train split.
    fixed_test_genes : list
        List of test genes if test_split_mode="fixed".
    test_ratio : float
        Ratio of genes to keep as test if test_split_mode="ratio".
    genes_to_predict_file : str
        Filename for genes (columns) to predict (one gene per line).
    perturbations_file : str
        Filename for predicted perturbations (one gene per line).
    random_seed : int
        Seed for reproducible random splits.

    Returns
    -------
    local_adata_train : AnnData
        Filtered training AnnData.
    test_adata_eval : AnnData
        Filtered test AnnData with additional info in `uns`.
    test_genes: list
        List of test genes
    """

    if filter_genes:
        # Load predicted perturbations and genes to predict
        predict_perturbations = set(pd.read_csv(os.path.join(data_dir, perturbations_file), header=None)[0])
        observed_genes = set(adata.obs["gene"].unique()) - {control_label}
        genes_to_predict = set(pd.read_csv(os.path.join(data_dir, genes_to_predict_file), header=None)[0])

        # Keep genes in any of genes_to_predict, observed, or predicted perturbations
        genes_to_keep = genes_to_predict | observed_genes | predict_perturbations
        genes_to_keep = list(genes_to_keep & set(adata.var.index))
    else:
        genes_to_keep = adata.var.index

    print("Number of genes kept:", len(genes_to_keep))

    if filter_cells:
        # Separate control and perturbed cells
        control_mask = adata.obs["gene"] == control_label
        perturbed_mask = ~control_mask

        # Compute number of control cells to keep based on original ratio
        n_control_total = control_mask.sum()
        n_perturbed_total = perturbed_mask.sum()
        control_fraction = n_control_total / (n_control_total + n_perturbed_total)

        # Subsample perturbed cells: up to number_cells_per_gene per gene
        perturbed_cells_to_keep = adata.obs[perturbed_mask].groupby("gene", observed=True).head(number_cells_per_gene).index
        n_perturbed_keep = len(perturbed_cells_to_keep)
        n_control_keep = int(control_fraction * (n_perturbed_keep / (1 - control_fraction)))

        # Randomly sample control cells proportionally
        control_cells_to_keep = adata.obs[control_mask].sample(n=min(n_control_keep, n_control_total), random_state=42).index

        # Combine
        cells_to_keep = perturbed_cells_to_keep.union(control_cells_to_keep)
        print(f"Number of perturbed cells kept: {len(perturbed_cells_to_keep)}")
        print(f"Number of control cells kept: {len(control_cells_to_keep)}")
    else:
        cells_to_keep = adata.obs.index

    print("Total cells kept:", len(cells_to_keep))

    adata_filtered = adata[cells_to_keep, genes_to_keep].copy()

    # Create train/test gene split
    genes = adata_filtered.obs["gene"].unique().tolist()
    if test_split_mode == "fixed":
        if fixed_test_genes is None:
            fixed_test_genes = []
        test_genes = fixed_test_genes
        train_genes = [g for g in genes if g not in test_genes]
    elif test_split_mode == "ratio":
        np.random.seed(random_seed)
        genes.remove(control_label)
        genes = np.random.permutation(genes).tolist()
        n_test = int(len(genes) * test_ratio)
        test_genes = genes[:n_test]
        train_genes = genes[n_test:] + [control_label]
    else:
        raise ValueError(f"Unknown test_split_mode: {test_split_mode}")

    print("Num train genes:", len(train_genes))
    print("Num test genes:", len(test_genes))

    # Filter cells by gene split
    local_adata_train = adata_filtered[adata_filtered.obs["gene"].isin(train_genes)].copy()
    local_adata_test = adata_filtered[adata_filtered.obs["gene"].isin(test_genes)].copy()

    # Convert X to dense if needed
    train_X = to_array(local_adata_train)
    test_X = to_array(local_adata_test)

    # Compute control and perturbed centroids from train
    control_mask = local_adata_train.obs["gene"] == control_label
    assert control_mask.sum() > 0, "No control cells found in train set"
    control_centroid = train_X[control_mask].mean(axis=0)
    perturbed_centroid = train_X[~control_mask].mean(axis=0)

    # Prepare test AnnData with additional info in `uns`
    test_adata_eval = anndata.AnnData(
        X=test_X,
        obs=local_adata_test.obs.copy(),
        uns=dict(
            control_centroid_train=control_centroid,
            perturbed_centroid_train=perturbed_centroid
        )
    )

    print("local_adata_train:\n", local_adata_train)
    print("test_adata_eval:\n", test_adata_eval)

    # Save results
    local_adata_train.write(os.path.join(data_dir, "local_train.h5ad"))
    test_adata_eval.write(os.path.join(data_dir, "local_test_eval.h5ad"))
    with open(os.path.join(data_dir, "local_test_perturbated_genes.txt"), "w") as f:
        f.write("\n".join(test_genes))

    return local_adata_train, test_adata_eval, test_genes

You could use this to create your own split / cross-validation.

For this quickstarter, we will use the provided split:
- local train: `obesity_challenge_1.h5ad`
- and local test: `obesity_challenge_1_local_gtruth.h5ad`

In [12]:
# local_adata_train, test_adata_eval, test_genes = prepare_local_split_adata(
#     adata_train,
#     filter_genes=True,
#     filter_cells=True,
#     number_cells_per_gene=50,
#     test_split_mode="ratio",
#     test_ratio=0.2,
# )

## Strategy Implementation: regression baseline - gene profiles

In this section, we implement a **perturbation-level regression baseline** that predicts gene expression responses and expands them to the single-cell level.

The approach proceeds as follows:

- **Aggregate training data at the perturbation level** by computing the mean expression profile of each gene perturbation. This yields a compact representation that avoids operating on the full single-cell matrix.

- **Learn a low-dimensional representation** of perturbation effects using Principal Component Analysis (**PCA**), followed by **a regularized linear regression** to model gene expression responses across perturbations.

- **Predict perturbation-level expression profiles** for unseen target genes using the trained PCA–regression model.

- **Expand perturbation-level predictions to the cell level** by repeating each predicted profile a fixed number of times, corresponding to synthetic cells.

- **Inject gene-wise Gaussian noise**, with variance estimated from the training perturbation means, to approximate biological variability while preserving scalability.

- Package the resulting synthetic cells into an AnnData object with the appropriate observation and variable metadata.

In [13]:
def compute_perturbation_means(
    adata_train: anndata.AnnData,
    gene_col: str = "gene",
    n_cells_per_perturbation: int | None = None,
):
    """
    Compute mean expression per perturbation without loading full X into memory.

    Returns
    -------
    Y_train : pd.DataFrame
        Genes × perturbations matrix
    """

    perturbations = adata_train.obs[gene_col].astype("category")
    gene_names = adata_train.var_names

    perturbation_means = {}

    for pert in tqdm(perturbations.cat.categories):
        pert_mask = perturbations == pert

        if pert_mask.sum() == 0:
            continue

        # Optional downsampling
        if n_cells_per_perturbation is not None:
            idx = (
                adata_train.obs.loc[pert_mask]
                .head(n_cells_per_perturbation)
                .index
            )
            X_pert = adata_train[idx].X
        else:
            X_pert = adata_train[pert_mask].X

        # Compute mean safely for sparse / dense
        pert_mean = np.asarray(X_pert.mean(axis=0)).ravel()
        perturbation_means[pert] = pert_mean

    # Build genes × perturbations DataFrame
    Y_train = pd.DataFrame(
        perturbation_means,
        index=gene_names
    )

    return Y_train

 Compute the mean expression profile of each gene perturbation:

In [14]:
# ======================================================
# STEP 1: Aggregate Training Data per Perturbation
# ======================================================
Y_train_gene_exp = compute_perturbation_means(adata_train, n_cells_per_perturbation=300)
Y_train_gene_exp # shape: (n_genes, n_perturbations)

  0%|          | 0/123 [00:00<?, ?it/s]

Unnamed: 0_level_0,AATF,ABT1,AFF1,ANKRA2,BCL6,BDP1,BRCA1,BTAF1,BTG2,CBFB,...,ZNF141,ZNF146,ZNF215,ZNF254,ZNF283,ZNF331,ZNF334,ZNF354B,ZNF419,ZNF480
gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
MEX3D,0.801940,0.648885,0.717339,0.719150,0.628722,0.659352,0.772598,0.663117,0.772341,0.707593,...,0.742868,0.726407,0.717818,0.585405,0.735812,0.809255,0.697943,0.748158,0.659558,0.684653
LNCTAM34A,0.221202,0.147458,0.176294,0.220083,0.119103,0.157657,0.210497,0.185567,0.223767,0.233119,...,0.197217,0.230852,0.204709,0.188129,0.174588,0.242954,0.202724,0.181750,0.264274,0.228765
NIBAN1,1.800681,1.783491,1.840822,1.517416,1.684059,1.704172,1.661471,1.736631,1.696868,2.050926,...,1.796720,1.714299,1.588454,1.696785,1.908280,1.819232,1.803758,1.832787,1.918774,1.777491
ARHGEF28,0.410630,0.408500,0.390281,0.520536,0.440210,0.545472,0.278301,0.501184,0.475083,0.422723,...,0.371797,0.414036,0.529561,0.406471,0.477613,0.397199,0.519653,0.379641,0.377131,0.400610
PER2,0.199727,0.232834,0.254728,0.169376,0.113593,0.191881,0.230349,0.246556,0.217533,0.211207,...,0.244051,0.199651,0.183931,0.210403,0.173499,0.187054,0.211336,0.146639,0.192058,0.184599
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DDX60,0.306737,0.328172,0.407768,0.308033,0.312921,0.302626,0.269351,0.297162,0.345882,0.357660,...,0.299636,0.303833,0.318676,0.303219,0.301906,0.304018,0.306228,0.339541,0.317460,0.279997
MYEOV,0.008205,0.000000,0.009504,0.000000,0.000000,0.004700,0.000000,0.004732,0.003927,0.004313,...,0.000000,0.000000,0.006178,0.004357,0.008768,0.014092,0.009516,0.000000,0.015788,0.004576
C1orf87,0.015403,0.007291,0.013503,0.004594,0.012991,0.019702,0.015203,0.012066,0.000000,0.004294,...,0.016985,0.004006,0.005066,0.004372,0.011247,0.004540,0.004598,0.000000,0.000000,0.004863
ASF1B,0.014893,0.037829,0.044730,0.040228,0.043262,0.031503,0.019173,0.025893,0.062201,0.042065,...,0.032756,0.036437,0.074029,0.024444,0.027514,0.033922,0.036754,0.031687,0.040896,0.036570


In [15]:
# ======================================================
# STEP 2: PCA on Training Data
# ======================================================
from sklearn.decomposition import PCA

def prepare_pca_for_regression(
    Y_train: pd.DataFrame,
    control_label: str = "NC",
    n_components: int = 10,
    display_output: bool = True
):
    """
    Perform PCA on the training data and prepare perturbation-level matrices for regression.

    Parameters
    ----------
    Y_train : pd.DataFrame
        Genes × perturbations matrix (aggregated per perturbation).
    control_label : str
        Label of the control perturbation to exclude.
    n_components : int
        Number of PCA components.
    display_output : bool
        Whether to display PCA matrices and Y_train for sanity check.

    Returns
    -------
    pca_genes : pd.DataFrame
        Full PCA representation of all genes.
    pca_perturbations : pd.DataFrame
        PCA representation of selected perturbations used for regression.
    Y_train_subset : pd.DataFrame
        Subset of Y_train columns matching selected perturbations.
    gene_mean_vector : pd.Series
        Mean expression per gene across selected perturbations.
    """
    # ------------------------------------------------------
    # PCA on genes × perturbations
    # ------------------------------------------------------
    pca = PCA(n_components=n_components)
    principal_components = pca.fit_transform(Y_train)
    pca_df = pd.DataFrame(principal_components, index=Y_train.index)

    # ------------------------------------------------------
    # Prepare perturbation list for regression
    # Exclude control and keep only perturbations that are also genes
    # ------------------------------------------------------
    train_perturbations = Y_train.columns.values.tolist()
    if control_label in train_perturbations:
        train_perturbations.remove(control_label)
    train_perturbations = [ptb for ptb in train_perturbations if ptb in Y_train.index]

    # Full PCA representation of all genes
    pca_genes = pca_df.copy()

    # PCA representation of perturbations used for regression
    pca_perturbations = pca_genes.loc[train_perturbations]

    # Subset Y_train to match selected perturbations
    Y_train_subset = Y_train[pca_perturbations.index]

    # Gene-wise mean across selected perturbations
    gene_mean_vector = Y_train_subset.mean(axis=1)

    # Sanity check
    assert set(pca_perturbations.index) == set(Y_train_subset.columns)

    # Optional display
    if display_output:
        print("PCA representation of perturbations (rows=perturbation genes, columns=PCA components)")
        print("(It plays the role of X_train for regression)")
        display(pca_perturbations)
        print("\nY_train:")
        display(Y_train_subset.T)

    return pca_genes, pca_perturbations, Y_train_subset, gene_mean_vector

pca_genes, pca_perturbations, Y_train_gene_exp_subset, gene_mean_vector = prepare_pca_for_regression(
    Y_train_gene_exp,
    control_label="NC",
    n_components=10,
    display_output=True
)

PCA representation of perturbations (rows=perturbation genes, columns=PCA components)
(It plays the role of X_train for regression)


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9
gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
AATF,8.604356,-0.074695,0.069683,0.033321,0.219478,-0.029004,0.029453,-0.113226,0.000268,0.000304
ABT1,-2.100445,-0.032756,-0.062296,-0.047705,-0.066221,-0.007291,0.044781,0.111082,0.032072,-0.002289
AFF1,4.581484,-0.246102,0.101982,-0.138330,0.174381,-0.017872,0.258326,0.107344,0.104765,-0.022510
ANKRA2,0.106986,0.007098,-0.095608,-0.080267,0.321377,-0.103199,0.076587,0.023118,-0.105169,-0.075170
BCL6,17.916299,0.261334,0.145843,0.134444,0.036186,0.195170,-0.057083,-0.329898,0.019132,-0.058063
...,...,...,...,...,...,...,...,...,...,...
ZNF331,-1.604509,-0.174777,-0.155737,-0.217065,-0.091361,0.000113,-0.120636,-0.005180,-0.088044,-0.198560
ZNF334,-8.915408,-0.007816,-0.030514,-0.006382,0.005296,0.003854,-0.109315,0.161640,0.099072,-0.067741
ZNF354B,-4.108700,0.619342,0.084007,-0.202947,0.042006,0.029707,0.228061,-0.079778,0.072382,-0.072972
ZNF419,-11.719326,-0.024960,-0.016346,0.025844,-0.033803,0.029082,0.008869,0.011603,0.038593,-0.027275



Y_train:


gene,MEX3D,LNCTAM34A,NIBAN1,ARHGEF28,PER2,KCNJ2,HDAC7,KCNIP2-AS1,ARMC7,NLK,...,LY86-AS1,RNASEL,TTC29,RFC5,ZNF517,DDX60,MYEOV,C1orf87,ASF1B,FLAD1
AATF,0.801940,0.221202,1.800681,0.410630,0.199727,0.063945,1.643240,0.196687,0.259148,1.630660,...,0.249752,0.235438,0.004822,0.390048,0.118791,0.306737,0.008205,0.015403,0.014893,1.163528
ABT1,0.648885,0.147458,1.783491,0.408500,0.232834,0.107059,1.608334,0.202129,0.219984,1.639126,...,0.444420,0.255868,0.004065,0.364859,0.104723,0.328172,0.000000,0.007291,0.037829,1.197529
AFF1,0.717339,0.176294,1.840822,0.390281,0.254728,0.070754,1.666717,0.149300,0.240083,1.547383,...,0.350297,0.277749,0.005151,0.400475,0.190977,0.407768,0.009504,0.013503,0.044730,1.187638
ANKRA2,0.719150,0.220083,1.517416,0.520536,0.169376,0.067999,1.556885,0.253034,0.267486,1.711355,...,0.380542,0.245874,0.010580,0.322367,0.145086,0.308033,0.000000,0.004594,0.040228,1.190062
BCL6,0.628722,0.119103,1.684059,0.440210,0.113593,0.052900,1.597000,0.205062,0.300392,1.706626,...,0.275024,0.167642,0.003827,0.308869,0.113717,0.312921,0.000000,0.012991,0.043262,1.268831
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZNF331,0.809255,0.242954,1.819232,0.397199,0.187054,0.088170,1.641512,0.185611,0.255814,1.773537,...,0.199603,0.244482,0.009115,0.373115,0.187955,0.304018,0.014092,0.004540,0.033922,1.092434
ZNF334,0.697943,0.202724,1.803758,0.519653,0.211336,0.050217,1.713922,0.255777,0.289250,1.756987,...,0.215579,0.262741,0.008709,0.363784,0.202647,0.306228,0.009516,0.004598,0.036754,1.146263
ZNF354B,0.748158,0.181750,1.832787,0.379641,0.146639,0.053729,1.709781,0.148526,0.220272,1.606879,...,0.195656,0.222565,0.000000,0.370311,0.198948,0.339541,0.000000,0.000000,0.031687,1.144783
ZNF419,0.659558,0.264274,1.918774,0.377131,0.192058,0.041497,1.770884,0.187163,0.214975,1.621950,...,0.238976,0.207856,0.000000,0.358737,0.118337,0.317460,0.015788,0.000000,0.040896,1.021013


In [16]:
# ======================================================
# STEP 3: Fit Linear Regression in PCA Space (W Matrix)
# ======================================================
# Goal: Learn a linear mapping from perturbation features (pca_perturbations)
# to gene expression in PCA space. This W matrix will be used to predict
# the expression profiles of new perturbations.

def fit_ridge_pca_regression(
    Y_train: pd.DataFrame,
    G: pd.DataFrame,
    pca_perturbations: pd.DataFrame,
    mean_vector: pd.Series,
    ridge_alpha: float = 0.1
):
    """
    Fit a linear Ridge regression model in PCA space to learn a mapping
    from perturbation features to output profiles.

    The output profiles can represent either gene expression levels or program-level proportions

    Parameters
    ----------
    Y_train : pd.DataFrame
        Training output matrix for selected perturbations.
        Shape: (n_outputs × n_perturbations),
        where outputs are genes or programs.
    G : pd.DataFrame
        PCA representation of outputs (genes or programs).
        Shape: (n_outputs × n_pca_components).
    pca_perturbations : pd.DataFrame
        PCA representation of perturbations used for regression.
        Shape: (n_perturbations × n_pca_components).
    mean_vector : pd.Series
        Mean output value (expression or proportion) per output,
        used for centering during training.
    ridge_alpha : float
        Ridge regularization strength (λ).

    Returns
    -------
    W : np.ndarray
        Regression weight matrix mapping perturbation features
        to output profiles in PCA space.
    """
    # Convert DataFrames/Series to NumPy arrays for matrix operations
    Y_train_mtrx = Y_train.values                  # genes × perturbations
    G_mtrx = G.values                              # genes × PCA components
    P_mtrx = pca_perturbations.values              # perturbations × PCA components
    b_mtrx = mean_vector.values               # mean expression per gene

    # Center Y by subtracting per-gene mean
    Y_centered = Y_train_mtrx - b_mtrx[:, np.newaxis]

    # Compute W using closed-form Ridge regression in PCA space
    # Formula: W = (G^T G + λ I)^-1 * G^T * (Y - b) * P * (P^T P + λ I)^-1
    I_G = np.eye(G_mtrx.shape[1])
    GtG_inv = np.linalg.inv(G_mtrx.T @ G_mtrx + ridge_alpha * I_G)

    Gt_Y_b = G_mtrx.T @ Y_centered

    I_P = np.eye(P_mtrx.shape[1])
    PtP_inv = np.linalg.inv(P_mtrx.T @ P_mtrx + ridge_alpha * I_P)

    W = GtG_inv @ Gt_Y_b @ P_mtrx @ PtP_inv

    return W

W = fit_ridge_pca_regression(
    Y_train=Y_train_gene_exp_subset,
    G=pca_genes,
    pca_perturbations=pca_perturbations,
    mean_vector=gene_mean_vector,
    ridge_alpha=0.1
)
print("W shape:", W.shape)

W shape: (10, 10)


Load the list of gene perturbations to predict.

For the local version, we evaluate on a separate test file (`"obesity_challenge_1_local_gtruth.h5ad"`):

In [17]:
# For your local evaluation, we use a separate test file with ground truth data.
# The file contains a subset of genes for which we want to predict the perturbation effects.
gtruth = scanpy.read_h5ad(os.path.join("data", "obesity_challenge_1_local_gtruth.h5ad"), backed="r")
predict_perturbations = gtruth.obs["gene"].cat.categories.tolist()

print("Local test gene perturbations:", predict_perturbations)

Local test gene perturbations: ['CHD4', 'FOXC1', 'SOX6', 'TRIM5', 'ZBTB20']


In the full challenge, you would load the 2863 unseen perturbations via `predict_perturbations.txt`:

In [18]:
# predict_perturbations = (
#     pd.read_csv(os.path.join("data", "predict_perturbations.txt"), header=None)[0]
#     .values
# )
#
# print("All gene perturbations:", predict_perturbations)

Generate predictions using the learned W matrix.

In [19]:
# ======================================================
# STEP 4: Predict on Validation Set
# ======================================================
# Goal: Use the learned W matrix to predict the gene expression profiles
# of unseen perturbations in PCA space, then transform back to gene space.

def predict_perturbation_expression(
    W: np.ndarray,
    G: pd.DataFrame,
    pca_genes: pd.DataFrame,
    mean_vector: pd.Series,
    target_perturbations: list[str]
):
    """
    Predict gene expression profiles or program proportion for unseen perturbations using the
    learned PCA-space regression (W matrix).

    Parameters
    ----------
    W : np.ndarray
        Learned regression weight matrix mapping perturbation features
        to expression (or program) space in PCA coordinates.
    G : pd.DataFrame
        PCA representation of output variables (genes or programs),
        shape: (n_outputs × n_pca_components).
    pca_genes : pd.DataFrame
        PCA representation of all genes (genes × PCA components).
    mean_vector : pd.Series
        Per-output mean (gene expression or program proportion)
        used to re-center predictions.
    target_perturbations : list of str
        List of perturbation genes to predict.

    Returns
    -------
    Y_hat_df : pd.DataFrame
        Predicted perturbation-level expression profiles or program proportion.
        Rows = perturbations, Columns = genes or programs.
    """

    # ------------------------------------------------------
    # Select the PCA representation of target perturbations
    # Ensures the order matches target_perturbations
    # ------------------------------------------------------
    P_pred = pca_genes.loc[target_perturbations]
    P_pred_mtrx = P_pred.values

    # ------------------------------------------------------
    # Compute predictions in output space
    # Formula:
    #   Y_hat = G · W · P_predᵀ + b
    #
    # where:
    #   G       : outputs × PCA components
    #   W       : regression weights
    #   P_predᵀ : PCA components × perturbations
    #   b       : per-output mean (added back after centering)
    # ------------------------------------------------------
    G_mtrx = G.values
    b_mtrx = mean_vector.values

    Y_hat = G_mtrx @ W @ P_pred_mtrx.T
    Y_hat = Y_hat + b_mtrx[:, np.newaxis]  # Add back mean

    # Transpose so rows = perturbations, columns = genes
    Y_hat = Y_hat.T

    # Convert to DataFrame for convenience
    Y_hat_df = pd.DataFrame(
        Y_hat,
        index=P_pred.index,         # perturbation genes
        columns=G.index     # all genes or programs
    )

    return Y_hat_df

Y_hat_df_gene_exp = predict_perturbation_expression(
    W=W,
    G=pca_genes,
    pca_genes=pca_genes,
    mean_vector=gene_mean_vector,
    target_perturbations=predict_perturbations
)
Y_hat_df_gene_exp

gene,MEX3D,LNCTAM34A,NIBAN1,ARHGEF28,PER2,KCNJ2,HDAC7,KCNIP2-AS1,ARMC7,NLK,...,LY86-AS1,RNASEL,TTC29,RFC5,ZNF517,DDX60,MYEOV,C1orf87,ASF1B,FLAD1
gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CHD4,0.716206,0.187553,1.826635,0.439815,0.198105,0.06682,1.706065,0.1539,0.266262,1.674663,...,0.208652,0.252496,0.004654,0.35402,0.143845,0.324288,0.008178,0.008293,0.030142,1.12574
FOXC1,0.699685,0.187667,1.79336,0.44616,0.199444,0.069188,1.65267,0.180808,0.26678,1.689904,...,0.282258,0.247558,0.005098,0.351556,0.144908,0.31503,0.008511,0.009249,0.033912,1.146814
SOX6,0.708853,0.18612,1.808989,0.443491,0.199735,0.067709,1.672638,0.171374,0.267163,1.684944,...,0.257179,0.247631,0.004636,0.352057,0.144036,0.317015,0.007926,0.008678,0.033587,1.142117
TRIM5,0.697785,0.187252,1.810748,0.446584,0.206927,0.069441,1.653757,0.1848,0.268141,1.701036,...,0.298921,0.24577,0.005848,0.351831,0.148577,0.307621,0.00952,0.00967,0.036055,1.151326
ZBTB20,0.705224,0.192957,1.781921,0.43523,0.209105,0.075844,1.655918,0.185451,0.262525,1.687528,...,0.283569,0.244927,0.008896,0.352047,0.147982,0.306423,0.012669,0.012369,0.035002,1.150072


In [20]:
# ======================================================
# STEP 5: Expand Perturbation-level Prediction to Cell-level
# ======================================================
# Goal: Generate synthetic single-cell expression profiles for each
# target perturbation by repeating the predicted perturbation-level
# profile (Y_hat_df) multiple times, optionally adding gene-wise noise.

def expand_perturbation_to_cells(
    Y_hat_df: pd.DataFrame,
    Y_train: pd.DataFrame,
    predict_perturbations: list[str],
    adata_var: pd.DataFrame,
    cells_per_perturbation: int = 100,
    add_train_std: bool = True,
    random_seed: int | None = 42
) -> anndata.AnnData:
    """
    Expand perturbation-level predictions to synthetic single-cell expression profiles.

    Parameters
    ----------
    Y_hat_df : pd.DataFrame
        Predicted perturbation-level gene expression (rows = perturbations, columns = genes).
    Y_train : pd.DataFrame
        Training perturbation-level gene expression (genes × perturbations), used to compute gene-wise std.
    predict_perturbations : list[str]
        List of perturbation genes to expand to synthetic cells.
    adata_var : pd.DataFrame
        Variable/gene information to include in the AnnData object.
    cells_per_perturbation : int
        Number of synthetic cells to generate per perturbation.
    add_train_std : bool
        Whether to add Gaussian noise based on training gene-wise std.
    random_seed : int | None
        Optional random seed for reproducibility.

    Returns
    -------
    prediction_adata : anndata.AnnData
        Synthetic single-cell predictions.
        Rows = cells, Columns = genes
    """
    if random_seed is not None:
        np.random.seed(random_seed)

    n_genes = adata_var.shape[0]
    n_perturbations = len(predict_perturbations)

    # Initialize the prediction matrix
    prediction_matrix = np.zeros((n_perturbations * cells_per_perturbation, n_genes))

    # Compute gene-wise std from training data
    train_std = Y_train.std(axis=1).values  # shape: (n_genes,)

    # Fill prediction matrix perturbation by perturbation
    for i, pert in enumerate(predict_perturbations):
        start = i * cells_per_perturbation
        end = start + cells_per_perturbation

        if pert not in Y_hat_df.index:
            # No prediction available -> leave zeros
            continue

        # Base perturbation-level prediction
        pred_profile = Y_hat_df.loc[pert].values

        # Optionally add gene-wise Gaussian noise
        if add_train_std:
            noise = np.random.normal(
                loc=0.0,
                scale=train_std,
                size=(cells_per_perturbation, n_genes)
            )
            pred_profile = pred_profile + noise

        prediction_matrix[start:end] = pred_profile

    # Construct obs["gene"] for the AnnData object
    # Each perturbation label is repeated 'cells_per_perturbation' times
    # Total rows = n_perturbations * cells_per_perturbation
    # Example: 5 perturbations × 100 cells -> 500 rows
    obs_gene = np.repeat(predict_perturbations, cells_per_perturbation)

    # Build AnnData object
    prediction_adata = anndata.AnnData(
        X=prediction_matrix,
        obs={"gene": obs_gene},
        var=adata_var.copy()  # preserve gene metadata
    )

    return prediction_adata


prediction = expand_perturbation_to_cells(
    Y_hat_df=Y_hat_df_gene_exp,
    Y_train=Y_train_gene_exp_subset,
    predict_perturbations=predict_perturbations,
    adata_var=adata_train.var,
    cells_per_perturbation=100,
    add_train_std=True,
    random_seed=42
)

# The resulting object contains synthetic single-cell predictions
# with one row per synthetic cell and one column per gene
prediction

AnnData object with n_obs × n_vars = 500 × 11046
    obs: 'gene'

### Why do we generate 100 predicted synthetic cells per perturbation?

Single-cell RNA-seq produces **many individual cells per perturbation**, not a single expression vector. Even when the same gene is perturbed, cells show diverse responses:
- different effect magnitudes,
- alternative differentiation paths,
- alternative metabolic states,
- natural transcriptional variability.

This heterogeneity is real **biological signal**, not noise, and the challenge evaluates how well a model reproduces the distribution of cell states.

The two evaluation metrics reflect this: **MMD** compares full distributions using pairwise similarities between cells and **Pearson Delta** evaluates perturbation effects relative to perturbed means. Predicting only one cell would collapse the distribution, give unstable or biased estimates, and distort correlations.

By generating **100 (could be more) predicted synthetic cells per perturbation**, it create enough samples to capture heterogeneity, support distribution-based metrics, stabilize mean/variance estimates and realistically approximate the “cloud” of single-cell responses observed in real data.

## Strategy Implementation: Regression Baseline - Program Proportions

For program proportions, we apply the same PCA-based Ridge regression strategy used for gene-expression prediction.

Concretely:

- We load the training program proportion matrix (programs × perturbations) and transpose it to match the regression setup.

- A PCA is fitted on the program-proportion space to obtain a low-dimensional representation of programs.

- Using the PCA representation of perturbations (pca_perturbations) and the PCA representation of programs, we fit a Ridge regression in PCA space to learn a linear mapping (W_program_proportion) from perturbation features to program proportions.

- This learned mapping is then used to predict program proportions for unseen perturbations, followed by adding back the mean program-proportion vector to obtain predictions in the original space.

In [21]:
# load the (train) program proportion matrix (and transpose it to match the regression setup)
train_program_proportion = pd.read_csv(os.path.join("data", "program_proportion.csv"), index_col="gene")
# (Use the "default" version of the dataset (not the small one) and you will get the exact match with program_proportion.csv)
Y_train_program_proportion = train_program_proportion.T
Y_train_program_proportion

gene,HIF1A,ZNF26,TWIST1,CEBPA,HIF3A,TSHZ2,TRRAP,RB1,ZFP2,NC,...,PPARD,ZHX3,PDCD11,SUPT5H,NAB1,CNOT8,CEBPB,EEF1A1,HMGA1,SRPK1
pre_adipo,0.304972,0.427457,0.451327,0.461864,0.431746,0.435484,0.345476,0.441501,0.331034,0.366456,...,0.384615,0.361851,0.383636,0.368201,0.348993,0.474576,0.489796,0.387755,0.331658,0.414634
adipo,0.237569,0.268331,0.223009,0.086864,0.22381,0.209677,0.289072,0.225166,0.321839,0.257553,...,0.294872,0.291725,0.249091,0.248954,0.313758,0.172881,0.136054,0.267857,0.345059,0.195122
other,0.457459,0.304212,0.325664,0.451271,0.344444,0.354839,0.365452,0.333333,0.347126,0.375991,...,0.320513,0.346424,0.367273,0.382845,0.337248,0.352542,0.37415,0.344388,0.323283,0.390244
lipo,0.060773,0.057722,0.067257,0.014831,0.065079,0.069892,0.085781,0.059603,0.075862,0.06996,...,0.138889,0.095372,0.074545,0.085774,0.09396,0.027119,0.030612,0.071429,0.092127,0.0
lipo_adipo,0.255814,0.215116,0.301587,0.170732,0.29078,0.333333,0.296748,0.264706,0.235714,0.271632,...,0.471014,0.326923,0.29927,0.344538,0.299465,0.156863,0.225,0.266667,0.26699,0.0


In [22]:
# ------------------------------------------------------
# PCA on programs × perturbations
# ------------------------------------------------------
K = 3
pca = PCA(n_components=K)
principal_components = pca.fit_transform(Y_train_program_proportion)
pca_program_proportion = pd.DataFrame(data=principal_components, index=Y_train_program_proportion.index)

print("W matrix will be a mapping between pca_program_proportion and pca_perturbations:")
print("\npca_program_proportion:")
display(pca_program_proportion)
print("\npca_perturbations:")
display(pca_perturbations)

W matrix will be a mapping between pca_program_proportion and pca_perturbations:

pca_program_proportion:


Unnamed: 0,0,1,2
pre_adipo,1.296724,-0.373225,-0.132756
adipo,-0.155706,0.473354,0.266442
other,0.949588,-0.172355,0.221742
lipo,-2.130675,-0.330967,-0.007975
lipo_adipo,0.040069,0.403192,-0.347453



pca_perturbations:


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9
gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
AATF,8.604356,-0.074695,0.069683,0.033321,0.219478,-0.029004,0.029453,-0.113226,0.000268,0.000304
ABT1,-2.100445,-0.032756,-0.062296,-0.047705,-0.066221,-0.007291,0.044781,0.111082,0.032072,-0.002289
AFF1,4.581484,-0.246102,0.101982,-0.138330,0.174381,-0.017872,0.258326,0.107344,0.104765,-0.022510
ANKRA2,0.106986,0.007098,-0.095608,-0.080267,0.321377,-0.103199,0.076587,0.023118,-0.105169,-0.075170
BCL6,17.916299,0.261334,0.145843,0.134444,0.036186,0.195170,-0.057083,-0.329898,0.019132,-0.058063
...,...,...,...,...,...,...,...,...,...,...
ZNF331,-1.604509,-0.174777,-0.155737,-0.217065,-0.091361,0.000113,-0.120636,-0.005180,-0.088044,-0.198560
ZNF334,-8.915408,-0.007816,-0.030514,-0.006382,0.005296,0.003854,-0.109315,0.161640,0.099072,-0.067741
ZNF354B,-4.108700,0.619342,0.084007,-0.202947,0.042006,0.029707,0.228061,-0.079778,0.072382,-0.072972
ZNF419,-11.719326,-0.024960,-0.016346,0.025844,-0.033803,0.029082,0.008869,0.011603,0.038593,-0.027275


In [23]:
# ======================================================
# STEP 3: Fit Linear Regression in PCA Space (W Matrix)
# ======================================================
# Goal: Learn a linear mapping from perturbation features (pca_perturbations)
# to program proportion in PCA space. This W matrix will be used to predict
# the program proportion of new perturbations.

# We use the same function as before but genes representation is replaced by program proportion

Y_train_program_proportion = Y_train_program_proportion[pca_perturbations.index]  # keep order
# mean vector
program_proportion_mean_vector = Y_train_program_proportion.mean(axis=1)

W_program_proportion = fit_ridge_pca_regression(
    Y_train=Y_train_program_proportion,
    G=pca_program_proportion,
    pca_perturbations=pca_perturbations,
    mean_vector=program_proportion_mean_vector,
    ridge_alpha=0.1
)
print("W shape:", W_program_proportion.shape)

W shape: (3, 10)


In [24]:
# ======================================================
# STEP 4: Predict on Validation Set
# ======================================================
# Goal: Use the learned W matrix to predict the program proportion
# of unseen perturbations in PCA space, then to output space

Y_hat_df_program_proportion = predict_perturbation_expression(
    W=W_program_proportion,
    G=pca_program_proportion,
    pca_genes=pca_genes,
    mean_vector=program_proportion_mean_vector,
    target_perturbations=predict_perturbations
)
pred_proportion_df = Y_hat_df_program_proportion.reset_index()
pred_proportion_df

Unnamed: 0,gene,pre_adipo,adipo,other,lipo,lipo_adipo
0,CHD4,0.406188,0.250419,0.365248,0.065839,0.275131
1,FOXC1,0.371313,0.275189,0.347031,0.080894,0.288399
2,SOX6,0.382876,0.267353,0.354086,0.076233,0.282278
3,TRIM5,0.365621,0.278538,0.344431,0.08555,0.288685
4,ZBTB20,0.388453,0.260789,0.353218,0.075589,0.284777


### The `train()` Function

The `train()` function is intended to encapsulate the full training workflow used to fit a predictive model on the provided perturbation dataset.
In a complete implementation, this function would:

- Load the training dataset from `data_directory_path`
- Preprocess the gene expression matrix
- Construct and train a model
- Evaluate the model on a validation split
- Save the trained model and any required metadata to `model_directory_path`

**On the platform, `train()` will be run on the full large dataset `obesity_challenge_1.h5ad`.**

**Note**:
- ❗ We recommend training locally and submitting weights/models because the dataset is large and cloud resources are limited.
- Make sure that the `Method description.md` file properly documents your model, so that the Broad Institute team can reference your work in their publications.




In [25]:
def train(
    data_directory_path: str,
    model_directory_path: str,
):
    """
    Train a perturbation prediction model.

    This function is designed to:
      - Load training data from `data_directory_path`
      - Preprocess gene expression matrices
      - Train your model on the available perturbations
      - Save all required models into `model_directory_path`

    In this baseline notebook, the function is left empty.
    You can fill in the model pipeline you want to use.
    """

    pass

### The `infer()` Function

The `infer()` function performs model inference on a set of unseen perturbations.

During inference, the function:

1. Loads the training dataset.
2. **Performs PCA and regression to predict perturbation-level expression**:
The training gene-expression matrix is reduced to a low-dimensional PCA space (pca_gene_matrix). Selected perturbations (pca_perturbations) are used to fit a Ridge regression model (W matrix) that maps perturbation features to gene expression. This model is then used to predict expression profiles for unseen target perturbations. The predicted perturbation-level profiles are expanded to synthetic single cells by repeating each profile a fixed number of times (100 synthetic cells per perturbation) and optionally adding gene-wise Gaussian noise based on the training perturbation standard deviation.
3. Constructs an `AnnData` object containing the predicted gene-expression profiles.
4. Saves the predictions in the correct structure expected by the specification.
5. **Performs PCA and regression to predict cell-type proportions** and saves them as a CSV file.

[Expected Output](https://docs.crunchdao.com/competitions/competitions/broad-obesity/crunch-1#expected-output):

- `Method description.md`: At the __end of the notebook__ 👇👇👇, write a small text outlining the approaches used to generate both the predictions and the estimated proportions of cells enriched for each program.

- `prediction.h5ad`: An `AnnData` file containing predicted gene expression profiles normalized and log-transformed post-perturbation for 2,863 gene perturbations indicated in `predict_perturbations.txt`. Predictions should be stored in `adata.X` matrix with the corresponding perturbation identity recorded in `adata.obs['gene']`. For each gene perturbation, we ask you to predict the gene expression profiles for 100 synthetic cells to quantify the distribution of each perturbation prediction. The file with predictions should have dimensions [286,300 × len(genes_to_predict)] (cells × genes).

- `predict_program_proportion.csv`: A CSV file reporting the predicted proportion of cells with enriched programs for each gene perturbation listed in `predict_perturbations.txt`.

**Inference input:**
- `predict_perturbations`: A list of **single-gene perturbations** for which you must predict the transcriptomic effect.
- `genes_to_predict`: A list of **gene names (columns)** to include in your predictions. Your model must generate expression values for this specific set of genes and the order of columns in the final prediction file must follow this list.

**Provide an inference function that relies on these two inputs.**



In [None]:
def infer(
    data_directory_path: str,
    prediction_directory_path: str,
    prediction_h5ad_file_path: str,
    program_proportion_csv_file_path: str,
    model_directory_path: str,
    predict_perturbations: typing.List[str],
    genes_to_predict: typing.List[str],
    cells_per_perturbation: int = 100,
):
    """
    Run inference for a set of perturbations.

    Parameters:
        data_directory_path: str
            Path to the training AnnData file.
        prediction_directory_path: str
            Directory where prediction files can be written.
        prediction_h5ad_file_path: str
            Direct path where to write the `prediction.h5ad` file.
        program_proportion_csv_file_path: str
            Direct path where to write the `predict_program_proportion.csv` file.
        model_directory_path: str
            Directory containing your persisted model files.
        predict_perturbations: List[str]
            The perturbations for which to generate predictions.
        genes_to_predict: List[str]
            List of gene names (columns) to include in the prediction.h5ad AnnData object.
        cells_per_perturbation: int
            Number of synthetic cells to generate per perturbation.

    Return:
      None: Returned value is ignored.

    Expected files:
        prediction_directory_path / "prediction.h5ad": anndata.AnnData
            AnnData matrix of size (n_perturbations * cells_per_perturbation, n_genes_to_predict), containing the predicted gene expression values.
        prediction_directory_path / "predict_program_proportion.csv": pd.DataFrame
            DataFrame (index=False) containing estimated cell-type proportions for each perturbation.
    """

    print("Loading data...")
    global adata_train
    if "adata_train" not in globals():
        # Optimization for Google Colab, avoid loading the data twice
        adata_train = load_data(data_directory_path)

    print(
        "###############################################################\n"
        "# Gene profiles\n"
        "###############################################################"
    )

    # ======================================================
    # STEP 1: Aggregate Training Data per Perturbation
    # ======================================================
    print("Aggregate Training Data per Perturbation")
    Y_train_gene_exp = compute_perturbation_means(adata_train, n_cells_per_perturbation=300) # shape: (n_genes, n_perturbations)

    # ======================================================
    # STEP 2: PCA on Training Data
    # ======================================================
    print("PCA gene profiles")
    pca_genes, pca_perturbations, Y_train_gene_exp_subset, gene_mean_vector = prepare_pca_for_regression(
        Y_train_gene_exp,
        control_label="NC",
        n_components=10,
        display_output=False
    )

    # ======================================================
    # STEP 3: Fit Linear Regression in PCA Space (W Matrix)
    # ======================================================
    print("Compute matrix W")
    W = fit_ridge_pca_regression(
        Y_train=Y_train_gene_exp_subset,
        G=pca_genes,
        pca_perturbations=pca_perturbations,
        mean_vector=gene_mean_vector,
        ridge_alpha=0.1
    )
    print("W shape:", W.shape)

    # ======================================================
    # STEP 4: Predict on Test Set
    # ======================================================
    print("Fit matrix W on test set")
    Y_hat_df_gene_exp = predict_perturbation_expression(
        W=W,
        G=pca_genes,
        pca_genes=pca_genes,
        mean_vector=gene_mean_vector,
        target_perturbations=predict_perturbations
    )

    # Compute gene-wise std from training data
    train_std = Y_train_gene_exp_subset[Y_train_gene_exp_subset.index.isin(genes_to_predict)].std(axis=1).values  # shape: (n_genes,)
    # Whether to add Gaussian noise based on training gene-wise std:
    add_train_std = True

    print("Infering the prediction...")
    n_genes = len(genes_to_predict)
    n_perturbations = len(predict_perturbations)
    n_cells = n_perturbations * cells_per_perturbation

    print(f"Predicting {n_cells} synthetic cells for {n_perturbations} perturbations.")
    print("Each row will contain", n_genes, "genes.")

    # Construct obs["gene"] (n_perturbations * cells_per_pert rows)
    obs = {"gene": np.repeat(predict_perturbations, cells_per_perturbation)}
    # Column variables
    var = adata_train[:, genes_to_predict].var.copy()

    # Clean any prior prediction file
    if os.path.exists(prediction_h5ad_file_path):
        os.remove(prediction_h5ad_file_path)

    # Temporary disk-mapped matrix file (will be removed after final .h5ad is written)
    temporary_prediction_path = os.path.join(prediction_directory_path, "prediction_matrix.h5")
    if os.path.exists(temporary_prediction_path):
        os.remove(temporary_prediction_path)

    # RAM estimation (informational)
    needed_ram = (n_cells * n_genes * 4) / (1024**3)  # GB
    available_ram = psutil.virtual_memory().available / (1024**3)
    print(f"Available RAM: {available_ram:.2f} GB")
    print(f"Needed RAM for full matrix: ~{needed_ram:.2f} GB")

    # Detecting whether you are running inside the crunch platform or not
    use_ram_intensive_mode = crunch.is_inside_runner
    # use_ram_intensive_mode = True  # Uncomment me to force it

    def fill_X(X, predict_perturbations, Y_hat_df_gene_exp, genes_to_predict,
               cells_per_perturbation, train_std, add_noise):
        """
        Fill the prediction matrix with synthetic single-cell gene expression profiles.
        For each perturbation, the function repeats the perturbation-level predicted
        gene expression and optionally adds gene-wise Gaussian noise to simulate
        cell-to-cell variability.
        """
        np.random.seed(42)
        for i, pert in tqdm(enumerate(predict_perturbations), total=len(predict_perturbations), desc="Generating synthetic cells"):
            start = i * cells_per_perturbation
            end = (i + 1) * cells_per_perturbation

            # Base perturbation-level prediction
            pred_profile = Y_hat_df_gene_exp.loc[pert][genes_to_predict].values.astype(np.float32)

            # Optionally add gene-wise Gaussian noise
            if add_noise:
                noise = np.random.normal(
                    loc=0.0,
                    scale=train_std,
                    size=(cells_per_perturbation, len(genes_to_predict))
                )
                pred_profile = pred_profile + noise

            X[start:end] = pred_profile

    if use_ram_intensive_mode:
        print("-> Using full in-memory matrix (fastest).")

        # Full RAM X
        X = np.zeros((n_cells, n_genes), dtype=np.float32)
        print("Prediction matrix shape:", X.shape)

        fill_X(X, predict_perturbations, Y_hat_df_gene_exp, genes_to_predict,
               cells_per_perturbation, train_std, add_train_std)

        prediction = anndata.AnnData(X=X, obs=obs, var=var)
        del X

        # Save to .h5ad
        print("Saving the prediction...")
        prediction.write_h5ad(prediction_h5ad_file_path)
    else:
        print("-> Using HDF5 (low-memory mode).")
        # Adjust batch size
        batch_size = cells_per_perturbation #1024 if n_cells > 5000 else 50

        # Create an HDF5 dataset on disk and write predictions batch-by-batch.
        # This avoids storing a huge (e.g. 286,300 × len(genes_to_predict)) matrix in RAM.
        with h5py.File(temporary_prediction_path, "w") as f:
            print("Prediction matrix shape:", (n_cells, n_genes))
            dset = f.create_dataset(
                "X",
                shape=(n_cells, n_genes),
                dtype="float32",
                chunks=(batch_size, n_genes),
                # compression="gzip"
            )

            fill_X(dset, predict_perturbations, Y_hat_df_gene_exp, genes_to_predict,
               cells_per_perturbation, train_std, add_train_std)

        print("Finished writing matrix.")

        # Use an empty in-memory array with the same shape, then replace X with the HDF5 dataset
        prediction = anndata.AnnData(X=np.zeros((n_cells, n_genes), dtype=np.float32), obs=obs, var=var)

        with h5py.File(temporary_prediction_path, "r") as f:
            # Point X to the HDF5 file
            prediction.X = f["X"]

            # Save to .h5ad
            print("Saving the prediction...")
            prediction.write_h5ad(prediction_h5ad_file_path)

        # Remove the temporary HDF5 file
        if os.path.exists(temporary_prediction_path):
            os.remove(temporary_prediction_path)

    print(
        "###############################################################\n"
        "# Program Proportions\n"
        "###############################################################"
    )

    # STEP 1: Load the (train) program proportion matrix (and transpose it to match the regression setup)
    train_program_proportion = pd.read_csv(os.path.join(data_directory_path, "program_proportion.csv"), index_col="gene")
    Y_train_program_proportion = train_program_proportion.T

    # STEP 2: PCA on programs × perturbations
    print("PCA program proportion")
    K = 3
    pca = PCA(n_components=K)
    principal_components = pca.fit_transform(Y_train_program_proportion)
    pca_program_proportion = pd.DataFrame(data=principal_components, index=Y_train_program_proportion.index)

    # ======================================================
    # STEP 3: Fit Linear Regression in PCA Space (W Matrix)
    # ======================================================
    print("Compute matrix W")
    Y_train_program_proportion = Y_train_program_proportion[pca_perturbations.index]  # keep order
    # mean vector
    program_proportion_mean_vector = Y_train_program_proportion.mean(axis=1)

    W_program_proportion = fit_ridge_pca_regression(
        Y_train=Y_train_program_proportion,
        G=pca_program_proportion,
        pca_perturbations=pca_perturbations,
        mean_vector=program_proportion_mean_vector,
        ridge_alpha=0.1
    )
    print("W shape:", W_program_proportion.shape)

    # ======================================================
    # STEP 4: Predict on Validation Set
    # ======================================================
    print("Fit matrix W on test set")
    Y_hat_df_program_proportion = predict_perturbation_expression(
        W=W_program_proportion,
        G=pca_program_proportion,
        pca_genes=pca_genes,
        mean_vector=program_proportion_mean_vector,
        target_perturbations=predict_perturbations
    )
    pred_proportion_df = Y_hat_df_program_proportion.reset_index()

    # (pre_adipo + adipo + other) must be equal to 1
    cols = ["pre_adipo", "adipo", "other"]
    pred_proportion_df[cols] = (
        pred_proportion_df[cols]
        .div(pred_proportion_df[cols].sum(axis=1), axis=0)
    )

    # lipo must be less or equal to adipo
    pred_proportion_df["lipo"] = pred_proportion_df[["lipo", "adipo"]].min(axis=1)

    print("Saving the proportion...")
    pred_proportion_df.drop(columns=["lipo_adipo"], errors="ignore", inplace=True)
    pred_proportion_df.to_csv(program_proportion_csv_file_path, index=False)

    del prediction
    gc.collect()

    print("Finished !")

## Local testing

To make sure your `train()` and `infer()` function are working properly, you can call the `crunch_tools.test()` function that will reproduce the cloud environment locally. <br />
Even if it is not perfect, it should give you a quick idea if your model is working properly.

**Note**: Locally, the infer function will be run on only 5 perturbated genes.

In [None]:
crunch_tools.test()

In [28]:
genes_to_predict = pd.read_csv(os.path.join("data", "genes_to_predict.txt"), header=None)[0].values

## If instead you want to predict all genes (columns):
# genes_to_predict = adata_train.var.index.to_list()

print("About to predict", len(genes_to_predict), "genes")

About to predict 10238 genes


In [None]:
## In the full challenge, you would load the 2863 unseen perturbations via predict_perturbations.txt.
# predict_perturbations = (
#     pd.read_csv(os.path.join("data", "predict_perturbations.txt"), header=None)[0]
#     .values
# )
## Test infer function on local test set "obesity_challenge_1_local_gtruth.h5ad"
gtruth = scanpy.read_h5ad(os.path.join("data", "obesity_challenge_1_local_gtruth.h5ad"), backed="r")
predict_perturbations = gtruth.obs["gene"].cat.categories.tolist()

print("Local test gene perturbations:", predict_perturbations)

infer(
    data_directory_path="data/",
    prediction_directory_path="prediction/",
    model_directory_path="resources/",
    predict_perturbations=predict_perturbations,
    genes_to_predict=genes_to_predict,
    prediction_h5ad_file_path=os.path.join("prediction", "prediction.h5ad"),
    program_proportion_csv_file_path=os.path.join("prediction", "predict_program_proportion.csv")
)

## Results

Once the local tester is done, you can preview the result stored in `prediction/prediction.parquet`.

In [30]:
# If backed='r', load AnnData in backed mode instead of fully loading it into memory
prediction = scanpy.read_h5ad(os.path.join("prediction", "prediction.h5ad"))#, backed="r")
prediction

AnnData object with n_obs × n_vars = 500 × 10238
    obs: 'gene'

In [31]:
number_cells = 102
print(f"Subset of X_train: {number_cells} cells")
pd.DataFrame(prediction.X[:number_cells, :], index=prediction.obs["gene"][:number_cells], columns=prediction.var.index)

Subset of X_train: 102 cells


gene,CNDP1,PROS1,HACD2,Z98752.4,PCLAF,SP100,ALKBH4,RGCC,DNAJC25,SLC2A5,...,TRPM8,TRPV1,TRPV2,UCP3,WNT10B,ADIG,BMP7,LRG1,MB,VSTM2A
gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CHD4,0.037713,0.746573,1.634118,0.101823,0.037341,2.989954,1.212286,0.281577,2.521185,0.046670,...,-0.064588,0.239637,0.005315,-0.015155,0.010072,-0.052506,-0.003307,0.000071,0.020309,-0.013256
CHD4,-0.090243,0.779270,1.777196,-0.014558,0.029499,3.006407,1.216747,0.218300,2.492968,0.053271,...,0.013518,0.260848,0.003329,0.023807,0.066110,-0.015750,0.000169,-0.007047,0.014384,0.001083
CHD4,0.046467,0.771688,1.266429,-0.041052,0.012108,2.977184,1.003166,0.292967,2.575722,-0.015400,...,0.084289,0.338093,-0.001310,-0.007682,0.015721,0.081599,0.016849,-0.003881,-0.028588,-0.014657
CHD4,-0.058773,0.734685,1.399504,-0.066238,0.105272,2.980090,1.194868,0.232971,2.499562,0.044181,...,0.041567,0.203072,0.000962,0.006205,0.072698,-0.046553,0.007706,-0.003313,-0.003606,0.024327
CHD4,-0.020750,0.813098,1.272668,0.053975,0.042132,2.993988,1.078725,0.292898,2.545202,0.080754,...,0.048708,0.219761,-0.005340,0.045742,0.067912,-0.092860,0.001636,0.005748,0.017132,-0.062779
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CHD4,0.061200,0.725051,1.490097,-0.018409,0.066511,2.985085,1.109044,0.282769,2.516453,0.058828,...,-0.140279,0.272090,0.002335,0.063488,0.002276,-0.056713,0.012740,-0.001539,-0.013066,0.042439
CHD4,0.015995,0.716518,1.503511,0.055535,0.051432,2.985452,1.073213,0.286233,2.515693,-0.035001,...,-0.101231,0.283209,-0.002794,0.020687,0.050987,0.001446,0.010473,-0.011076,0.021479,-0.047753
CHD4,0.039516,0.713974,1.683520,-0.018435,0.048895,3.024854,0.932447,0.201798,2.534642,0.069860,...,0.065075,0.310904,0.002718,0.009903,0.019625,-0.020288,0.008866,0.001966,0.012249,0.041308
FOXC1,0.065955,0.684546,1.518314,0.073549,0.027721,2.979673,1.223528,0.193653,2.535188,-0.092337,...,0.040487,0.285277,0.002271,0.015356,0.052653,0.081390,0.010527,0.002662,0.011454,0.067922


In [32]:
predicted_proportion = pd.read_csv(os.path.join("prediction", "predict_program_proportion.csv"))
predicted_proportion

Unnamed: 0,gene,pre_adipo,adipo,other,lipo,lipo_adipo
0,CHD4,0.406188,0.250419,0.365248,0.065839,0.275131
1,FOXC1,0.371313,0.275189,0.347031,0.080894,0.288399
2,SOX6,0.382876,0.267353,0.354086,0.076233,0.282278
3,TRIM5,0.365621,0.278538,0.344431,0.08555,0.288685
4,ZBTB20,0.388453,0.260789,0.353218,0.075589,0.284777


### Local scoring

You can call the function that the system uses to estimate your score locally.

In [33]:
gtruth = scanpy.read_h5ad(os.path.join("data", "obesity_challenge_1_local_gtruth.h5ad"))#, backed="r")
gtruth

AnnData object with n_obs × n_vars = 1500 × 11046
    obs: 'gene'
    uns: 'control_centroid_train', 'perturbed_centroid_train'

In [34]:
gtruth_proportion = pd.read_csv(os.path.join("data", "program_proportion_local_gtruth.csv"))
gtruth_proportion

Unnamed: 0,gene,pre_adipo,adipo,other,lipo,lipo_adipo
0,ZBTB20,0.408851,0.189673,0.401475,0.037935,0.2
1,FOXC1,0.401361,0.27551,0.323129,0.073696,0.26749
2,SOX6,0.418065,0.238065,0.343871,0.073548,0.308943
3,CHD4,0.306306,0.326126,0.367568,0.064865,0.198895
4,TRIM5,0.349301,0.309381,0.341317,0.117764,0.380645


### Transcriptome-wide metrics

These metrics are computed using **either a subset of genes or all genes** (i.e., columns of the predicted matrix) for each perturbation.

- For the **public leaderboard** (updated weekly), the evaluation uses **1,000 hidden genes**.
The **MMD metric** is computed with a fixed sigma value of **2326**, calibrated specifically for 1,000-dimensional gene vectors.
- For the **private leaderboard**, both the **number** and the **identity** of the scoring genes will remain unknown to participants.

In this quickstarter example, we follow the same spirit and randomly select 1,000 genes to compute transcriptome-wide evaluation metrics.

In [35]:
# Retrieve the ordered list of predicted gene names (columns)
genes_to_predict_prediction = prediction.var.index.to_list()
print("Number of genes (columns) in the prediction:", len(genes_to_predict_prediction))

# Number of genes to use for computing metrics
n_genes_for_metric = 1000

# Randomly sample genes for the metric computation
rng = np.random.default_rng(seed=42)  # You should try with different seed to make sure you are ready for the private leaderboard
genes_for_metric = rng.choice(genes_to_predict_prediction, size=n_genes_for_metric, replace=False)

# Index for selecting the same genes from adata_train and local_gtruth
indexer_genes_gtruth = adata_train.var.index.get_indexer(genes_for_metric)

print("Number of genes (columns) used to compute metrics:", len(genes_for_metric))

Number of genes (columns) in the prediction: 10238
Number of genes (columns) used to compute metrics: 1000


#### Pearson Delta

The Pearson Delta ($\rho$) is the perturbation effects relative to perturbed mean ($\hat{X}_P$) between predicted ($\hat{X}$) and observed ($X$):

$
\displaystyle
\rho(X, \hat{X}) = \frac{1}{|P|} \sum_{p \in P} \text{cor}\left(\hat{X}_p - X_{PM}, X_p - X_{PM}\right)
$

where:
- $X_{PM}$ is the Perturbed Mean defined as the mean gene expression of all the perturbed cells (those receiving a single gene perturbation) in the training dataset,
- $P$ is the set of all perturbation targets,
- $|P|$ is the size of the set $P$,
- $\text{cor}(a, b)$ is the correlation between vectors $a$ and $b$.

---

*Higher is better!*

In [36]:
def compute_metric_pearson(
    gtruth_X: np.typing.NDArray[np.float64],
    pred_X: np.typing.NDArray[np.float64],
    perturbed_centroid: np.typing.NDArray[np.float64],
) -> float:
    gtruth_X_target = gtruth_X.mean(axis=0)
    pred_X_target = pred_X.mean(axis=0)

    return scipy.stats.pearsonr(
        gtruth_X_target - perturbed_centroid,
        pred_X_target - perturbed_centroid,
    ).statistic

perturbations = gtruth.obs["gene"].cat.categories.tolist()
scores_pearson = []

perturbed_centroid = gtruth.uns["perturbed_centroid_train"][indexer_genes_gtruth]

for p in tqdm(perturbations):
    # Select true and predicted cells for perturbation p
    gt_mask = gtruth.obs["gene"] == p
    pr_mask = prediction.obs["gene"] == p

    # Extract their filtered expression matrices
    gtruth_X = to_array(gtruth[gt_mask, genes_for_metric])
    pred_X   = to_array(prediction[pr_mask, genes_for_metric])

    # Skip perturbations without matching predicted cells
    if gtruth_X.shape[0] == 0 or pred_X.shape[0] == 0:
        print(f"skipping {p}: missing samples: will not be accepted on the platform!")
        scores_pearson.append(0)
        continue

    # Compute Pearson score for this perturbation
    score = compute_metric_pearson(
        gtruth_X=gtruth_X,
        pred_X=pred_X,
        perturbed_centroid=perturbed_centroid,
    )
    scores_pearson.append(score)

print(f"Pearson score: {np.mean(scores_pearson):.4f}")

  0%|          | 0/5 [00:00<?, ?it/s]

Pearson score: -0.0656


#### Maximum Mean Discrepancy (MMD)

The MMD is between the predicted and the observed distributions of single-cell profiles.

Let $X_p^i$ be the true expression profile and $\hat{X}_p^i$ be the predicted expression profile for cell $i$ with perturbation $p$.
  Then we calculate the MMD distance for a particular perturbation using the following formulae:

$
\displaystyle
MMD^2(X_p, \hat{X}_p) =
\frac{1}{N^2} \sum_{i,j} k(X_p^i, X_p^j) +
\frac{1}{N^2} \sum_{k,l} k(\hat{X}_p^k, \hat{X}_p^l) -
\frac{2}{N^2} \sum_{m,n} k(X_p^m, \hat{X}_p^n)
$

where:
- $k(a, b)$ is the Gaussian kernel with the bandwidth from the following list: $[581.5, 1163.0, 2326.0, 4652.0, 9304.0]$.

---

Then we average the MMD score across all the perturbations using the following formulae:

$
\displaystyle
MMD^2(X, \hat{X}) = \frac{1}{|P|} \sum_{p \in P} MMD^2(X_p, \hat{X}_p)
$

where:
- $|P|$ is the size of the set $P$.

---

*Lower is better!*

In [37]:
def _gaussian_kernel(source, target, kernel_mul, kernel_num, fix_sigma):
    # Getting the L2 distance
    n_samples = int(source.shape[0]) + int(target.shape[0])
    total = np.concatenate([source, target], axis=0)

    total0 = np.broadcast_to(np.expand_dims(total, axis=0), (int(total.shape[0]), int(total.shape[0]), int(total.shape[1])))
    total1 = np.broadcast_to(np.expand_dims(total, axis=1), (int(total.shape[0]), int(total.shape[0]), int(total.shape[1])))
    L2_distance = np.sum((total0 - total1)**2, axis=2)

    # Now we are ready to scale this using multiple bandwidth
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = np.sum(L2_distance) / (n_samples**2 - n_samples)

    # Now we will create the multiple bandwidth list
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [np.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]

    return sum(kernel_val)


def _compute_mmd(X_batch, Y_batch, kernel_mul, kernel_num, fix_sigma, kernel_func):
    num_batch_element = X_batch.shape[0]

    kernels = kernel_func(X_batch, Y_batch, kernel_mul, kernel_num, fix_sigma)
    XX = kernels[:num_batch_element, :num_batch_element]
    YY = kernels[num_batch_element:, num_batch_element:]
    XY = kernels[:num_batch_element, num_batch_element:]
    YX = kernels[num_batch_element:, :num_batch_element]
    mmd_val = np.sum(XX + YY - XY - YX)

    return mmd_val / num_batch_element ** 2


def balance_source_target_sample_per_perturbation(gtruth_X_tgt,pred_X_tgt):
    num_gtruth = gtruth_X_tgt.shape[0]
    num_pred = pred_X_tgt.shape[0]
    min_sample = min(num_gtruth,num_pred)

    return gtruth_X_tgt[0:min_sample,:], pred_X_tgt[0:min_sample,:]


def compute_metric_mmd(
    gtruth_X: np.typing.NDArray[np.float64],
    pred_X: np.typing.NDArray[np.float64],
) -> float:
    kernel_mul = 2.0
    kernel_num = 5
    fix_sigma = 2326 # with 1000 columns

    # Balancing the samples to compute mmd using equal number of samples
    gtruth_X, pred_X = balance_source_target_sample_per_perturbation(gtruth_X, pred_X)

    mmd_dist = _compute_mmd(
        gtruth_X,
        pred_X,
        kernel_mul,
        kernel_num,
        fix_sigma,
        _gaussian_kernel
    )

    return mmd_dist


perturbations = gtruth.obs["gene"].cat.categories.tolist()
scores_mmd = []

for p in tqdm(perturbations):
    # Select true and predicted cells for perturbation p
    gt_mask = gtruth.obs["gene"] == p
    pr_mask = prediction.obs["gene"] == p

    # Extract their filtered expression matrices
    gtruth_X = to_array(gtruth[gt_mask, genes_for_metric])
    pred_X   = to_array(prediction[pr_mask, genes_for_metric])

    # Skip perturbations without matching predicted cells
    if gtruth_X.shape[0] == 0 or pred_X.shape[0] == 0:
        print(f"skipping {p}: missing samples: will not be accepted on the platform!")
        scores_pearson.append(9)
        continue

    # Compute MMD for this perturbation
    score = compute_metric_mmd(
        gtruth_X=gtruth_X,
        pred_X=pred_X,
    )

    scores_mmd.append(score)

print(f"MMD score: {np.mean(scores_mmd):.4f}")

  0%|          | 0/5 [00:00<?, ?it/s]

MMD score: 0.7307


#### Program-level metrics.

These metrics evaluate whether models capture meaningful biological outcomes.

#### L1-distance

There are four cell state proportions for each perturbation, i.e., pre-adipogenic, adipogenic, lipogenic, and other.
For each perturbation $p$, we have the ground truth cell-state proportion $S_p = [preadipo, adipo, lipo, other]$.

Let $S_p^R$ be the vector that has the proportion of the cell states $[preadipo, adipo, other]$. <br />
Let $S_p^L$ denote the condition probability of a cell being in the lipogenic state given the adipogenic state, i.e., $S_p^L = lipo/adipo$.

Then we define the program level loss as:

$
\displaystyle
L1(\hat{S}, S) = \frac{1}{|P|} \sum_{p \in P} 0.75 * | S_p^R - \hat{S}_p^R |_1 + 0.25 * | \hat{S}_p^L - S_p^L |
$

where $|.|_1$ is the L₁-distance, $|P|$ is the number of perturbations, and $\hat{S}_p$ is the predicted cell-state proportions.

---

*Lower is better!*

In [38]:
def compute_metric_l1_distance(
    true_state_proportion_df: pd.DataFrame,
    pred_state_proprotion_df: pd.DataFrame,
) -> float:
    # Going over all the genes that were perturbed in this set
    unique_perturb_genes = list(true_state_proportion_df["gene"].unique())

    all_l1_loss_list = []
    for gene in unique_perturb_genes:
        # Slicing the column with this gene
        true_gene_df = true_state_proportion_df[true_state_proportion_df["gene"] == gene]
        pred_gene_df = pred_state_proprotion_df[pred_state_proprotion_df["gene"] == gene]

        # print(gene, pred_gene_df.shape[0])
        assert true_gene_df.shape[0] == 1 and pred_gene_df.shape[0] == 1, f"Invalid prediction count for state gene={gene} count={pred_gene_df.shape[0]}!=1"

        # Getting the L1 loss for main  pre, adipo and other
        l1_three = (
            np.abs(true_gene_df.iloc[0]["pre_adipo"] - pred_gene_df.iloc[0]["pre_adipo"]) +
            np.abs(true_gene_df.iloc[0]["adipo"] - pred_gene_df.iloc[0]["adipo"]) +
            np.abs(true_gene_df.iloc[0]["other"] - pred_gene_df.iloc[0]["other"])
        )

        # Getting the L1 loss for lipo by adipo
        numerical_stab_term = 1e-20
        pred_lipo_adipo = pred_gene_df.iloc[0]["lipo"] / (pred_gene_df.iloc[0]["adipo"] + numerical_stab_term)
        true_lipo_adipo = true_gene_df.iloc[0]["lipo"] / (true_gene_df.iloc[0]["adipo"] + numerical_stab_term)
        l1_lipo_adipo = np.abs(true_lipo_adipo - pred_lipo_adipo)

        # Getting the average error
        average_l1 = 0.75 * l1_three + 0.25 * l1_lipo_adipo
        all_l1_loss_list.append(average_l1)

    # Getting the overall average over all the gene perturbation
    l1_loss = np.mean(all_l1_loss_list)
    return float(l1_loss)


l1_distance = compute_metric_l1_distance(
    gtruth_proportion,
    predicted_proportion,
)

print(f"L1-distance score: {l1_distance:.4f}")

L1-distance score: 0.0884


## Writing the report

The final step is to write the method description as specified [in the documentation](https://docs.crunchdao.com/competitions/competitions/broad-obesity/crunch-1#file-method-description.md).

You must:
1. Explain how your method works. *(5-10 sentences)*
2. Describe the reasoning behind your gene panel design. *(5-10 sentences)*
3. Specify the datasets and any other resources utilized. *(5-10 sentences)*

The limit is about one page.
<br />
<br />
<br />

---

<span style="font-size: 48px">👇👇👇</span> (double-click the markdown cell below)

---
file: Method description.md
---

<!-- Don't forget to change me -->

# Method Description

Explain how your method works. (5-10 sentences)

# Rationale

Describe the reasoning behind your gene panel design. (5-10 sentences)

# Data and Resources Used

Specify the datasets and any other resources utilized. (5-10 sentences)

<span style="font-size: 48px">👆👆👆</span>

---
<br />
<br />
<br />

# Submit your Notebook

To submit your work, you must:
1. Download your Notebook from Colab
2. Upload it to the platform
3. Create a run to validate it

Executing the cell below will take care of everything (only available on Google Colab), or show you how to submit manually.

In [None]:
# @title  {"display-mode":"form", "form-width":"400px"}

# @markdown Describe your changes, then run the cell.
Message = "" # @param {"type":"string","placeholder":"Short description (optional)"}

# @markdown ---
# @markdown **Advanced:** Should the `requirements.txt` file be frozen using locally installed packages?
Pip_Freeze = False # @param {"type":"boolean"}

# ---
# THIS METHOD IS ONLY POSSIBLE ON COLAB.
# RUNNING THIS CELL WILL PROMPT YOU TO USE THE OLD WAY OF SUBMITTING A NOTEBOOK.

crunch_tools.submit(
    message=Message,
    include_installed_packages_version=Pip_Freeze,
)