# Generating Example Data

We use data from the Virtual Cell Challenge for model training, inference, and finetuning.

In [None]:
'''
Download the dataset

(taken from Colab Notebook by Adduri et al.
https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l#scrollTo=h0aSjKX7Rtyw)
'''

import requests
from tqdm.auto import tqdm  # picks the best bar for the environment
from zipfile import ZipFile
from tqdm.auto import tqdm
import os

# Download the Replogle-Nadig training dataset.
url = "https://storage.googleapis.com/vcc_data_prod/datasets/state/competition_support_set.zip"
output_path = "competition_support_set.zip"

# stream the download so we can track progress
response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0))

with open(output_path, "wb") as f, tqdm(
    total=total, unit='B', unit_scale=True, desc="Downloading"
) as bar:
    for chunk in response.iter_content(chunk_size=8192):
        if not chunk:
            break
        f.write(chunk)
        bar.update(len(chunk))

out_dir  = "competition_support_set"
os.makedirs(out_dir, exist_ok=True)
with ZipFile(output_path, 'r') as z:
    for member in tqdm(z.infolist(), desc="Unzipping", unit="file"):
        z.extract(member, out_dir)

# State embeddings

In [None]:
from helical.models.state import stateEmbed
from helical.models.state import stateConfig

state_config = stateConfig(batch_size=16)
state_embed = stateEmbed(configurer=state_config)

In [None]:
from helical.utils.downloader import Downloader
from pathlib import Path
downloader = Downloader()
downloader.download_via_link(
    Path("yolksac_human.h5ad"),
    "https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true",)
    
from helical.models.state import stateEmbed, stateConfig
import scanpy as sc

# Load your data
anndata = sc.read_h5ad("yolksac_human.h5ad")
anndata = anndata[:10].copy()
# Initialize the model
state_config = stateConfig(batch_size=16)
state_embed = stateEmbed(configurer=state_config)

# Process and get embeddings
processed_data = state_embed.process_data(adata=anndata)
embeddings = state_embed.get_embeddings(processed_data)

In [None]:
import scanpy as sc

# Load your yolksac data
adata = sc.read_h5ad("yolksac_human.h5ad")

# Print basic info about the AnnData object
print("=== AnnData Object Info ===")
print(adata)

print("\n=== Observation keys (cell metadata) ===")
print("adata.obs columns:")
print(adata.obs.columns.tolist())

print("\n=== Variable keys (gene metadata) ===")
print("adata.var columns:")
print(adata.var.columns.tolist())

print("\n=== Unstructured annotations ===")
print("adata.uns keys:")
print(list(adata.uns.keys()))

print("\n=== Observation matrices (embeddings, etc.) ===")
print("adata.obsm keys:")
print(list(adata.obsm.keys()))

print("\n=== Variable matrices ===")
print("adata.varm keys:")
print(list(adata.varm.keys()))

print("\n=== Variable pairs ===")
print("adata.varp keys:")
print(list(adata.varp.keys()))

print("\n=== Observation pairs ===")
print("adata.obsp keys:")
print(list(adata.obsp.keys()))

adata.X

import pic

In [None]:
import pickle

# From state_transition.py lines 45-49:
with open("/home/rasched/.cache/helical/models/state/state_transition/var_dims.pkl", "rb") as f:
    var_dims = pickle.load(f)

pert_dim = var_dims.get("pert_dim")      # Number of perturbation types
batch_dim = var_dims.get("batch_dim", None)  # Number of batch types

print(var_dims)

In [None]:
import scanpy as sc
import numpy as np
from helical.models.state import stateTransitionModel, stateConfig

# 1. Load your yolksac data
adata = sc.read_h5ad("yolksac_human.h5ad")
print(f"Original data shape: {adata.shape}")

# 2. Subset to 10 cells
adata = adata[:10].copy()
print(f"Subset data shape: {adata.shape}")

# 3. Check what process_data returns
from helical.models.state import stateEmbed, stateConfig as embedConfig

print("Generating embeddings...")
embed_config = embedConfig(batch_size=16)
state_embed = stateEmbed(configurer=embed_config)

# Debug: Check what process_data returns
processed_data = state_embed.process_data(adata=adata)

# Debug: Check what get_embeddings returns
embeddings = state_embed.get_embeddings(processed_data)
print(f"Embeddings shape: {embeddings.shape}")
print(f"Expected shape: ({adata.n_obs}, embedding_dim)")
# 5. Now add embeddings
adata.obsm['X_state_emb'] = embeddings
print(f"Successfully added embeddings: {adata.obsm['X_state_emb'].shape}")

In [None]:
adata.obsm['X_state_emb'].shape 

