# Generating Example Data

We use data from the Virtual Cell Challenge for model training, inference, and finetuning.

In [None]:
'''
Download the dataset

(taken from Colab Notebook by Adduri et al.
https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l#scrollTo=h0aSjKX7Rtyw)
'''

import requests
from tqdm.auto import tqdm  # picks the best bar for the environment
from zipfile import ZipFile
from tqdm.auto import tqdm
import os

# Download the Replogle-Nadig training dataset.
url = "https://storage.googleapis.com/vcc_data_prod/datasets/state/competition_support_set.zip"
output_path = "competition_support_set.zip"

# stream the download so we can track progress
response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0))

with open(output_path, "wb") as f, tqdm(
    total=total, unit='B', unit_scale=True, desc="Downloading"
) as bar:
    for chunk in response.iter_content(chunk_size=8192):
        if not chunk:
            break
        f.write(chunk)
        bar.update(len(chunk))

out_dir  = "competition_support_set"
os.makedirs(out_dir, exist_ok=True)
with ZipFile(output_path, 'r') as z:
    for member in tqdm(z.infolist(), desc="Unzipping", unit="file"):
        z.extract(member, out_dir)

# State embeddings

In [6]:
from helical.models.state import stateEmbed
from helical.models.state import stateConfig

state_config = stateConfig(batch_size=16)
state_embed = stateEmbed(configurer=state_config)

INFO:helical.models.state.state_embeddings_torch:Using model checkpoint: /home/rasched/.cache/helical/models/state/state_embed/se600m_model_weights.pt
INFO:helical.models.state.state_embeddings_torch:number of free parameters: 603245738
INFO:helical.models.state.state_embeddings_torch:Missing keys: []
INFO:helical.models.state.state_embeddings_torch:Successfully loaded model


In [7]:
import scanpy as sc

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")
adata = adata[:2].copy()

In [8]:
processed_data1 = state_embed.process_data(adata=adata.copy())
embeddings = state_embed.get_embeddings(processed_data1)

INFO:helical.models.state.state_embeddings_torch:Auto-detected gene column: var.index (overlap: 17773/19790 protein embeddings, 98.3% of genes)
INFO:/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/data/loader.py:17773 genes mapped to embedding file (out of 18080)
INFO:/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/data/loader.py:17773 genes mapped to embedding file (out of 18080)
Encoding: 100%|██████████| 1/1 [00:00<00:00,  6.24it/s]


In [None]:
# import torch

# def compare_models(model1, model2):
#     params1 = list(model1.parameters())
#     params2 = list(model2.parameters())
    
#     print(f"Model 1 has {len(params1)} parameters")
#     print(f"Model 2 has {len(params2)} parameters")
    
#     for i, (p1, p2) in enumerate(zip(params1, params2)):
#         if not torch.allclose(p1, p2, atol=1e-6):
#             print(f"Parameter {i} differs! Max diff: {torch.max(torch.abs(p1 - p2)).item()}")
#             return False
#     print("All parameters are identical!")
#     return True

# compare_models(state_embed.model, state_embed_torch.model)
# print(state_embed.model.training, state_embed_torch.model.training)


# adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")
# adata = adata[:2].copy()

# processed_data1 = state_embed.process_data(adata=adata.copy())
# batch = next(iter(processed_data1))
# print(batch)

# with torch.no_grad():
#     with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
#         _, _, _, emb1, _ = state_embed.model._compute_embedding_for_batch(batch)
#         _, _, _, emb2, _ = state_embed_torch.model._compute_embedding_for_batch(batch)
    
#     print(f"Model 1 embedding: {emb1[0, :5]}")
#     print(f"Model 2 embedding: {emb2[0, :5]}")
    
#     # Compare
#     diff = torch.abs(emb1 - emb2).sum()
#     print(f"Difference: {torch.max(diff).item()}")

In [None]:
# Extract weights from checkpoint
# import torch
# checkpoint = state_embed.checkpoint['state_dict']
# state_dict = state_embed.checkpoint['state_dict']
# torch.save(state_dict, "embed_model_epoch16_weights.pt")

# Training the Model

The example data should generate a directory. Before training edit the `competition_support_set/starter.toml` file to point to the correct path on your machine. Here you can also change the train/test split but feel free to leave this at default.

In [9]:
# train the model on the training data
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainingConfig

train_config = trainingConfig(
    output_dir="competition",
    name="first_run",
    toml_config_path="competition_support_set/starter.toml",
    checkpoint_name="final.ckpt",
    max_steps=40000,
    max_epochs=1,
    ckpt_every_n_steps=20000,
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="competition_support_set/ESM2_pert_features.pt"
    )

