# Example state embeddings

In [None]:
from helical.models.state import stateEmbeddingsModel
from helical.models.state import stateConfig

state_config = stateConfig()
state_embed = stateEmbeddingsModel(configurer=state_config)

In [None]:
processed_data = state_embed.process_data(ann_data_path="competition_val_template.h5ad")
embeddings = state_embed.get_embeddings(processed_data)

# Example state transition model inference

In [None]:
from helical.models.state import stateTransitionModel
import scanpy as sc

state_transition = stateTransitionModel(configurer=state_config)

adata = sc.read_h5ad("example_data.h5ad")
adata = state_transition.process_data(adata)
# embeddings after perturbation 
adata = state_transition.get_embeddings(adata)

# Example finetuning head on ST model

In [1]:
from helical.models.state import stateFineTuningModel, stateConfig
import scanpy as sc
# Load the desired dataset
adata = sc.read_h5ad("competition_val_template.h5ad")

# Get the desired label class
cell_types = list(adata.obs.cell_type)

# Get unique labels
label_set = set(cell_types)
print(label_set)

# Create the fine-tuning model with the relevant configs
config = stateConfig()
model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
    freeze_backbone=True
)

print(model)

# Fine-tune
model.train(train_input_data=adata, train_labels=cell_types)

INFO:datasets:PyTorch version 2.6.0 available.
INFO:datasets:Polars version 1.33.0 available.


{'H1'}


INFO:helical.models.state.state_finetune:Backbone model frozen - only fine-tuning head will be trained
INFO:helical.models.state.state_finetune:Created label mapping: {'H1': 0}
INFO:helical.models.state.state_finetune:Converted string labels to integers: 1 classes


stateFineTuningModel(
  (fine_tuning_head): ClassificationHead(
    (dropout): Dropout(p=0.02, inplace=False)
    (linear): Linear(in_features=18080, out_features=1, bias=True)
  )
  (model): StateTransitionPerturbationModel(
    (loss_fn): SamplesLoss()
    (pert_encoder): Sequential(
      (0): Linear(in_features=5120, out_features=672, bias=True)
      (1): GELU(approximate='none')
      (2): Dropout(p=0.1, inplace=False)
      (3): Linear(in_features=672, out_features=672, bias=True)
      (4): GELU(approximate='none')
      (5): Dropout(p=0.1, inplace=False)
      (6): Linear(in_features=672, out_features=672, bias=True)
      (7): GELU(approximate='none')
      (8): Dropout(p=0.1, inplace=False)
      (9): Linear(in_features=672, out_features=672, bias=True)
    )
    (basal_encoder): Linear(in_features=18080, out_features=672, bias=True)
    (transformer_backbone): LlamaBidirectionalModel(
      (embed_tokens): Embedding(32000, 672, padding_idx=0)
      (layers): ModuleList(
   

INFO:helical.models.state.state_finetune:Optimizer set up for fine-tuning head only
INFO:helical.models.state.state_finetune:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 6183/6183 [00:08<00:00, 714.68it/s, loss=0]
INFO:helical.models.state.state_finetune:Fine-Tuning Complete. Epochs: 1


# Creating a Virtual Cell Challenge Submission using Helical

In [None]:
'''
Download the dataset

(taken from Colab Notebook by Adduri et al.
https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l#scrollTo=h0aSjKX7Rtyw)
'''

import requests
from tqdm.auto import tqdm  # picks the best bar for the environment
from zipfile import ZipFile
from tqdm.auto import tqdm
import os

# Download the Replogle-Nadig training dataset.
url = "https://storage.googleapis.com/vcc_data_prod/datasets/state/competition_support_set.zip"
output_path = "competition_support_set.zip"

# stream the download so we can track progress
response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0))

with open(output_path, "wb") as f, tqdm(
    total=total, unit='B', unit_scale=True, desc="Downloading"
) as bar:
    for chunk in response.iter_content(chunk_size=8192):
        if not chunk:
            break
        f.write(chunk)
        bar.update(len(chunk))

out_dir  = "competition_support_set"
os.makedirs(out_dir, exist_ok=True)
with ZipFile(output_path, 'r') as z:
    for member in tqdm(z.infolist(), desc="Unzipping", unit="file"):
        z.extract(member, out_dir)

In [None]:
# train the model on the training data
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainingConfig

train_config = trainingConfig(
    toml_config_path="competition_support_set/starter.toml",
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="competition_support_set/ESM2_pert_features.pt",
    max_steps=40000,
    ckpt_every_n_steps=20000,
    model="state")

state_train = stateTransitionTrainModel(configurer = train_config)
state_train.train() 
state_train.predict() 

Once the model is trained we can perform inference on a new dataset as done with the class before

In [None]:
from helical.models.state import stateConfig
from helical.models.state import stateTransitionModel
import scanpy as sc

state_config = stateConfig(
    output = "competition/prediction.h5ad",
    model_dir = "competition/first_run",
    checkpoint = "competition/first_run/checkpoints/final.ckpt",
    pert_col = "target_gene",
    embed_key = None,
    celltype_col = None,
    celltypes = None,
    batch_col = None,
    control_pert = None,
    seed = 42,
    max_set_len = None,
    tsv = None
)

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

state_transition = stateTransitionModel(configurer=state_config)
adata = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(adata)

In [None]:
import scanpy as sc

adata = sc.read_h5ad("competition_val_template.h5ad")


# Print the AnnData object structure and keys
print("AnnData object:")
print(adata)
print("\nobs columns:", adata.obs.columns.tolist())
print("var columns:", adata.var.columns.tolist())
print("uns keys:", list(adata.uns.keys()))
print("obsm keys:", list(adata.obsm.keys()))
print("varm keys:", list(adata.varm.keys()))


Now we can evaluate the model

In [None]:
# evaluate the model - underlying function uses cell-eval package 
# (https://github.com/ArcInstitute/cell-eval)
from helical.models.state import vcc_eval

# default configs for competition dataset
EXPECTED_GENE_DIM = 18080
MAX_CELL_DIM = 100000
DEFAULT_PERT_COL = "target_gene"
DEFAULT_CTRL = "non-targeting"
DEFAULT_COUNTS_COL = "n_cells"
DEFAULT_CELLTYPE_COL = "celltype"
DEFAULT_NTC_NAME = "non-targeting"

configs = {
    # path to the prediction file
    "input": "competition/prediction.h5ad",
    # path to the gene names file
    "genes": "competition_support_set/gene_names.csv",
    # path to the output file - if None will be created with default naming
    "output": None,
    "pert_col": DEFAULT_PERT_COL,
    "celltype_col": None,
    "ntc_name": DEFAULT_NTC_NAME,
    "output_pert_col": DEFAULT_PERT_COL,
    "output_celltype_col": DEFAULT_CELLTYPE_COL,
    "encoding": 32,
    "allow_discrete": False,
    "expected_gene_dim": EXPECTED_GENE_DIM,
    "max_cell_dim": MAX_CELL_DIM,
}

# this creates a submission file in the output directory which can be uploaded to the challenge leaderboard
vcc_eval(configs)