In [None]:
import scanpy as sc
import numpy as np
from helical.models.state import stateTransitionModel, stateConfig

# 1. Load your yolksac data
adata = sc.read_h5ad("yolksac_human.h5ad")
print(f"Loaded data: {adata.shape}")

# 2. Subset FIRST, then generate embeddings
adata = adata[:2000].copy()  # Subset to 10 cells
print(f"Subset data: {adata.shape}")

# 3. Generate embeddings for the subset
from helical.models.state import stateEmbed, stateConfig as embedConfig

print("Generating embeddings...")
embed_config = embedConfig(batch_size=16)
state_embed = stateEmbed(configurer=embed_config)

processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

# 4. Now the shapes match
print(f"Embeddings shape: {embeddings.shape}")
print(f"Data shape: {adata.shape}")

# Add embeddings to your data
adata.obsm['X_state_emb'] = embeddings
print(f"Successfully added embeddings: {adata.obsm['X_state_emb'].shape}")

In [None]:
n_cells = adata.n_obs
print(adata.obsm['X_state_emb'].shape)
# 3. Set up perturbation labels
# Create some virtual perturbations to test
n_cells = adata.n_obs
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

# Assign perturbations to cells (you can adjust the distribution)
adata.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata.obs['cell_type'] = adata.obs['LVL1']  # Use your cell type column

batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata.obs['batch_var'] = batch_labels

print(f"Perturbation distribution:")
print(adata.obs['target_gene'].value_counts())


In [None]:
print(adata.obsm['X_state_emb'].shape)
adata_new = adata.copy()
adata_new.obsm['X_state_emb'] = adata_new.obsm['X_state_emb'][:, :2000]  # Truncate embeddings to 2000 dims
print(adata_new.obsm['X_state_emb'].shape)

In [None]:
# 4. Configure the transition model
import scanpy as sc
from helical.models.state import stateTransitionModel, stateConfig
# state transition model trained on 2000 genes
# adata = sc.read_h5ad("yolksac_human.h5ad")
# adata_new = adata[:, :2000].copy()
# adata.obs['batch_var'] = 'batch_1'  # or some other identifier
n_cells = adata_new.n_obs
# 3. Set up perturbation labels
# Create some virtual perturbations to test
n_cells = adata_new.n_obs
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

# Assign perturbations to cells (you can adjust the distribution)
adata_new.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata_new.obs['cell_type'] = adata_new.obs['LVL1']  # Use your cell type column

batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata_new.obs['batch_var'] = batch_labels

config = stateConfig(
    embed_key='X_state_emb',
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    output="yolksac_predictions.h5ad",
)

# 5. Run inference
print("Running perturbation predictions...")
state_transition = stateTransitionModel(configurer=config)
adata_processed = state_transition.process_data(adata_new)
embeddings = state_transition.get_embeddings(adata_processed)

In [None]:
import scanpy as sc
from helical.models.state import stateFineTuningModel, stateConfig  # Changed import
import pandas as pd
import numpy as np

# Load the predictions file
adata = sc.read_h5ad("yolksac_human.h5ad")
adata = adata[:100, :2000].copy()
n_cells = adata.n_obs

# 3. Set up perturbation labels
# Create some virtual perturbations to test
n_cells = adata.n_obs
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

# Assign perturbations to cells (you can adjust the distribution)
adata.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata.obs['cell_type'] = adata.obs['LVL1']  # Use your cell type column

batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata.obs['batch_var'] = batch_labels

# Get the desired label class (using LVL1 as cell types)
cell_types = list(adata.obs['LVL1'])
label_set = set(cell_types)

print(f"Found {len(label_set)} unique cell types:")
print(label_set)

# Create the fine-tuning model
config = stateConfig(
    embed_key=None,  # Use gene expression instead of embeddings
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    batch_size=8,  # Add batch size
)

model = stateFineTuningModel(  # Changed class name
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
    freeze_backbone=False
)

# Process the data for training
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

# Convert cell type labels to integers
cell_type_labels = [class_id_dict[ct] for ct in cell_types]

print(f"Class mapping: {class_id_dict}")

# Fine-tune
model.train(train_input_data=data, train_labels=cell_type_labels)

In [None]:
import pickle

# Read var_dims.pkl
with open("/home/rasched/.cache/helical/models/state/state_transition/var_dims.pkl", "rb") as f:
    var_dims = pickle.load(f)

print("=== var_dims.pkl contents ===")
for key, value in var_dims.items():
    if isinstance(value, (list, tuple)) and len(value) > 10:
        print(f"{key}: {type(value)} with {len(value)} items")
        print(f"  First 5: {value[:5]}")
    else:
        print(f"{key}: {value}")