state_train = stateTransitionTrainModel(configurer=train_config)
state_train.train() 
state_train.predict() 

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42
INFO:cell_load.config:Configuration validation passed
INFO:cell_load.data_modules.perturbation_dataloader:Initializing DataModule: batch_size=16, workers=4, random_seed=42


/home/rasched/final_helical_with_state/helical/helical/models/state/competition_support_set/{competition_train,k562_gwps,rpe1,jurkat,k562,hepg2}.h5


INFO:cell_load.data_modules.perturbation_dataloader:Set 2 missing perturbations to zero vectors.
INFO:cell_load.data_modules.perturbation_dataloader:Loaded custom perturbation featurizations for 19792 perturbations.
INFO:cell_load.data_modules.perturbation_dataloader:Processing dataset replogle_h1:
INFO:cell_load.data_modules.perturbation_dataloader:  - Training dataset: True
INFO:cell_load.data_modules.perturbation_dataloader:  - Zeroshot cell types: ['hepg2']
INFO:cell_load.data_modules.perturbation_dataloader:  - Fewshot cell types: []
Processing replogle_h1: 100%|██████████| 6/6 [00:00<00:00, 27.33it/s]
INFO:cell_load.data_modules.perturbation_dataloader:

INFO:cell_load.data_modules.perturbation_dataloader:Done! Train / Val / Test splits: 5 / 0 / 1


Processed competition_train: 221273 train, 0 val, 0 test
Processed k562_gwps: 111605 train, 0 val, 0 test
Processed rpe1: 22317 train, 0 val, 0 test
Processed jurkat: 21412 train, 0 val, 0 test
Processed k562: 18465 train, 0 val, 0 test
Processed hepg2: 0 train, 0 val, 9386 test
Model created. Estimated params size: 0.61 GB and 650505936 parameters


INFO:helical.models.state.state_train:Loggers and callbacks set up.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:helical.models.state.state_train:Starting trainer fit.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: 
  | Name                 | Type                    | Params | Mode 
-------------------------------------------------------------------------
0 | loss_fn              | SamplesLoss             | 0      | train
1 | pert_encoder         | Sequential              | 4.8 M  | train
2 | basal_encoder        | Linear                  | 12.2 M | train
3 | t

Trainer built successfully
Sanity Checking: |          | 0/? [00:00<?, ?it/s]

INFO:cell_load.data_modules.samplers:Creating perturbation batch sampler with metadata caching (using codes)...
INFO:cell_load.data_modules.samplers:Total # cells 9386. Cell set size mean / std before resampling: 78.22 / 47.32.
INFO:cell_load.data_modules.samplers:Creating meta-batches with cell_sentence_len=128...
INFO:cell_load.data_modules.samplers:Of all batches, 51 were full and 69 were partial.
INFO:cell_load.data_modules.samplers:Sampler created with 8 batches in 0.02 seconds.
INFO:cell_load.data_modules.samplers:Of all batches, 51 were full and 69 were partial.


                                                                           

INFO:cell_load.data_modules.samplers:Creating perturbation batch sampler with metadata caching (using codes)...




INFO:cell_load.data_modules.samplers:Total # cells 395072. Cell set size mean / std before resampling: 117.72 / 27.93.
INFO:cell_load.data_modules.samplers:Creating meta-batches with cell_sentence_len=128...
INFO:cell_load.data_modules.samplers:Of all batches, 2831 were full and 525 were partial.
INFO:cell_load.data_modules.samplers:Sampler created with 210 batches in 0.40 seconds.
INFO:cell_load.data_modules.samplers:Of all batches, 2831 were full and 525 were partial.


Epoch 0: 100%|██████████| 210/210 [01:47<00:00,  1.95it/s, v_num=0]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 210/210 [01:47<00:00,  1.95it/s, v_num=0]
Training completed, saving final checkpoint...


INFO:cell_load.data_modules.samplers:Creating perturbation batch sampler with metadata caching (using codes)...
INFO:cell_load.data_modules.samplers:Total # cells 9386. Cell set size mean / std before resampling: 78.22 / 47.32.
INFO:cell_load.data_modules.samplers:Creating meta-batches with cell_sentence_len=128...
INFO:cell_load.data_modules.samplers:Of all batches, 120 were full and 0 were partial.
INFO:cell_load.data_modules.samplers:Sampler created with 120 batches in 0.01 seconds.
INFO:helical.models.state.state_train:Loading model from competition/first_run/final.ckpt
INFO:helical.models.state.state_train:Model loaded successfully.
INFO:helical.models.state.state_train:Generating predictions on test set using manual loop...
Predicting:   0%|          | 0/120 [00:00<?, ?batch/s]INFO:cell_load.data_modules.samplers:Of all batches, 120 were full and 0 were partial.
Predicting: 100%|██████████| 120/120 [00:20<00:00,  5.93batch/s]
INFO:helical.models.state.state_train:Creating anndata

