# VCC Submission Notebook

Hello! 

This is a notebook that will help you prepare your predicted AnnData to be ready to be scored by `cell-eval` against a validation dataset.

Before we begin you will need a few things:

1. `cell-eval` installed and in your `$PATH` (see our [installation guide](https://github.com/ArcInstitute/cell-eval?tab=readme-ov-file#installation))
2. The number of expected cells / perturbation in the validation dataset (CSV) ([download](https://virtualcellchallenge.org/app))
3. The gene names to predict (CSV) ([download](https://virtualcellchallenge.org/app))
4. Your model predictions in an AnnData (h5ad)
5. (Optional) The training AnnData (if you are not predicting Non-Targeting Controls) ([download](https://virtualcellchallenge.org/app))


> Note: Your model predictions **may not exceed 100K cells total**

## Building an Example Submission

For the purposes of this tutorial we will be generating **random predictions** and preparing them to be evaluated.

We will create an AnnData with *random gene abundances* for each cell, where the number of cells for each perturbation matches the number of cells in the validation dataset.

### Load in our Expected Counts

In [1]:
import polars as pl

# Define our path
pert_counts_path = "../vcc_data/pert_counts_Validation.csv"

# Read in the csv
pert_counts = pl.read_csv(pert_counts_path)

# Show the dimensions
print(f"Dimensions: {pert_counts.shape}")
pert_counts.head()

Dimensions: (50, 3)


target_gene,n_cells,median_umi_per_cell
str,i64,f64
"""SH3BP4""",2925,54551.0
"""ZNF581""",2502,53803.5
"""ANXA6""",2496,55175.0
"""PACSIN3""",2101,54088.0
"""MGST1""",2096,54217.5


### Load in our Expected Gene Names

In [2]:
gene_names_path = "../vcc_data/gene_names.csv"

# Read this in and immediately convert to array
gene_names = pl.read_csv(gene_names_path, has_header=False).to_numpy().flatten()

gene_names

array(['SAMD11', 'NOC2L', 'KLHL17', ..., 'MT-ND5', 'MT-ND6', 'MT-CYB'],
      shape=(18080,), dtype=object)

# Model inferrence to construct the anndata object for validation

In [3]:
from cell_load.data_modules import PerturbationDataModule
dm = PerturbationDataModule(
    toml_config_path="starter.toml",
    embed_key=None, 
    num_workers=8,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    use_scplode = True,
    perturbation_features_file="/home/tphan/state/state/competition_support_set/ESM2_pert_features.pt",
    output_space="gene",
    basal_mapping_strategy="random",
    n_basal_samples=1,
    should_yield_control_cells=True,
    batch_size=16,
)
dm.setup()

Dataset path does not exist: /home/tphan/state/state/competition_support_set/{competition_train,k562_gwps,rpe1,jurkat,k562,hepg2}.h5


/home/tphan/state/state/competition_support_set/{competition_train,k562_gwps,rpe1,jurkat,k562,hepg2}.h5


Processing replogle_h1:   0%|                                                                                        | 0/6 [00:00<?, ?it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:   0%|                                                                                        | 0/6 [00:00<?, ?it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:  33%|██████████████████████████▋                                                     | 2/6 [00:00<00:00, 12.63it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:  33%|██████████████████████████▋                                                     | 2/6 [00:00<00:00, 12.63it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:  33%|██████████████████████████▋                        

Processed competition_train: 221273 train, 0 val, 0 test
Processed k562_gwps: 111605 train, 0 val, 0 test
Processed rpe1: 22317 train, 0 val, 0 test
Processed jurkat: 21412 train, 0 val, 0 test
Processed k562: 18465 train, 0 val, 0 test
Processed hepg2: 0 train, 0 val, 9386 test





In [4]:
import os
import torch
from protoplast.scrna.models.baseline import BaselinePerturbModel

G = 18080           # genes
n_cell_lines = 5
pert_d = 5120   # genes + control

device = "cuda" if torch.cuda.is_available() else "cpu"
last_ck = f"baseline-delta-pert-emb/epoch=30.pt"
ckpt = torch.load(last_ck, map_location=device)
model = BaselinePerturbModel(G, n_cell_lines, pert_d).to(device)
model.load_state_dict(ckpt["model_state"])

<All keys matched successfully>

In [5]:
dm.cell_type_onehot_map

{np.str_('ARC_H1'): tensor([1., 0., 0., 0., 0.]),
 np.str_('hepg2'): tensor([0., 1., 0., 0., 0.]),
 np.str_('jurkat'): tensor([0., 0., 1., 0., 0.]),
 np.str_('k562'): tensor([0., 0., 0., 1., 0.]),
 np.str_('rpe1'): tensor([0., 0., 0., 0., 1.])}

In [6]:
import numpy as np
import pandas as pd
from numpy.typing import NDArray
import anndata as ad
import torch.nn.functional as F
from scipy import sparse

def model_inference(model,
                    dm,
                    pert_names: NDArray[np.str_],
                    cell_counts: NDArray[np.int64],
                    gene_names: NDArray[np.str_],
                    max_count: int | float = 1e4,
                    control_ratio = .3
                   ):

    obs_target_genes = []
    X = None
    total_cell = 0
    for i, target_gene in enumerate(pert_names):
        if not i % 5:
            print(f"working on gene: {i}")
        n_cells = int(cell_counts[i] * (1 + control_ratio))
        total_cell += n_cells
        n_ctrl = int(n_cells * control_ratio)
        n_pert = n_cells - n_ctrl

        xp = dm.pert_onehot_map[target_gene].to(device)
        xp = xp.squeeze(0).expand(n_cells, -1)
        y = dm.cell_type_onehot_map["ARC_H1"].to(device)
        y = y.squeeze(0).expand(n_cells, -1)
        xh_ctrl, delta, xh_prt = model(y, xp)
        ctrl_X = xh_ctrl[0:n_ctrl].detach().cpu().numpy()
        ctrl_X = sparse.csr_matrix(ctrl_X) 
        pert_X = xh_prt[n_ctrl:n_cells].detach().cpu().numpy()
        pert_X = sparse.csr_matrix(pert_X) 
        # TODO, make it sparse
        if X is None:
            X = sparse.vstack([ctrl_X, pert_X], format='csr')
        else:
            X = sparse.vstack([X, ctrl_X, pert_X], format='csr')
        obs_target_genes += ["non-targeting"] * n_ctrl + [target_gene] * n_pert
    return ad.AnnData(
        X=X,
        obs=pd.DataFrame(
            {
                "target_gene": obs_target_genes,
            },
            index=np.arange(total_cell).astype(str),
        ),
        var=pd.DataFrame(index=gene_names),
    )
adata = model_inference(model, dm, pert_counts["target_gene"].to_numpy(), pert_counts["n_cells"].to_numpy(), gene_names)
adata.write_h5ad("./baseline-delta-pert-emb.h5ad")

working on gene: 0
working on gene: 5
working on gene: 10
working on gene: 15
working on gene: 20
working on gene: 25
working on gene: 30
working on gene: 35
working on gene: 40
working on gene: 45


### Adding in Non-Targeting Controls if you are not predicting them

Our evaluation framework expects non-targeting controls to be included in the predicted AnnData, but not all models may try to predict non-targeting controls.
If you are not predicting non-targeting controls, you can take the non-targeting from the training AnnData and just copy them over into your predicted AnnData for validation.

In [5]:
# Define our path to the training anndata
tr_adata_path = "./adata_Training.h5ad"

# Read in the anndata
tr_adata = ad.read_h5ad(tr_adata_path)

# Filter for non-targeting
ntc_adata = tr_adata[tr_adata.obs["target_gene"] == "non-targeting"]

# Append the non-targeting controls to the example anndata if they're missing
if "non-targeting" not in adata.obs["target_gene"].unique():
    assert np.all(adata.var_names.values == ntc_adata.var_names.values), (
        "Gene-Names are out of order or unequal"
    )
    adata = ad.concat(
        [
            adata,
            ntc_adata,
        ]
    )

### Write our predictions to some output path

In [11]:
adata.write_h5ad("./baseline-delta-1.h5ad")

## Running `cell-eval prep`

Now that we have our predictions, we will run `cell-eval` to prepare our AnnData for competition scoring.

```bash
cell-eval prep \
    -i ./example.h5ad \
    --genes ./gene_names.csv
```

And that's it! Your model outputs will be output to path: `./example.prep.vcc` are ready for scoring.