# Read batch_onehot_map.pkl
with open("/home/rasched/.cache/helical/models/state/state_transition/batch_onehot_map.pkl", "rb") as f:
    batch_map = pickle.load(f)

print("\n=== batch_onehot_map.pkl contents ===")
print(f"Number of batch mappings: {len(batch_map)}")
for key, value in list(batch_map.items())[:5]:  # Show first 5
    print(f"  '{key}': {value}")

In [None]:
# import torch

# def compare_models(model1, model2):
#     params1 = list(model1.parameters())
#     params2 = list(model2.parameters())
    
#     print(f"Model 1 has {len(params1)} parameters")
#     print(f"Model 2 has {len(params2)} parameters")
    
#     for i, (p1, p2) in enumerate(zip(params1, params2)):
#         if not torch.allclose(p1, p2, atol=1e-6):
#             print(f"Parameter {i} differs! Max diff: {torch.max(torch.abs(p1 - p2)).item()}")
#             return False
#     print("All parameters are identical!")
#     return True

# compare_models(state_embed.model, state_embed_torch.model)
# print(state_embed.model.training, state_embed_torch.model.training)


# adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")
# adata = adata[:2].copy()

# processed_data1 = state_embed.process_data(adata=adata.copy())
# batch = next(iter(processed_data1))
# print(batch)

# with torch.no_grad():
#     with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
#         _, _, _, emb1, _ = state_embed.model._compute_embedding_for_batch(batch)
#         _, _, _, emb2, _ = state_embed_torch.model._compute_embedding_for_batch(batch)
    
#     print(f"Model 1 embedding: {emb1[0, :5]}")
#     print(f"Model 2 embedding: {emb2[0, :5]}")
    
#     # Compare
#     diff = torch.abs(emb1 - emb2).sum()
#     print(f"Difference: {torch.max(diff).item()}")

In [None]:
# Extract weights from checkpoint
# import torch
# checkpoint = state_embed.checkpoint['state_dict']
# state_dict = state_embed.checkpoint['state_dict']
# torch.save(state_dict, "embed_model_epoch16_weights.pt")

# Training the Model

The example data should generate a directory. Before training edit the `competition_support_set/starter.toml` file to point to the correct path on your machine. Here you can also change the train/test split but feel free to leave this at default.

In [None]:
# train the model on the training data
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainingConfig

train_config = trainingConfig(
    output_dir="competition",
    name="first_run",
    toml_config_path="competition_support_set/starter.toml",
    checkpoint_name="final.ckpt",
    max_steps=40000,
    max_epochs=1,
    ckpt_every_n_steps=20000,
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="competition_support_set/ESM2_pert_features.pt"
    )

state_train = stateTransitionTrainModel(configurer=train_config)
state_train.train() 
state_train.predict() 

Once the model is trained we can perform inference on a new dataset using:

In [None]:
from helical.models.state import stateTransitionModel

state_config = stateConfig(
    output = "competition/prediction.h5ad",
    model_dir = "competition/first_run",
    model_config = "model_dir/config.yaml",
    pert_col = "target_gene",
)

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

state_transition = stateTransitionModel(configurer=state_config)
adata = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(adata)

# Creating a Virtual Cell Challenge Submission

To create a submission for the Virtual Cell Challenge we pass our previous prediction file into `helical.models.state.vcc_eval` which uses the `cell-eval` package. This will generate a `.vcc` file that can be uploaded to the public leaderboard.

In [None]:
# evaluate the model - underlying function uses cell-eval package 
# (https://github.com/ArcInstitute/cell-eval)
from helical.models.state import vcc_eval

# default configs for competition dataset
EXPECTED_GENE_DIM = 18080
MAX_CELL_DIM = 100000
DEFAULT_PERT_COL = "target_gene"
DEFAULT_CTRL = "non-targeting"
DEFAULT_COUNTS_COL = "n_cells"
DEFAULT_CELLTYPE_COL = "celltype"
DEFAULT_NTC_NAME = "non-targeting"

configs = {
    # path to the prediction file
    "input": "competition/prediction.h5ad",
    # path to the gene names file
    "genes": "competition_support_set/gene_names.csv",
    # path to the output file - if None will be created with default naming
    "output": None,
    "pert_col": DEFAULT_PERT_COL,
    "celltype_col": None,
    "ntc_name": DEFAULT_NTC_NAME,
    "output_pert_col": DEFAULT_PERT_COL,
    "output_celltype_col": DEFAULT_CELLTYPE_COL,
    "encoding": 32,
    "allow_discrete": False,
    "expected_gene_dim": EXPECTED_GENE_DIM,
    "max_cell_dim": MAX_CELL_DIM,
}

