# Test dependencies

In [None]:
import helical.models.state
print("All imports successful")

# Generating Example Data

We use data from the Virtual Cell Challenge for model training, inference, and finetuning.

In [None]:
'''
Download the dataset

(taken from Colab Notebook by Adduri et al.
https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l#scrollTo=h0aSjKX7Rtyw)
'''

import requests
from tqdm.auto import tqdm  # picks the best bar for the environment
from zipfile import ZipFile
from tqdm.auto import tqdm
import os

# Download the Replogle-Nadig training dataset.
url = "https://storage.googleapis.com/vcc_data_prod/datasets/state/competition_support_set.zip"
output_path = "competition_support_set.zip"

# stream the download so we can track progress
response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0))

with open(output_path, "wb") as f, tqdm(
    total=total, unit='B', unit_scale=True, desc="Downloading"
) as bar:
    for chunk in response.iter_content(chunk_size=8192):
        if not chunk:
            break
        f.write(chunk)
        bar.update(len(chunk))

out_dir  = "competition_support_set"
os.makedirs(out_dir, exist_ok=True)
with ZipFile(output_path, 'r') as z:
    for member in tqdm(z.infolist(), desc="Unzipping", unit="file"):
        z.extract(member, out_dir)

# State embeddings

We can generate embeddings using the STATE Embedding model in the helical package.

In [None]:
from helical.models.state import stateEmbeddingsModel
from helical.models.state import stateConfig

state_config = stateConfig()
state_embed = stateEmbeddingsModel(configurer=state_config)

In [None]:
import scanpy as sc

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")
adata = adata[:2].copy()

processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

print(embeddings.shape)

# Training the Model

The example data should generate a directory. Before training edit the `competition_support_set/starter.toml` file to point to the correct path on your machine. Here you can also change the train/test split but feel free to leave this at default.

In [None]:
# train the model on the training data
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainingConfig

train_config = trainingConfig(
    output_dir="competition",
    name="first_run",
    toml_config_path="competition_support_set/starter.toml",
    checkpoint_name="final.ckpt",
    max_steps=40000,
    max_epochs=1,
    ckpt_every_n_steps=20000,
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="competition_support_set/ESM2_pert_features.pt"
    )

state_train = stateTransitionTrainModel(configurer=train_config)
state_train.train() 
state_train.predict() 

Once the model is trained we can perform inference on a new dataset using:

In [None]:
from helical.models.state import stateTransitionModel

state_config = stateConfig(
    output = "competition/prediction.h5ad",
    model_dir = "competition/first_run",
    model_config = "configs/config.yaml",
    pert_col = "target_gene",
)

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

state_transition = stateTransitionModel(configurer=state_config)
adata = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(adata)

# Creating a Virtual Cell Challenge Submission

To create a submission for the Virtual Cell Challenge we pass our previous prediction file into `helical.models.state.vcc_eval` which uses the `cell-eval` package. This will generate a `.vcc` file that can be uploaded to the public leaderboard.

In [None]:
# evaluate the model - underlying function uses cell-eval package 
# (https://github.com/ArcInstitute/cell-eval)
from helical.models.state import vcc_eval

# default configs for competition dataset
EXPECTED_GENE_DIM = 18080
MAX_CELL_DIM = 100000
DEFAULT_PERT_COL = "target_gene"
DEFAULT_CTRL = "non-targeting"
DEFAULT_COUNTS_COL = "n_cells"
DEFAULT_CELLTYPE_COL = "celltype"
DEFAULT_NTC_NAME = "non-targeting"

configs = {
    # path to the prediction file
    "input": "competition/prediction.h5ad",
    # path to the gene names file
    "genes": "competition_support_set/gene_names.csv",
    # path to the output file - if None will be created with default naming
    "output": None,
    "pert_col": DEFAULT_PERT_COL,
    "celltype_col": None,
    "ntc_name": DEFAULT_NTC_NAME,
    "output_pert_col": DEFAULT_PERT_COL,
    "output_celltype_col": DEFAULT_CELLTYPE_COL,
    "encoding": 32,
    "allow_discrete": False,
    "expected_gene_dim": EXPECTED_GENE_DIM,
    "max_cell_dim": MAX_CELL_DIM,
}

# this creates a submission file in the output directory which can be uploaded to the challenge leaderboard
vcc_eval(configs)

# Finetuning Example

We can further finetune the model. This can start with a model directory containing a `.ckpt` file. If the model directory does not exists, one will be created and so will a new model instance. The finetuning will produce model weights and head weights that will be saved and training will resume from these `.pt` files in the directory (if present).

In [1]:
from helical.models.state import stateFineTuningModel
import scanpy as sc
from helical.models.state import stateConfig

# Load the desired dataset
adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# Get the desired label class
cell_types = list(adata.obs.cell_type)
label_set = set(cell_types)

# Create the fine-tuning model (no need to specify var_dims location)
config = stateConfig(
    batch_size=8,
    model_dir="competition/first_run",
    model_config="configs/config.yaml",
    freeze_backbone=True
)

model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]

# Fine-tune
model.train(train_input_data=data, train_labels=cell_types)

INFO:datasets:PyTorch version 2.6.0 available.
INFO:datasets:Polars version 1.33.1 available.
INFO:helical.models.state.state_finetune:Loading existing config.yaml from: configs/config.yaml
INFO:helical.models.state.state_finetune:Loading pre-trained model from: competition/first_run/final.ckpt
INFO:helical.models.state.state_finetune:Loaded head weights from competition/first_run/head_weights.pt
INFO:helical.models.state.state_finetune:Backbone model frozen - only fine-tuning head will be trained
INFO:helical.models.state.state_finetune:Processing data for state model fine-tuning.
INFO:helical.models.state.state_finetune:Loaded perturbation mapping with 19792 perturbations
INFO:helical.models.state.state_finetune:Successfully processed the data for state model fine-tuning.
INFO:helical.models.state.state_finetune:Optimizer set up for fine-tuning head only
INFO:helical.models.state.state_finetune:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 12366/12366 [00:14<00:00, 84