# VCC Submission Notebook

Hello! 

This is a notebook that will help you prepare your predicted AnnData to be ready to be scored by `cell-eval` against a validation dataset.

Before we begin you will need a few things:

1. `cell-eval` installed and in your `$PATH` (see our [installation guide](https://github.com/ArcInstitute/cell-eval?tab=readme-ov-file#installation))
2. The number of expected cells / perturbation in the validation dataset (CSV) ([download](https://virtualcellchallenge.org/app))
3. The gene names to predict (CSV) ([download](https://virtualcellchallenge.org/app))
4. Your model predictions in an AnnData (h5ad)
5. (Optional) The training AnnData (if you are not predicting Non-Targeting Controls) ([download](https://virtualcellchallenge.org/app))


> Note: Your model predictions **may not exceed 100K cells total**

## Building an Example Submission

For the purposes of this tutorial we will be generating **random predictions** and preparing them to be evaluated.

We will create an AnnData with *random gene abundances* for each cell, where the number of cells for each perturbation matches the number of cells in the validation dataset.

### Load in our Expected Counts

In [1]:
import polars as pl

# Define our path
pert_counts_path = "../vcc_data/pert_counts_Validation.csv"

# pert_counts_path = "gene_counts_arc_h1_true_validation.csv"

# Read in the csv
pert_counts = pl.read_csv(pert_counts_path)

# Show the dimensions
print(f"Dimensions: {pert_counts.shape}")
pert_counts.head()

Dimensions: (50, 3)


target_gene,n_cells,median_umi_per_cell
str,i64,f64
"""SH3BP4""",2925,54551.0
"""ZNF581""",2502,53803.5
"""ANXA6""",2496,55175.0
"""PACSIN3""",2101,54088.0
"""MGST1""",2096,54217.5


### Load in our Expected Gene Names

In [2]:
gene_names_path = "../vcc_data/gene_names.csv"

# Read this in and immediately convert to array
gene_names = pl.read_csv(gene_names_path, has_header=False).to_numpy().flatten()

gene_names

array(['SAMD11', 'NOC2L', 'KLHL17', ..., 'MT-ND5', 'MT-ND6', 'MT-CYB'],
      shape=(18080,), dtype=object)

# Model inferrence to construct the anndata object for validation

In [3]:
from protoplast.scrna.anndata.pert_modules import PerturbDataModule
dm = PerturbDataModule(
        config_path="pert-dataconfig.toml",
        pert_embedding_file="/home/tphan/Softwares/protoplast/notebooks/competition_support_set/ESM2_pert_features.pt",
        train_batch_size=1024,
        eval_batch_size=256,
        num_workers=8,
        persistent_workers=True,
        n_basal_samples=10,
        barcodes=True
    )
dm.setup()

✓ Applied AnnDataFileManager patch


2025-09-08 05:18:07,511 - protoplast.scrna.anndata.pert_dataset - INFO - write mmap file for /home/tphan/Softwares/protoplast/notebooks/competition_support_set/competition_train.h5
[INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
2025-09-08 05:18:07,550 - protoplast.scrna.anndata.pert_dataset - INFO - n_obs for /home/tphan/Softwares/protoplast/notebooks/competition_support_set/competition_train.h5: 221273
2025-09-08 05:18:07,551 - protoplast.scrna.anndata.pert_dataset - INFO - write mmap file for /home/tphan/Softwares/protoplast/notebooks/competition_support_set/hepg2.h5
[INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
2025-09-08 05:18:07,559 - protoplast.scrna.anndata.pert_dataset - INFO - n_obs for /home/tphan/Softwares/protoplast/notebooks/competition_support_set/hepg2.h5: 9386
2025-09-08 05:18:07,559 - protoplast.scrna.anndata.pert_dataset - INFO - write mmap file for /home/tphan/Softwares/protoplas

In [4]:
# Model definition

In [5]:
import os
import torch
from protoplast.scrna.models.cpa_vae_deepset import CPAVAE_Simple

G = 18080
n_cell_lines = 5
n_batches = 536
d_xp = 5120
device = "cuda"

# ---- Training loop ----
model = CPAVAE_Simple(G, n_cell_lines, d_xp=d_xp).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"
last_ck = f"cpa-vae-ctrl-set//epoch=25.pt"
ckpt = torch.load(last_ck, map_location=device)
model.load_state_dict(ckpt["model_state"])

<All keys matched successfully>

In [6]:
dm.ds.cell_types_onehot_map

{np.str_('ARC_H1'): tensor([1., 0., 0., 0., 0.]),
 np.str_('hepg2'): tensor([0., 1., 0., 0., 0.]),
 np.str_('jurkat'): tensor([0., 0., 1., 0., 0.]),
 np.str_('k562'): tensor([0., 0., 0., 1., 0.]),
 np.str_('rpe1'): tensor([0., 0., 0., 0., 1.])}

In [7]:
# generate the control set
import anndata as ad
tr_adata_path = "./competition_support_set/competition_train.h5"
tr_adata = ad.read_h5ad(tr_adata_path)

In [8]:
import random
import numpy as np
random.seed(42)
control_indices = np.where(tr_adata.obs.target_gene == "non-targeting")[0]
control30 = random.sample(control_indices.tolist(), 8)

In [9]:
X_control = tr_adata.X[control30, :].todense()

In [10]:
import torch
X_control = torch.tensor(X_control).to(device)
X_control

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6931, 0.0000,  ..., 0.0000, 0.6931, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.0986, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.0986, 1.0986],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.6931, 0.0000],
        [0.0000, 1.0986, 0.0000,  ..., 0.0000, 1.3863, 1.6094]],
       device='cuda:0')

In [11]:
import numpy as np
import pandas as pd
from numpy.typing import NDArray
import anndata as ad
import torch.nn.functional as F
from scipy import sparse

def model_inference(model,
                    dm,
                    pert_names: NDArray[np.str_],
                    cell_counts: NDArray[np.int64],
                    median_library_size: NDArray[np.int64],
                    gene_names: NDArray[np.str_],
                    X_control = NDArray[np.float64]
                   ):

    obs_target_genes = []
    X = None
    total_cell = 0
    for i, target_gene in enumerate(pert_names):
        if not i % 5:
            print(f"working on gene: {i}")
        n_cells = int(cell_counts[i] / 2)
        total_cell += n_cells

        xp = dm.ds.pert_embedding[target_gene].to(device)
        xp = xp.squeeze(0).expand(n_cells, -1)
        y = dm.ds.cell_types_onehot_map["ARC_H1"].to(device)
        y = y.squeeze(0).expand(n_cells, -1)
        x_ctrl_set = X_control.unsqueeze(0).repeat(n_cells, 1, 1)
        # ---- generate from (y, xp) ----
        xhat_ctrl, xhat_pert, mu_x_pert, logvar_x_pert = model.predict_from_yxp(y, xp, x_ctrl_set, sample=True, temperature=1.0)
        del x_ctrl_set
        torch.cuda.empty_cache()
        # xhat_counts = torch.poisson(torch.expm1(xhat_pert))
        # library_size = xhat_counts.sum(dim=1)
        # factor = median_library_size[i] / library_size
        # xhat_counts = xhat_counts * factor[:, None]
        # xhat_log = torch.log1p(xhat_counts)
        pert_X = sparse.csr_matrix(xhat_pert.detach().cpu().numpy()) 
        # TODO, make it sparse
        if X is None:
            X = sparse.vstack([pert_X], format='csr')
        else:
            X = sparse.vstack([X, pert_X], format='csr')
        obs_target_genes += [target_gene] * n_cells
    return ad.AnnData(
        X=X,
        obs=pd.DataFrame(
            {
                "target_gene": obs_target_genes,
            },
            index=np.arange(total_cell).astype(str),
        ),
        var=pd.DataFrame(index=gene_names),
    )
adata = model_inference(model, dm, pert_counts["target_gene"].to_numpy(), pert_counts["n_cells"].to_numpy(), pert_counts["median_umi_per_cell"].to_numpy(), gene_names, X_control)

working on gene: 0
working on gene: 5
working on gene: 10
working on gene: 15
working on gene: 20
working on gene: 25
working on gene: 30
working on gene: 35
working on gene: 40
working on gene: 45


In [12]:
adata

AnnData object with n_obs × n_vars = 30364 × 18080
    obs: 'target_gene'

In [13]:
n_controls = 10000
controls = random.sample(control_indices.tolist(), n_controls)

### Adding in Non-Targeting Controls if you are not predicting them

Our evaluation framework expects non-targeting controls to be included in the predicted AnnData, but not all models may try to predict non-targeting controls.
If you are not predicting non-targeting controls, you can take the non-targeting from the training AnnData and just copy them over into your predicted AnnData for validation.

In [14]:
# Filter for non-targeting
ntc_adata = tr_adata[controls]

# Append the non-targeting controls to the example anndata if they're missing
if "non-targeting" not in adata.obs["target_gene"].unique():
    assert np.all(adata.var_names.values == ntc_adata.var_names.values), (
        "Gene-Names are out of order or unequal"
    )
    adata = ad.concat(
        [
            adata,
            ntc_adata,
        ]
    )

In [15]:
adata

AnnData object with n_obs × n_vars = 40364 × 18080
    obs: 'target_gene'

### Write our predictions to some output path

In [16]:
adata.write_h5ad("./cpa-vae-simple-deepset-epoch25.h5ad")

## Running `cell-eval prep`

Now that we have our predictions, we will run `cell-eval` to prepare our AnnData for competition scoring.

```bash
cell-eval prep \
    -i ./example.h5ad \
    --genes ./gene_names.csv
```

And that's it! Your model outputs will be output to path: `./example.prep.vcc` are ready for scoring.