[![Open In Colab](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/badge/open-in-colab.svg)](https://colab.research.google.com/github/crunchdao/quickstarters/blob/master/competitions/broad-obesity-1/quickstarters/perturbed-mean-baseline/perturbed-mean-baseline.ipynb)

![Banner](https://raw.githubusercontent.com/crunchdao/quickstarters/refs/heads/master/competitions/broad-obesity-1/assets/banner.webp)

# Obesity ML Competition: Tackling metabolic diseases

## Crunch 1 ‚Äì Predicting the effect of held-out single-gene perturbations

In Crunch 1, we will explore how well we can predict the single-cell transcriptomic response to single-gene perturbations that were not measured and provided in the training dataset.

# Setup

The first steps to get started are:
1. Get the setup command
2. Execute it in the cell below

### >> https://hub.crunchdao.com/competitions/broad-obesity-1/submit/notebook

![Reveal token](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/animations/reveal-token.gif)

In [None]:
# Install the Crunch CLI
%pip install --upgrade crunch-cli

# Setup your local environment
!crunch --env staging setup-notebook broad-obesity-1 aaaabbbbccccddddeeeeffff

# Your model

## Setup

In [None]:
# Install required dependencies
%pip install anndata scanpy

In [None]:
import gc
import os
import typing

# Import your dependencies
import anndata
import h5py
import numpy as np
import pandas as pd
import psutil
import scanpy
import scipy
from tqdm.notebook import tqdm

In [None]:
import crunch

# Load the Crunch Toolings
crunch_tools = crunch.load_notebook()

## Constants

Store your global values here.

In [6]:
# @crunch/keep:on

control_label = "NC"

## Understanding the Data

The data was downloaded when you setup your local environment and is now available in the `data/` directory.

- `obesity_challenge_1.h5ad`: The dataset includes perturbations targeting genes. For each perturbation, we provide **single-cell gene expression** (RNA-seq) profiles measured at the day 14 of adipocyte differentiation, annotated with gene perturbation identity, quality control (QC) metrics, and cell metadata. The training dataset contains a subset of these perturbations, while **a distinct set of single-gene perturbations is held out for validation and test for the leaderboard**.
- `obesity_challenge_1_local_gtruth.h5ad`: A local test set containing single-gene perturbations that do not appear in the training file. This dataset allows you to evaluate your method offline before submitting predictions.
- `predict_perturbations.txt`: A list of **2,863 unseen single-gene perturbations** for which you must predict the transcriptomic effect. These constitute the validation and test gene targets used on the official leaderboard.
- `genes_to_predict.txt`: A list of gene names (columns) to include in your predictions. Your model must generate expression values for this specific set of genes and the order of columns in the final prediction file must follow this list. The set of genes may change between validation and test phases.

In [7]:
def load_data(
    data_directory_path: str = "data",
):
    # If backed='r', load AnnData in backed mode instead of fully loading it into memory
    adata_train = scanpy.read_h5ad(os.path.join(data_directory_path, "obesity_challenge_1.h5ad"), backed="r")

    return adata_train

In [8]:
adata_train = load_data()

### Understanding `adata_train`

- The dataset is provided in [`AnnData` format](https://anndata.readthedocs.io/en/stable/) (.h5ad).

- Normalized gene expression values are stored in `adata.X` after per-cell total count normalization followed by $\log_2(1 + x)$ transformation (standard single-cell RNA-seq normalization; see [lecture 2 of the crash course](https://docs.crunchdao.com/competitions/competitions/broad-obesity/crash-course#lecture-2)).

- Raw gene expression counts prior to normalization are stored in `adata.layers['counts']` for reproducibility and alternative preprocessing.

- The perturbation target gene information is provided in `adata.obs['gene']`, with values corresponding to either `"NC"` for control cells or to the target gene name if the cell is perturbed. Control cells receive a perturbation that has no effect on the cell‚Äôs RNA-Seq profile.

- Cell state/program enrichment information is provided in `.obs`, with columns `pre_adipo`, `adipo`, `lipo`, and `other` indicating whether each cell was enriched for pre-adipocyte, adipocyte, or lipogenic programs. Other was defined as cells that were not enriched for either pre-adipocyte or adipocyte programs. Program enrichment assignments were based on expert-curated canonical signature genes, and the list of signature genes is provided in `signature_genes.csv`.

- We provide the cell state proportion for each of the perturbations in a separate file `program_proportion.csv`.

- During preprocessing, standard single-cell quality control (QC) was applied to remove low-quality cells and cell doublets based on sequencing library complexity, gene detection rate, and mitochondrial gene content. The dataset was then restricted to cells with a single confident guide assignment to a perturbation, and guides represented by fewer than 10 cells were excluded. Genes detected in fewer than 10 cells were removed, and known signature genes from `signature_genes.csv` were subsequently re-introduced.

In [9]:
adata_train.obs

Unnamed: 0_level_0,nCount_RNA,nFeature_RNA,nCount_guide,nFeature_guide,percent.mt,SampleID,Day,num_features,feature_call,num_umis,gene,adipo,pre_adipo,other,lipo
cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
HIF1A_P1_AAACCATTCGTTCCTG-1,77689,8476,349,5,6.053624,TF150_1,Day14,1,HIF1A_P1P2_JW3,345,HIF1A,0,0,1,0
ZNF26_P1_AAACCCGCAAGCCGAC-1,100670,9040,206,7,7.674580,TF150_1,Day14,1,ZNF26_P1P2_JW4,200,ZNF26,0,1,0,0
TWIST1_P1_AAACCCGCACTAGGCA-1,56415,7081,174,11,3.435257,TF150_1,Day14,1,TWIST1_P1P2_JW2,164,TWIST1,0,1,0,0
CEBPA_P1_AAACCCGCAGGAAGCG-1,85222,8283,421,6,7.487503,TF150_1,Day14,1,CEBPA_P1P2_JW4,416,CEBPA,0,0,1,0
HIF3A_P1_AAACCCTGTGCAAGCG-1,84351,9183,193,4,3.705943,TF150_1,Day14,1,HIF3A_P2_JW1,190,HIF3A,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
REXO4_P8_TGTGTACGTTATCTCT-1,44773,6567,130,4,3.908606,TF150_8,Day14,1,REXO4_P1P2_JW3,127,REXO4,0,0,1,0
EBF2_P8_TGTGTACGTTGCAATC-1,48752,7154,177,1,3.187562,TF150_8,Day14,1,EBF2_P1P2_JW4,177,EBF2,0,1,0,0
PLAGL1_P8_TGTGTTAGTATGGGAC-1,62875,7976,186,10,7.193638,TF150_8,Day14,1,PLAGL1_P2_JW2,177,PLAGL1,0,0,1,0
AFF1_P8_TGTGTTGAGGGTCTCG-1,63077,7718,325,9,3.392679,TF150_8,Day14,1,AFF1_P1P2_JW4,317,AFF1,0,0,1,0


In [10]:
def to_array(
    adata: anndata.AnnData,
) -> np.typing.NDArray[np.float64]:
    if isinstance(adata.X, np.ndarray):
        return adata.X

    return adata.X.toarray()

In [11]:
# X_train_values = to_array(adata_train)
# X_train_values.shape

adata_train.X

CSRDataset: backend hdf5, shape (88202, 21592), data_dtype float64

## Prepare Local Train/Test Datasets

To speed up experimentation, the **small** setting (`--size small`) uses a reduced dataset (~4 GB) and a predefined split:

- **Local train:** `obesity_challenge_1.h5ad` <br />
  Contains single-cell expression profiles for a subset of gene perturbations used for training.

- **Local test:** `obesity_challenge_1_local_gtruth.h5ad` <br />
  A held-out subset of perturbations used locally to evaluate model performance.

If you prefer, you can create your **own train/test split** or run **cross-validation** by switching to the full dataset:

- **Full dataset mode:** `--size default` <br />
  Downloads the entire available dataset (~18 GB), allowing you to design your own evaluation strategy.

In [None]:
def prepare_local_split_adata(
    adata,
    data_dir="data",
    control_label="NC",
    filter_genes=True,
    filter_cells=True,
    number_cells_per_gene=100,
    test_split_mode="fixed",
    fixed_test_genes=None,
    test_ratio=0.2,
    genes_to_predict_file="genes_to_predict.txt",
    perturbations_file="predict_perturbations.txt",
    random_seed=42,
):
    """
    Prepares local training and test AnnData objects.

    Parameters
    ----------
    adata : AnnData
        Input AnnData object.
    data_dir : str
        Directory to read/write data files.
    control_label : str
        Label for control cells.
    number_cells_per_gene : int
        Maximum number of cells to keep per gene.
    test_split_mode : str
        "fixed" or "ratio" for test/train split.
    fixed_test_genes : list
        List of test genes if test_split_mode="fixed".
    test_ratio : float
        Ratio of genes to keep as test if test_split_mode="ratio".
    genes_to_predict_file : str
        Filename for genes (columns) to predict (one gene per line).
    perturbations_file : str
        Filename for predicted perturbations (one gene per line).
    random_seed : int
        Seed for reproducible random splits.

    Returns
    -------
    local_adata_train : AnnData
        Filtered training AnnData.
    test_adata_eval : AnnData
        Filtered test AnnData with additional info in `uns`.
    test_genes: list
        List of test genes
    """

    if filter_genes:
        # Load predicted perturbations and genes to predict
        predict_perturbations = set(pd.read_csv(os.path.join(data_dir, perturbations_file), header=None)[0])
        observed_genes = set(adata.obs["gene"].unique()) - {control_label}
        genes_to_predict = set(pd.read_csv(os.path.join(data_dir, genes_to_predict_file), header=None)[0])

        # Keep genes in any of genes_to_predict, observed, or predicted perturbations
        genes_to_keep = genes_to_predict | observed_genes | predict_perturbations
        genes_to_keep = list(genes_to_keep & set(adata.var.index))
    else:
        genes_to_keep = adata.var.index

    print("Number of genes kept:", len(genes_to_keep))

    if filter_cells:
        # Separate control and perturbed cells
        control_mask = adata.obs["gene"] == control_label
        perturbed_mask = ~control_mask

        # Compute number of control cells to keep based on original ratio
        n_control_total = control_mask.sum()
        n_perturbed_total = perturbed_mask.sum()
        control_fraction = n_control_total / (n_control_total + n_perturbed_total)

        # Subsample perturbed cells: up to number_cells_per_gene per gene
        perturbed_cells_to_keep = adata.obs[perturbed_mask].groupby("gene", observed=True).head(number_cells_per_gene).index
        n_perturbed_keep = len(perturbed_cells_to_keep)
        n_control_keep = int(control_fraction * (n_perturbed_keep / (1 - control_fraction)))

        # Randomly sample control cells proportionally
        control_cells_to_keep = adata.obs[control_mask].sample(n=min(n_control_keep, n_control_total), random_state=42).index

        # Combine
        cells_to_keep = perturbed_cells_to_keep.union(control_cells_to_keep)
        print(f"Number of perturbed cells kept: {len(perturbed_cells_to_keep)}")
        print(f"Number of control cells kept: {len(control_cells_to_keep)}")
    else:
        cells_to_keep = adata.obs.index

    print("Total cells kept:", len(cells_to_keep))

    adata_filtered = adata[cells_to_keep, genes_to_keep].copy()

    # Create train/test gene split
    genes = adata_filtered.obs["gene"].unique().tolist()
    if test_split_mode == "fixed":
        if fixed_test_genes is None:
            fixed_test_genes = []
        test_genes = fixed_test_genes
        train_genes = [g for g in genes if g not in test_genes]
    elif test_split_mode == "ratio":
        np.random.seed(random_seed)
        genes.remove(control_label)
        genes = np.random.permutation(genes).tolist()
        n_test = int(len(genes) * test_ratio)
        test_genes = genes[:n_test]
        train_genes = genes[n_test:] + [control_label]
    else:
        raise ValueError(f"Unknown test_split_mode: {test_split_mode}")

    print("Num train genes:", len(train_genes))
    print("Num test genes:", len(test_genes))

    # Filter cells by gene split
    local_adata_train = adata_filtered[adata_filtered.obs["gene"].isin(train_genes)].copy()
    local_adata_test = adata_filtered[adata_filtered.obs["gene"].isin(test_genes)].copy()

    # Convert X to dense if needed
    train_X = to_array(local_adata_train)
    test_X = to_array(local_adata_test)

    # Compute control and perturbed centroids from train
    control_mask = local_adata_train.obs["gene"] == control_label
    assert control_mask.sum() > 0, "No control cells found in train set"
    control_centroid = train_X[control_mask].mean(axis=0)
    perturbed_centroid = train_X[~control_mask].mean(axis=0)

    # Prepare test AnnData with additional info in `uns`
    test_adata_eval = anndata.AnnData(
        X=test_X,
        obs=local_adata_test.obs.copy(),
        uns=dict(
            control_centroid_train=control_centroid,
            perturbed_centroid_train=perturbed_centroid
        )
    )

    print("local_adata_train:\n", local_adata_train)
    print("test_adata_eval:\n", test_adata_eval)

    # Save results
    local_adata_train.write(os.path.join(data_dir, "local_train.h5ad"))
    test_adata_eval.write(os.path.join(data_dir, "local_test_eval.h5ad"))
    with open(os.path.join(data_dir, "local_test_perturbated_genes.txt"), "w") as f:
        f.write("\n".join(test_genes))

    return local_adata_train, test_adata_eval, test_genes

You could use this to create your own split / cross-validation.

For this quickstarter, we will use the provided split:
- local train: `obesity_challenge_1.h5ad`
- and local test: `obesity_challenge_1_local_gtruth.h5ad`

In [None]:
# local_adata_train, test_adata_eval, test_genes = prepare_local_split_adata(
#     adata_train,
#     filter_genes=True,
#     filter_cells=True,
#     number_cells_per_gene=50,
#     test_split_mode="ratio",
#     test_ratio=0.2,
# )

## Strategy Implementation

In this section, we build a **simple baseline predictor** based on cell-type centroids computed from the training data.

The idea is straightforward:
- Compute the **control centroid** ‚Üí average expression profile of all control cells.

- For each target perturbation in the evaluation set, generate synthetic "predicted cells" by **repeating the control centroid** a fixed number of times.

- Package these predictions into an `AnnData` object with the correct structure.

In [None]:
def get_centroids(
    adata_train: anndata.AnnData,
    n_cells_per_perturbation: int = 50
):
    """
    Compute control and perturbed centroids
    """

    # Getting the control centroid in the train set
    control_mask = adata_train.obs["gene"] == control_label
    assert control_mask.sum() != 0, "No control cell"
    control_centroid = np.asarray(
        adata_train[control_mask].X.mean(axis=0)
    ).ravel()

    # Getting the perturbed centroid
    perturbed_mask = ~control_mask

    # Create a reduced index by sampling N cells per perturbation
    sampled_perturbed_idx = (
        adata_train.obs.loc[perturbed_mask]
        .groupby("gene", observed=True)
        .head(n_cells_per_perturbation)
        .index
    )

    perturbed_centroid = np.asarray(
        adata_train[sampled_perturbed_idx].X.mean(axis=0)
    ).ravel()

    gc.collect()

    return control_mask, control_centroid, perturbed_mask, perturbed_centroid

We compute two reference profiles from the training set:

- **Control centroid**: mean expression across all control cells
- **Perturbed centroid**: mean expression across all non-control (perturbed) cells

In [None]:
(
    control_mask,
    control_centroid,
    perturbed_mask,
    perturbed_centroid,
) = get_centroids(
    adata_train,
)

control_centroid.shape, perturbed_centroid.shape

((21592,), (21592,))

Load the list of gene perturbations to predict.

For the local version, we evaluate on a separate test file (`"obesity_challenge_1_local_gtruth.h5ad"`):

In [None]:
gtruth = scanpy.read_h5ad(os.path.join("data", "obesity_challenge_1_local_gtruth.h5ad"), backed="r")
predict_perturbations = gtruth.obs["gene"].cat.categories.tolist()

print("Local test gene perturbations:", predict_perturbations)

Local test gene perturbations: ['CHD4', 'FOXC1', 'SOX6', 'TRIM5', 'ZBTB20']


In the full challenge, you would load the 2863 unseen perturbations via `predict_perturbations.txt`:

In [None]:
# predict_perturbations = (
#     pd.read_csv(os.path.join("data", "predict_perturbations.txt"), header=None)[0]
#     .values
# )
# 
# print("All gene perturbations:", predict_perturbations)

Generate predictions using the control centroid baseline.

We create a prediction matrix where each gene perturbation is represented by `cells_per_perturbation` (100) identical rows equal to the control centroid.

In [None]:
# GENERATE PREDICTIONS
n_genes = len(adata_train.var.index)
n_perturbations = len(predict_perturbations)
cells_per_perturbation = 100

# (n_perturbations * cells_per_pert, n_genes) prediction matrix
prediction_matrix = np.zeros((n_perturbations * cells_per_perturbation, n_genes))

# Fill with a repeated control centroid
for i, pert in enumerate(predict_perturbations):
    start = i * cells_per_perturbation
    end = (i + 1) * cells_per_perturbation
    prediction_matrix[start:end] = control_centroid

# Construct obs["gene"] (n_perturbations * cells_per_pert rows)
obs_gene = np.repeat(predict_perturbations, cells_per_perturbation)

# Build the AnnData output object
prediction = anndata.AnnData(
    X=prediction_matrix,
    obs={"gene": obs_gene},
    var=adata_train.var.copy(),   # To preserve gene names & order
)

prediction

### Why do we generate 100 predicted cells per perturbation?

Single-cell RNA-seq produces **many individual cells per perturbation**, not a single expression vector. Even when the same gene is perturbed, cells show diverse responses:
- different effect magnitudes,
- alternative differentiation paths,
- alternative metabolic states,
- natural transcriptional variability.

This heterogeneity is real **biological signal**, not noise, and the challenge evaluates how well a model reproduces the distribution of cell states.

The two evaluation metrics reflect this: **MMD** compares full distributions using pairwise similarities between cells and **Pearson Delta** evaluates perturbation effects relative to perturbed means. Predicting only one cell would collapse the distribution, give unstable or biased estimates, and distort correlations.

By generating **100 (could be more) predicted cells per perturbation**, it create enough samples to capture heterogeneity, support distribution-based metrics, stabilize mean/variance estimates and realistically approximate the ‚Äúcloud‚Äù of single-cell responses observed in real data.

### Predict Program Proportions

For this baseline, we simply copy the mean program proportions of control cells from the training set.

In [None]:
adata_train_control = adata_train[control_mask]

pred_proportion_df = pd.DataFrame({"gene": predict_perturbations})
pred_proportion_df["adipo"] = float(adata_train_control.obs["adipo"].mean())
pred_proportion_df["pre_adipo"] = float(adata_train_control.obs["pre_adipo"].mean())
pred_proportion_df["other"] = float(adata_train_control.obs["other"].mean())
pred_proportion_df["lipo"] = float(adata_train_control.obs["lipo"].mean())
pred_proportion_df

Unnamed: 0,gene,adipo,pre_adipo,other,lipo
0,CHD4,0.257553,0.366456,0.375991,0.06996
1,FOXC1,0.257553,0.366456,0.375991,0.06996
2,SOX6,0.257553,0.366456,0.375991,0.06996
3,TRIM5,0.257553,0.366456,0.375991,0.06996
4,ZBTB20,0.257553,0.366456,0.375991,0.06996


### The `train()` Function

The `train()` function is intended to encapsulate the full training workflow used to fit a predictive model on the provided perturbation dataset.
In a complete implementation, this function would:

- Load the training dataset from `data_directory_path`
- Preprocess the gene expression matrix
- Construct and train a model
- Evaluate the model on a validation split
- Save the trained model and any required metadata to `model_directory_path`

**On the platform, `train()` will be run on the full large dataset `obesity_challenge_1.h5ad`.**

**Note**:
- ‚ùó We recommend training locally and submitting weights/models because the dataset is large and cloud resources are limited.
- Make sure that the `Method description.md` file properly documents your model, so that the Broad Institute team can reference your work in their publications.




In [19]:
def train(
    data_directory_path: str,
    model_directory_path: str,
):
    """
    Train a perturbation prediction model.

    This function is designed to:
      - Load training data from `data_directory_path`
      - Preprocess gene expression matrices
      - Train your model on the available perturbations
      - Save all required models into `model_directory_path`

    In this baseline notebook, the function is left empty.
    You can fill in the model pipeline you want to use.
    """

    pass

### The `infer()` Function

The `infer()` function performs model inference on a set of unseen perturbations.
In the context of this baseline solution, there is no trained model.
Instead, predictions are generated by reusing the control centroid, making this a simple but valid baseline for evaluating the metric computation pipeline.

During inference, the function:

1. Loads the training dataset.
2. Baseline strategy: Computes centroids for control and perturbed cells in the training data. Builds a prediction matrix by repeating the control centroid for every requested perturbation.
3. Constructs an `AnnData` object containing the predicted gene-expression profiles.
4. Saves the predictions in the correct structure expected by the specification.
5. Computes the predicted cell-type proportions and saves them as a CSV file.

[Expected Output](https://docs.crunchdao.com/competitions/competitions/broad-obesity/crunch-1#expected-output):

- `Method description.md`: At the __end of the notebook__ üëáüëáüëá, write a small text outlining the approaches used to generate both the predictions and the estimated proportions of cells enriched for each program.

- `prediction.h5ad`: An `AnnData` file containing predicted gene expression profiles normalized and log-transformed post-perturbation for 2,863 gene perturbations indicated in `predict_perturbations.txt`. Predictions should be stored in `adata.X` matrix with the corresponding perturbation identity recorded in `adata.obs['gene']`. For each gene perturbation, we ask you to predict the gene expression profiles for 100 cells to quantify the distribution of each perturbation prediction. The file with predictions should have dimensions [286,300 √ó len(genes_to_predict)] (cells √ó genes).

- `predict_program_proportion.csv`: A CSV file reporting the predicted proportion of cells with enriched programs for each gene perturbation listed in `predict_perturbations.txt`.

**Inference input:**
- `predict_perturbations`: A list of **single-gene perturbations** for which you must predict the transcriptomic effect.
- `genes_to_predict`: A list of **gene names (columns)** to include in your predictions. Your model must generate expression values for this specific set of genes and the order of columns in the final prediction file must follow this list.

**Provide an inference function that relies on these two inputs.**



In [None]:
def infer(
    data_directory_path: str,
    prediction_directory_path: str,
    prediction_h5ad_file_path: str,
    program_proportion_csv_file_path: str,
    model_directory_path: str,
    predict_perturbations: typing.List[str],
    genes_to_predict: typing.List[str],
    cells_per_perturbation: int = 100,
):
    """
    Run inference for a set of perturbations.

    Parameters:
        data_directory_path: str
            Path to the training AnnData file.
        prediction_directory_path: str
            Directory where prediction files can be written.
        prediction_h5ad_file_path: str
            Direct path where to write the `prediction.h5ad` file.
        program_proportion_csv_file_path: str
            Direct path where to write the `predict_program_proportion.csv` file.
        model_directory_path: str
            Directory containing your persisted model files.
        predict_perturbations: List[str]
            The perturbations for which to generate predictions.
        genes_to_predict: List[str]
            List of gene names (columns) to include in the prediction.h5ad AnnData object.
        cells_per_perturbation: int
            Number of synthetic cells to generate per perturbation.

    Return:
      None: Returned value is ignored.

    Expected files:
        prediction_directory_path / "prediction.h5ad": anndata.AnnData
            AnnData matrix of size (n_perturbations * cells_per_perturbation, n_genes_to_predict), containing the predicted gene expression values.
        prediction_directory_path / "predict_program_proportion.csv": pd.DataFrame
            DataFrame (index=False) containing estimated cell-type proportions for each perturbation.
    """

    print("Loading data...")
    global adata_train
    if "adata_train" not in globals():
        # Optimization for Google Colab, avoid loading the data twice
        adata_train = load_data(data_directory_path)

    print("Extracting centroids...")
    (
        control_mask,
        control_centroid,
        perturbed_mask,
        perturbed_centroid,
    ) = get_centroids(
        adata_train,
    )

    print("Filtering data to only genes to predict...")
    mask = adata_train.var.index.isin(genes_to_predict)
    control_centroid = control_centroid[mask]
    perturbed_centroid = perturbed_centroid[mask]

    print("Infering the prediction...")
    n_genes = len(genes_to_predict)
    n_perturbations = len(predict_perturbations)
    n_cells = n_perturbations * cells_per_perturbation

    print(f"Predicting {n_cells} cells for {n_perturbations} perturbations.")
    print("Each row will contain", n_genes, "genes.")

    # Construct obs["gene"] (n_perturbations * cells_per_pert rows)
    obs = {"gene": np.repeat(predict_perturbations, cells_per_perturbation)}
    # Column variables
    var = adata_train[:, genes_to_predict].var.copy()

    # Clean any prior prediction file
    if os.path.exists(prediction_h5ad_file_path):
        os.remove(prediction_h5ad_file_path)

    # Temporary disk-mapped matrix file (will be removed after final .h5ad is written)
    temporary_prediction_path = os.path.join(prediction_directory_path, "prediction_matrix.h5")
    if os.path.exists(temporary_prediction_path):
        os.remove(temporary_prediction_path)

    # RAM estimation (informational)
    needed_ram = (n_cells * n_genes * 4) / (1024**3)  # GB
    available_ram = psutil.virtual_memory().available / (1024**3)
    print(f"Available RAM: {available_ram:.2f} GB")
    print(f"Needed RAM for full matrix: ~{needed_ram:.2f} GB")

    # Detecting whether you are running inside the crunch platform or not
    use_ram_intensive_mode = crunch.is_inside_runner
    # use_ram_intensive_mode = True  # Uncomment me to force it

    if use_ram_intensive_mode:
        print("-> Using full in-memory matrix (fastest).")

        # Full RAM X
        X = np.zeros((n_cells, n_genes), dtype=np.float32)
        print("Prediction matrix shape:", X.shape)

        # Fill with a repeated control centroid
        for i, pert in tqdm(enumerate(predict_perturbations)):
            start = i * cells_per_perturbation
            end = (i + 1) * cells_per_perturbation
            X[start:end] = control_centroid

        prediction = anndata.AnnData(X=X, obs=obs, var=var)
        del X
    else:
        print("-> Using HDF5 (low-memory mode).")
        # Adjust batch size
        batch_size = 1024 if n_cells > 5000 else 50

        # Create an HDF5 dataset on disk and write predictions batch-by-batch.
        # This avoids storing a huge (e.g. 286,300 √ó len(genes_to_predict)) matrix in RAM.
        with h5py.File(temporary_prediction_path, "w") as f:
            print("Prediction matrix shape:", (n_cells, n_genes))
            dset = f.create_dataset(
                "X",
                shape=(n_cells, n_genes),
                dtype="float32",
                chunks=(batch_size, n_genes),
                # compression="gzip"
            )

            for start in tqdm(range(0, n_cells, batch_size)):
                end = min(start + batch_size, n_cells)
                dset[start:end] = control_centroid

        print("Finished writing matrix.")

        # Use an empty in-memory array with the same shape, then replace X with the HDF5 dataset
        prediction = anndata.AnnData(X=np.zeros((n_cells, n_genes), dtype=np.float32), obs=obs, var=var)

        # Point X to the HDF5 file
        prediction.X = h5py.File(temporary_prediction_path, "r")["X"]

    # Save to .h5ad
    print("Saving the prediction...")
    prediction.write_h5ad(prediction_h5ad_file_path)

    # Cleanup open resources
    if isinstance(prediction.X, h5py.File):
        prediction.X.__exit__()

    # Remove the temporary HDF5 file
    if os.path.exists(temporary_prediction_path):
        os.remove(temporary_prediction_path)

    print("Infering the proportion...")
    # Cell-type proportions are estimated solely from control cells
    adata_train_control = adata_train[control_mask, genes_to_predict]
    pred_proportion_df = pd.DataFrame({"gene": predict_perturbations})
    pred_proportion_df["adipo"] = float(adata_train_control.obs["adipo"].mean())
    pred_proportion_df["pre_adipo"] = float(adata_train_control.obs["pre_adipo"].mean())
    pred_proportion_df["other"] = float(adata_train_control.obs["other"].mean())
    pred_proportion_df["lipo"] = float(adata_train_control.obs["lipo"].mean())

    print("Saving the proportion...")
    pred_proportion_df.to_csv(program_proportion_csv_file_path, index=False)

    ## Optionally for your experiments (don't close prediction.X!):
    # return (
    #     prediction,
    #     pred_proportion_df
    # )

    del prediction
    gc.collect()

## Local testing

To make sure your `train()` and `infer()` function are working properly, you can call the `crunch_tools.test()` function that will reproduce the cloud environment locally. <br />
Even if it is not perfect, it should give you a quick idea if your model is working properly.

**Note**: Locally, the infer function will be run on only 5 perturbated genes.

In [None]:
# crunch_tools.test()

In [None]:
genes_to_predict = pd.read_csv(os.path.join("data", "genes_to_predict.txt"), header=None)[0].values

## If instead you want to predict all genes (columns):
# genes_to_predict = adata_train.var.index.to_list()

print("About to predict", len(genes_to_predict), "genes")

In [None]:
## In the full challenge, you would load the 2863 unseen perturbations via predict_perturbations.txt.
# predict_perturbations = (
#     pd.read_csv(os.path.join("data", "predict_perturbations.txt"), header=None)[0]
#     .values
# )

## Test infer function on local test set "obesity_challenge_1_local_gtruth.h5ad"
gtruth = scanpy.read_h5ad(os.path.join("data", "obesity_challenge_1_local_gtruth.h5ad"), backed="r")
predict_perturbations = gtruth.obs["gene"].cat.categories.tolist()
print("Local test gene perturbations:", predict_perturbations)

infer(
    data_directory_path="data/",
    prediction_directory_path="prediction/",
    model_directory_path="resources/",
    predict_perturbations=predict_perturbations,
    genes_to_predict=genes_to_predict,
    prediction_h5ad_file_path=os.path.join("prediction", "prediction.h5ad"),
    program_proportion_csv_file_path=os.path.join("prediction", "predict_program_proportion.csv")
)

## Results

Once the local tester is done, you can preview the result stored in `prediction/prediction.parquet`.

In [None]:
# If backed='r', load AnnData in backed mode instead of fully loading it into memory
prediction = scanpy.read_h5ad(os.path.join("prediction", "prediction.h5ad"), backed="r")
prediction

In [None]:
predicted_proportion = pd.read_csv(os.path.join("prediction", "predict_program_proportion.csv"))
predicted_proportion

### Local scoring

You can call the function that the system uses to estimate your score locally.

In [None]:
gtruth = scanpy.read_h5ad(os.path.join("data", "obesity_challenge_1_local_gtruth.h5ad"), backed="r")
gtruth

In [None]:
gtruth_proportion = pd.read_csv(os.path.join("data", "program_proportion_local_gtruth.csv"))
gtruth_proportion

### Transcriptome-wide metrics

These metrics are computed using **either a subset of genes or all genes** (i.e., columns of the predicted matrix) for each perturbation.

- For the **public leaderboard** (updated weekly), the evaluation uses **1,000 hidden genes**.
The **MMD metric** is computed with a fixed sigma value of **2326**, calibrated specifically for 1,000-dimensional gene vectors.
- For the **private leaderboard**, both the **number** and the **identity** of the scoring genes will remain unknown to participants.

In this quickstarter example, we follow the same spirit and randomly select 1,000 genes to compute transcriptome-wide evaluation metrics.

In [None]:
# Retrieve the ordered list of predicted gene names (columns)
genes_to_predict_prediction = prediction.var.index.to_list()
print("Number of genes (columns) in the prediction:", len(genes_to_predict_prediction))

# Number of genes to use for computing metrics
n_genes_for_metric = 1000

# Randomly sample genes for the metric computation
rng = np.random.default_rng(seed=42)  # You should try with different seed to make sure you are ready for the private leaderboard
genes_for_metric = rng.choice(genes_to_predict_prediction, size=n_genes_for_metric, replace=False)

# Index for selecting the same genes from adata_train and local_gtruth
indexer_genes_gtruth = adata_train.var.index.get_indexer(genes_for_metric)

print("Number of genes (columns) used to compute metrics:", len(genes_for_metric))

#### Pearson Delta

The Pearson Delta ($\rho$) is the perturbation effects relative to perturbed mean ($\hat{X}_P$) between predicted ($\hat{X}$) and observed ($X$):

$
\displaystyle
\rho(X, \hat{X}) = \frac{1}{|P|} \sum_{p \in P} \text{cor}\left(\hat{X}_p - X_{PM}, X_p - X_{PM}\right)
$

where:
- $X_{PM}$ is the Perturbed Mean defined as the mean gene expression of all the perturbed cells (those receiving a single gene perturbation) in the training dataset,
- $P$ is the set of all perturbation targets,
- $|P|$ is the size of the set $P$,
- $\text{cor}(a, b)$ is the correlation between vectors $a$ and $b$.

---

*Higher is better!*

In [None]:
def compute_metric_pearson(
    gtruth_X: np.typing.NDArray[np.float64],
    pred_X: np.typing.NDArray[np.float64],
    perturbed_centroid: np.typing.NDArray[np.float64],
) -> float:
    gtruth_X_target = gtruth_X.mean(axis=0)
    pred_X_target = pred_X.mean(axis=0)

    return scipy.stats.pearsonr(
        gtruth_X_target - perturbed_centroid,
        pred_X_target - perturbed_centroid,
    ).statistic

perturbations = gtruth.obs["gene"].cat.categories.tolist()
scores_pearson = []

perturbed_centroid = gtruth.uns["perturbed_centroid_train"][indexer_genes_gtruth]

for p in tqdm(perturbations):
    # Select true and predicted cells for perturbation p
    gt_mask = gtruth.obs["gene"] == p
    pr_mask = prediction.obs["gene"] == p

    # Extract their filtered expression matrices
    gtruth_X = to_array(gtruth[gt_mask, genes_for_metric])
    pred_X   = to_array(prediction[pr_mask, genes_for_metric])

    # Skip perturbations without matching predicted cells
    if gtruth_X.shape[0] == 0 or pred_X.shape[0] == 0:
        print(f"skipping {p}: missing samples: will not be accepted on the platform!")
        scores_pearson.append(0)
        continue

    # Compute Pearson score for this perturbation
    score = compute_metric_pearson(
        gtruth_X=gtruth_X,
        pred_X=pred_X,
        perturbed_centroid=perturbed_centroid,
    )
    scores_pearson.append(score)

print(f"Pearson score: {np.mean(scores_pearson):.4f}")

#### Maximum Mean Discrepancy (MMD)

The MMD is between the predicted and the observed distributions of single-cell profiles.

Let $X_p^i$ be the true expression profile and $\hat{X}_p^i$ be the predicted expression profile for cell $i$ with perturbation $p$.
  Then we calculate the MMD distance for a particular perturbation using the following formulae:

$
\displaystyle
MMD^2(X_p, \hat{X}_p) =
\frac{1}{N^2} \sum_{i,j} k(X_p^i, X_p^j) +
\frac{1}{N^2} \sum_{k,l} k(\hat{X}_p^k, \hat{X}_p^l) -
\frac{2}{N^2} \sum_{m,n} k(X_p^m, \hat{X}_p^n)
$

where:
- $k(a, b)$ is the Gaussian kernel with the bandwidth from the following list: $[581.5, 1163.0, 2326.0, 4652.0, 9304.0]$.

---

Then we average the MMD score across all the perturbations using the following formulae:

$
\displaystyle
MMD^2(X, \hat{X}) = \frac{1}{|P|} \sum_{p \in P} MMD^2(X_p, \hat{X}_p)
$

where:
- $|P|$ is the size of the set $P$.

---

*Lower is better!*

In [None]:
def _gaussian_kernel(source, target, kernel_mul, kernel_num, fix_sigma):
    # Getting the L2 distance
    n_samples = int(source.shape[0]) + int(target.shape[0])
    total = np.concatenate([source, target], axis=0)

    total0 = np.broadcast_to(np.expand_dims(total, axis=0), (int(total.shape[0]), int(total.shape[0]), int(total.shape[1])))
    total1 = np.broadcast_to(np.expand_dims(total, axis=1), (int(total.shape[0]), int(total.shape[0]), int(total.shape[1])))
    L2_distance = np.sum((total0 - total1)**2, axis=2)

    # Now we are ready to scale this using multiple bandwidth
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = np.sum(L2_distance) / (n_samples**2 - n_samples)

    # Now we will create the multiple bandwidth list
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [np.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]

    return sum(kernel_val)


def _compute_mmd(X_batch, Y_batch, kernel_mul, kernel_num, fix_sigma, kernel_func):
    num_batch_element = X_batch.shape[0]

    kernels = kernel_func(X_batch, Y_batch, kernel_mul, kernel_num, fix_sigma)
    XX = kernels[:num_batch_element, :num_batch_element]
    YY = kernels[num_batch_element:, num_batch_element:]
    XY = kernels[:num_batch_element, num_batch_element:]
    YX = kernels[num_batch_element:, :num_batch_element]
    mmd_val = np.sum(XX + YY - XY - YX)

    return mmd_val / num_batch_element ** 2


def balance_source_target_sample_per_perturbation(gtruth_X_tgt,pred_X_tgt):
    num_gtruth = gtruth_X_tgt.shape[0]
    num_pred = pred_X_tgt.shape[0]
    min_sample = min(num_gtruth,num_pred)

    return gtruth_X_tgt[0:min_sample,:], pred_X_tgt[0:min_sample,:]


def compute_metric_mmd(
    gtruth_X: np.typing.NDArray[np.float64],
    pred_X: np.typing.NDArray[np.float64],
) -> float:
    kernel_mul = 2.0
    kernel_num = 5
    fix_sigma = 2326 # with 1000 columns

    # Balancing the samples to compute mmd using equal number of samples
    gtruth_X, pred_X = balance_source_target_sample_per_perturbation(gtruth_X, pred_X)

    mmd_dist = _compute_mmd(
        gtruth_X,
        pred_X,
        kernel_mul,
        kernel_num,
        fix_sigma,
        _gaussian_kernel
    )

    return mmd_dist


perturbations = gtruth.obs["gene"].cat.categories.tolist()
scores_mmd = []

for p in tqdm(perturbations):
    # Select true and predicted cells for perturbation p
    gt_mask = gtruth.obs["gene"] == p
    pr_mask = prediction.obs["gene"] == p

    # Extract their filtered expression matrices
    gtruth_X = to_array(gtruth[gt_mask, genes_for_metric])
    pred_X   = to_array(prediction[pr_mask, genes_for_metric])

    # Skip perturbations without matching predicted cells
    if gtruth_X.shape[0] == 0 or pred_X.shape[0] == 0:
        print(f"skipping {p}: missing samples: will not be accepted on the platform!")
        scores_pearson.append(9)
        continue

    # Compute MMD for this perturbation
    score = compute_metric_mmd(
        gtruth_X=gtruth_X,
        pred_X=pred_X,
    )

    scores_mmd.append(score)

print(f"MMD score: {np.mean(scores_mmd):.4f}")

#### Program-level metrics.

These metrics evaluate whether models capture meaningful biological outcomes.

#### L1-distance

There are four cell state proportions for each perturbation, i.e., pre-adipogenic, adipogenic, lipogenic, and other.
For each perturbation $p$, we have the ground truth cell-state proportion $S_p = [preadipo, adipo, lipo, other]$.

Let $S_p^R$ be the vector that has the proportion of the cell states $[preadipo, adipo, other]$. <br />
Let $S_p^L$ denote the condition probability of a cell being in the lipogenic state given the adipogenic state, i.e., $S_p^L = lipo/adipo$.

Then we define the program level loss as:

$
\displaystyle
L1(\hat{S}, S) = \frac{1}{|P|} \sum_{p \in P} 0.75 * | S_p^R - \hat{S}_p^R |_1 + 0.25 * | \hat{S}_p^L - S_p^L |
$

where $|.|_1$ is the L‚ÇÅ-distance, $|P|$ is the number of perturbations, and $\hat{S}_p$ is the predicted cell-state proportions.

---

*Lower is better!*

In [31]:
def compute_metric_l1_distance(
    true_state_proportion_df: pd.DataFrame,
    pred_state_proprotion_df: pd.DataFrame,
) -> float:
    # Going over all the genes that were perturbed in this set
    unique_perturb_genes = list(true_state_proportion_df["gene"].unique())

    all_l1_loss_list = []
    for gene in unique_perturb_genes:
        # Slicing the column with this gene
        true_gene_df = true_state_proportion_df[true_state_proportion_df["gene"] == gene]
        pred_gene_df = pred_state_proprotion_df[pred_state_proprotion_df["gene"] == gene]

        # print(gene, pred_gene_df.shape[0])
        assert true_gene_df.shape[0] == 1 and pred_gene_df.shape[0] == 1, f"Invalid prediction count for state gene={gene} count={pred_gene_df.shape[0]}!=1"

        # Getting the L1 loss for main  pre, adipo and other
        l1_three = (
            np.abs(true_gene_df.iloc[0]["pre_adipo"] - pred_gene_df.iloc[0]["pre_adipo"]) +
            np.abs(true_gene_df.iloc[0]["adipo"] - pred_gene_df.iloc[0]["adipo"]) +
            np.abs(true_gene_df.iloc[0]["other"] - pred_gene_df.iloc[0]["other"])
        )

        # Getting the L1 loss for lipo by adipo
        numerical_stab_term = 1e-20
        pred_lipo_adipo = pred_gene_df.iloc[0]["lipo"] / (pred_gene_df.iloc[0]["adipo"] + numerical_stab_term)
        true_lipo_adipo = true_gene_df.iloc[0]["lipo"] / (true_gene_df.iloc[0]["adipo"] + numerical_stab_term)
        l1_lipo_adipo = np.abs(true_lipo_adipo - pred_lipo_adipo)

        # Getting the average error
        average_l1 = 0.75 * l1_three + 0.25 * l1_lipo_adipo
        all_l1_loss_list.append(average_l1)

    # Getting the overall average over all the gene perturbation
    l1_loss = np.mean(all_l1_loss_list)
    return float(l1_loss)


l1_distance = compute_metric_l1_distance(
    gtruth_proportion,
    predicted_proportion,
)

print(f"L1-distance score: {l1_distance:.4f}")

L1-distance score: 0.1026


## Writing the report

The final step is to write the method description as specified [in the documentation](https://docs.crunchdao.com/competitions/competitions/broad-obesity/crunch-1#file-method-description.md).

You must:
1. Explain how your method works. *(5-10 sentences)*
2. Describe the reasoning behind your gene panel design. *(5-10 sentences)*
3. Specify the datasets and any other resources utilized. *(5-10 sentences)*

The limit is about one page.
<br />
<br />
<br />

---

<span style="font-size: 48px">üëáüëáüëá</span> (double-click the markdown cell below)

---
file: Method description.md
---

<!-- Don't forget to change me -->

# Method Description

Explain how your method works. (5-10 sentences)

# Rationale

Describe the reasoning behind your gene panel design. (5-10 sentences)

# Data and Resources Used

Specify the datasets and any other resources utilized. (5-10 sentences)

<span style="font-size: 48px">üëÜüëÜüëÜ</span>

---
<br />
<br />
<br />

# Submit your Notebook

To submit your work, you must:
1. Download your Notebook from Colab
2. Upload it to the platform
3. Create a run to validate it

### >> https://hub.crunchdao.com/competitions/broad-obesity-1/submit/notebook

![Download and Submit Notebook](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/animations/download-and-submit-notebook.gif)