Once the model is trained we can perform inference on a new dataset using:

In [10]:
from helical.models.state import stateTransitionModel

state_config = stateConfig(
    output = "competition/prediction.h5ad",
    model_dir = "competition/first_run",
    model_config = "model_dir/config.yaml",
    pert_col = "target_gene",
)

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

state_transition = stateTransitionModel(configurer=state_config)
adata = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(adata)

INFO:helical.models.state.state_transition:Using checkpoint: competition/first_run/final.ckpt
INFO:helical.models.state.state_transition:Model device: cpu
INFO:helical.models.state.state_transition:Model cell_set_len (max sequence length): 128
INFO:helical.models.state.state_transition:Model uses batch encoder: False
INFO:helical.models.state.state_transition:Model output space: all
INFO:helical.models.state.state_transition:Grouping by cell type column: cell_type
INFO:helical.models.state.state_transition:Using adata.X as input features
INFO:helical.models.state.state_transition:Cells: total=98927, control=38176, non-control=60751
INFO:helical.models.state.state_transition:Running virtual experiment (homogeneous per-perturbation forward passes; controls included)...
Group H1: 100%|██████████| 51/51 [00:28<00:00,  1.81it/s, Pert: non-targeting           ]
INFO:helical.models.state.state_transition:
=== Inference complete ===
INFO:helical.models.state.state_transition:Input cells:      

# Creating a Virtual Cell Challenge Submission

To create a submission for the Virtual Cell Challenge we pass our previous prediction file into `helical.models.state.vcc_eval` which uses the `cell-eval` package. This will generate a `.vcc` file that can be uploaded to the public leaderboard.

In [None]:
# evaluate the model - underlying function uses cell-eval package 
# (https://github.com/ArcInstitute/cell-eval)
from helical.models.state import vcc_eval

# default configs for competition dataset
EXPECTED_GENE_DIM = 18080
MAX_CELL_DIM = 100000
DEFAULT_PERT_COL = "target_gene"
DEFAULT_CTRL = "non-targeting"
DEFAULT_COUNTS_COL = "n_cells"
DEFAULT_CELLTYPE_COL = "celltype"
DEFAULT_NTC_NAME = "non-targeting"

configs = {
    # path to the prediction file
    "input": "competition/prediction.h5ad",
    # path to the gene names file
    "genes": "competition_support_set/gene_names.csv",
    # path to the output file - if None will be created with default naming
    "output": None,
    "pert_col": DEFAULT_PERT_COL,
    "celltype_col": None,
    "ntc_name": DEFAULT_NTC_NAME,
    "output_pert_col": DEFAULT_PERT_COL,
    "output_celltype_col": DEFAULT_CELLTYPE_COL,
    "encoding": 32,
    "allow_discrete": False,
    "expected_gene_dim": EXPECTED_GENE_DIM,
    "max_cell_dim": MAX_CELL_DIM,
}

# this creates a submission file in the output directory which can be uploaded to the challenge leaderboard
vcc_eval(configs)

# Finetuning Example

We can further finetune the model. This can start with a model directory containing a `.ckpt` file. If the model directory does not exists, one will be created and so will a new model instance. The finetuning will produce model weights and head weights that will be saved and training will resume from these `.pt` files in the directory (if present).

In [11]:
from helical.models.state import stateFineTuningModel

# Load the desired dataset
adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# Get the desired label class
cell_types = list(adata.obs.cell_type)
label_set = set(cell_types)

# Create the fine-tuning model (no need to specify var_dims location)
config = stateConfig(
    batch_size=8,
    model_dir="competition/first_run",
    model_config="model_dir/config.yaml",
    freeze_backbone=True
)

model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]

# Fine-tune
model.train(train_input_data=data, train_labels=cell_types)

INFO:helical.models.state.state_finetune:Loading existing config.yaml from: model_dir/config.yaml
INFO:helical.models.state.state_finetune:Loading pre-trained model from: competition/first_run/final.ckpt
INFO:helical.models.state.state_finetune:Backbone model frozen - only fine-tuning head will be trained
INFO:helical.models.state.state_finetune:Processing data for state model fine-tuning.
INFO:helical.models.state.state_finetune:Loaded perturbation mapping with 19792 perturbations
INFO:helical.models.state.state_finetune:Successfully processed the data for state model fine-tuning.
INFO:helical.models.state.state_finetune:Optimizer set up for fine-tuning head only
INFO:helical.models.state.state_finetune:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 12366/12366 [00:13<00:00, 922.64it/s, loss=0]
INFO:helical.models.state.state_finetune:Fine-Tuning Complete. Epochs: 1