# this creates a submission file in the output directory which can be uploaded to the challenge leaderboard
vcc_eval(configs)

# Finetuning Example

We can further finetune the model. This can start with a model directory containing a `.ckpt` file. If the model directory does not exists, one will be created and so will a new model instance. The finetuning will produce model weights and head weights that will be saved and training will resume from these `.pt` files in the directory (if present).

In [None]:
# from helical.models.state import stateFineTuningModel

# # Load the desired dataset
# adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# # Get the desired label class
# cell_types = list(adata.obs.cell_type)
# label_set = set(cell_types)

# # Create the fine-tuning model (no need to specify var_dims location)
# config = stateConfig(
#     batch_size=8,
#     model_dir="competition/first_run",
#     model_config="model_dir/config.yaml",
#     freeze_backbone=True
# )

# model = stateFineTuningModel(
#     configurer=config, 
#     fine_tuning_head="classification", 
#     output_size=len(label_set),
# )

# # Process the data for training
# data = model.process_data(adata)

# # Create a dictionary mapping the classes to unique integers for training
# class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

# for i in range(len(cell_types)):
#     cell_types[i] = class_id_dict[cell_types[i]]

# # Fine-tune
# model.train(train_input_data=data, train_labels=cell_types)

In [None]:
import scanpy as sc
import anndata as ad
import numpy as np

def truncate_adata_file_complete(input_path, output_path, max_cells=1000, max_genes=None):
    """
    Truncate an AnnData file and handle ALL fields properly
    """
    print(f"Loading {input_path}...")
    adata = sc.read_h5ad(input_path)
    
    print(f"Original shape: {adata.shape}")
    print(f"Original obsm keys: {list(adata.obsm.keys())}")
    
    # Truncate cells
    if max_cells and adata.n_obs > max_cells:
        print(f"Truncating to {max_cells} cells...")
        
        # Truncate main data
        adata = adata[:max_cells, :].copy()
        
        # Manually truncate obsm fields that might not be handled properly
        for key in adata.obsm.keys():
            matrix = adata.obsm[key]
            if hasattr(matrix, 'shape') and len(matrix.shape) > 0:
                if matrix.shape[0] > max_cells:
                    print(f"Truncating obsm['{key}'] from {matrix.shape} to ({max_cells}, {matrix.shape[1] if len(matrix.shape) > 1 else 'N/A'})")
                    adata.obsm[key] = matrix[:max_cells]
    
    # Truncate genes (optional)
    if max_genes and adata.n_vars > max_genes:
        print(f"Truncating to {max_genes} genes...")
        
        # Truncate main data
        adata = adata[:, :max_genes].copy()
        
        # Manually truncate varm fields
        for key in adata.varm.keys():
            matrix = adata.varm[key]
            if hasattr(matrix, 'shape') and len(matrix.shape) > 0:
                if matrix.shape[0] > max_genes:
                    print(f"Truncating varm['{key}'] from {matrix.shape} to ({max_genes}, {matrix.shape[1] if len(matrix.shape) > 1 else 'N/A'})")
                    adata.varm[key] = matrix[:max_genes]
    
    print(f"New shape: {adata.shape}")
    print(f"New obsm keys: {list(adata.obsm.keys())}")
    
    # Save truncated file
    print(f"Saving to {output_path}...")
    adata.write_h5ad(output_path)
    
    return adata

# Usage
train_small = truncate_adata_file_complete('sample_vcc_data/train.h5', 'train_small.h5', max_cells=100)
val_small = truncate_adata_file_complete('sample_vcc_data/val.h5', 'val_small.h5', max_cells=100)
test_small = truncate_adata_file_complete('sample_vcc_data/test.h5ad', 'test_small.h5ad', max_cells=100)

In [2]:
from omegaconf import OmegaConf

model_configs = OmegaConf.load("/home/rasched/final_helical_with_state/helical/helical/models/state/sample_vcc_data/config.yaml")

In [4]:
model_configs["training"]

{'wandb_track': True, 'weight_decay': 0.0005, 'batch_size': 16, 'lr': 0.0001, 'max_steps': 40000, 'train_seed': 42, 'val_freq': 2000, 'ckpt_every_n_steps': 20000, 'gradient_clip_val': 10, 'loss_fn': 'mse', 'devices': 1, 'strategy': 'auto', 'use_mfu': True, 'mfu_kwargs': {'available_flops': 60000000000000.0, 'use_backward': True, 'logging_interval': 10, 'window_size': 2}}