This notebook goes over how to use `STATE` using `helical`.

# Download Example Data

We start by using the helical downloader to obtain an example huggingface dataset. 

In [None]:
from helical.utils.downloader import Downloader
from pathlib import Path

downloader = Downloader()
downloader.download_via_link(
    Path("yolksac_human.h5ad"),
    "https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true",)

# STATE Embeddings

Using the STATE model we can obtain single cell transcriptome embeddings. We first slice the dataset for demonstration purposes.

In [None]:
# load the data 
import scanpy as sc

adata = sc.read_h5ad("yolksac_human.h5ad")
# for demonstration we subset to 10 cells and 2000 genes
adata = adata[:10, :2000].copy()

print(adata.shape)
n_cells = adata.n_obs
print(n_cells)

Initialise the model - this will download the relevant files needed in `.cache/helical/state/`. It will download the necessary files when run the first time so will take slightly longer. 


In [None]:
from helical.models.state import stateConfig
from helical.models.state import stateEmbed

state_config = stateConfig(batch_size=16)
state_embed = stateEmbed(configurer=state_config)

We process the data by calling `state_embed.process_data` and pass this into `state_embed.get_embeddings` to get the final embeddings.

In [None]:
processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

# note that the STATE model returns a numpy array of shape (n_cells, 1024)
print(embeddings.shape)
print(type(embeddings))

# store the embeddings in adata.obsm['state_emb']
adata.obsm['state_emb'] = embeddings

# STATE Perturbations

To use the perturbation model you can either pass in embeddings by specifiyng the `embed_key` arguement in `stateConfig` or use the deafult `None` value in which case the expression values are used (`adata.X`).

For use of previous embeddings, the `embed_key` must exist in `adata.obsm[<embed_key>]` otherwise an error will be thrown. When set to `None` the model uses `adata.X`.

Let's create some dummy data for the previous example.

In [None]:
import numpy as np
# some default control and non-control perturbations
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

n_cells = adata.n_obs
# we assign perturbations to cells randomly
adata.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata.obs['cell_type'] = adata.obs['LVL1']  # Use your cell type column
# we can also add a batch variable to take into account batch effects
batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata.obs['batch_var'] = batch_labels

config = stateConfig(
    embed_key='state_emb', # our custom embedding key from above
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    output_path="yolksac_perturbed.h5ad",
)


Now we can run the perturbation model.

In [None]:
from helical.models.state import stateTransitionModel

state_transition = stateTransitionModel(configurer=config)

# again we process the data and get the perturbed embeddings
processed_data = state_transition.process_data(adata)
perturbed_embeds = state_transition.get_embeddings(processed_data)

print(perturbed_embeds.shape)

# Finetuning STATE

We can finetune the STATE perturbation embeddings using an additional head for downstream classification and regression. Below is a dummy example using data above to get you started.

In [None]:
from helical.models.state import stateFineTuningModel

# Dummy cell types and labels for demonstration
cell_types = list(adata.obs['LVL1'])
label_set = set(cell_types)
print(f"Found {len(label_set)} unique cell types:")

config = stateConfig(
    embed_key="state_emb",  # Use gene expression instead of embeddings
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    batch_size=8,
)

# Create the fine-tuning model - we use a classification head for demonstration
model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training - returns a dataset object
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

# Convert cell type labels to integers
cell_type_labels = [class_id_dict[ct] for ct in cell_types]

print(f"Class mapping: {class_id_dict}")

# Fine-tune
model.train(train_input_data=data, train_labels=cell_type_labels)

# Training STATE for the Virtual Cell Challenge

We use data from the Virtual Cell Challenge for model training and downstream inference. For this we require the VCC dataset as in the colab notebook by the authors. See the relevant code snippet for the entire dataset in the below colab notebook:

[STATE Colab Notebook](https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l)

For demonstration we use a subset of the data. We also need to change the filepath in `starter.toml` to point to the correct dataset location (see top of file).

In [None]:
from helical.utils.downloader import Downloader
import toml

downloader = Downloader()
downloader.download_via_name("sample_vcc_data")

# quick fix to change the filepath in starter.toml
toml.dump({**toml.load(open("sample_vcc_data/starter.toml")),**{"datasets": {"replogle_h1": str(Path("sample_vcc_data").absolute() / "{rpe1,hepg2}.h5")}},},open("sample_vcc_data/starter.toml", "w"))

We use the `stateTransitionTrainModel` class and initialise `trainConfigs`.

In [1]:
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainConfig

# default configs for competition dataset
train_config = trainConfig(
    output_dir="sample_vcc_data",
    name="first_run",
    toml_config_path="sample_vcc_data/starter.toml",
    checkpoint_name="final.ckpt",
    max_steps=40000,
    max_epochs=1,
    ckpt_every_n_steps=20000,
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="sample_vcc_data/ESM2_pert_features.pt"
    )

INFO:datasets:PyTorch version 2.6.0 available.
INFO:datasets:Polars version 1.33.0 available.


In [2]:
# we can then train the model and perform inference on a held out test set
state_train = stateTransitionTrainModel(configurer=train_config)
state_train.train() 
state_train.predict() 

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42
INFO:cell_load.config:Configuration validation passed
INFO:cell_load.data_modules.perturbation_dataloader:Initializing DataModule: batch_size=16, workers=4, random_seed=42


/home/rasched/final_helical_with_state/helical/helical/models/state/competition_support_set/{rpe1,hepg2}.h5


INFO:cell_load.data_modules.perturbation_dataloader:Set 2 missing perturbations to zero vectors.
INFO:cell_load.data_modules.perturbation_dataloader:Loaded custom perturbation featurizations for 19792 perturbations.
INFO:cell_load.data_modules.perturbation_dataloader:Processing dataset replogle_h1:
INFO:cell_load.data_modules.perturbation_dataloader:  - Training dataset: True
INFO:cell_load.data_modules.perturbation_dataloader:  - Zeroshot cell types: ['hepg2']
INFO:cell_load.data_modules.perturbation_dataloader:  - Fewshot cell types: []
Processing replogle_h1: 100%|██████████| 2/2 [00:00<00:00, 151.30it/s]
INFO:cell_load.data_modules.perturbation_dataloader:

INFO:cell_load.data_modules.perturbation_dataloader:Done! Train / Val / Test splits: 1 / 0 / 1


Processed rpe1: 22317 train, 0 val, 0 test
Processed hepg2: 0 train, 0 val, 9386 test
Model created. Estimated params size: 0.61 GB and 650505936 parameters


INFO:helical.models.state.state_train:Loggers and callbacks set up.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:helical.models.state.state_train:Starting trainer fit.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Trainer built successfully


INFO: 
  | Name                 | Type                    | Params | Mode 
-------------------------------------------------------------------------
0 | loss_fn              | SamplesLoss             | 0      | train
1 | pert_encoder         | Sequential              | 4.8 M  | train
2 | basal_encoder        | Linear                  | 12.2 M | train
3 | transformer_backbone | LlamaBidirectionalModel | 50.4 M | train
4 | project_out          | Sequential              | 13.5 M | train
5 | final_down_then_up   | Sequential              | 81.7 M | train
6 | relu                 | ReLU                    | 0      | train
-------------------------------------------------------------------------
141 M     Trainable params
21.5 M    Non-trainable params
162 M     Total params
650.506   Total estimated model params size (MB)
86        Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name                 | Type                    | Params 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

INFO:cell_load.data_modules.samplers:Creating perturbation batch sampler with metadata caching (using codes)...
INFO:cell_load.data_modules.samplers:Total # cells 9386. Cell set size mean / std before resampling: 78.22 / 47.32.
INFO:cell_load.data_modules.samplers:Creating meta-batches with cell_sentence_len=128...
INFO:cell_load.data_modules.samplers:Of all batches, 51 were full and 69 were partial.
INFO:cell_load.data_modules.samplers:Sampler created with 8 batches in 0.01 seconds.
INFO:cell_load.data_modules.samplers:Of all batches, 51 were full and 69 were partial.


                                                                           

INFO:cell_load.data_modules.samplers:Creating perturbation batch sampler with metadata caching (using codes)...
INFO:cell_load.data_modules.samplers:Total # cells 22317. Cell set size mean / std before resampling: 107.29 / 36.13.
INFO:cell_load.data_modules.samplers:Creating meta-batches with cell_sentence_len=128...
INFO:cell_load.data_modules.samplers:Of all batches, 139 were full and 69 were partial.
INFO:cell_load.data_modules.samplers:Sampler created with 13 batches in 0.02 seconds.




INFO:cell_load.data_modules.samplers:Of all batches, 139 were full and 69 were partial.


Epoch 0: 100%|██████████| 13/13 [00:04<00:00,  2.94it/s, v_num=0]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 13/13 [00:04<00:00,  2.94it/s, v_num=0]
Training completed, saving final checkpoint...


INFO:cell_load.data_modules.samplers:Creating perturbation batch sampler with metadata caching (using codes)...
INFO:cell_load.data_modules.samplers:Total # cells 9386. Cell set size mean / std before resampling: 78.22 / 47.32.
INFO:cell_load.data_modules.samplers:Creating meta-batches with cell_sentence_len=128...
INFO:cell_load.data_modules.samplers:Of all batches, 120 were full and 0 were partial.
INFO:cell_load.data_modules.samplers:Sampler created with 120 batches in 0.01 seconds.
INFO:helical.models.state.state_train:Loading model from sample_vcc_data/first_run/final.ckpt
INFO:helical.models.state.state_train:Model loaded successfully.
INFO:helical.models.state.state_train:Generating predictions on test set using manual loop...
Predicting:   0%|          | 0/120 [00:00<?, ?batch/s]INFO:cell_load.data_modules.samplers:Of all batches, 120 were full and 0 were partial.
Predicting: 100%|██████████| 120/120 [00:06<00:00, 18.06batch/s]
INFO:helical.models.state.state_train:Creating ann

The trained model will be saved to the `sample_vcc_data/first_run` directory, alongside the necessary files and checkpoints to intialise a new model. We can initialise `stateTransitionModel` as before and run inference. For consistency with the `stateConfig` copy the `config.yaml` to the directory set as `name`.

In [4]:
from helical.models.state import stateTransitionModel
from helical.models.state import stateConfig
import scanpy as sc

adata = sc.read_h5ad("sample_vcc_data/test.h5ad")

state_config = stateConfig(
    output_path = "sample_vcc_data/prediction.h5ad",
    perturb_dir = "sample_vcc_data/first_run",
    pert_col = "target_gene",
)

state_transition = stateTransitionModel(configurer=state_config)
processed_data = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(processed_data)

INFO:helical.models.state.state_transition:Using checkpoint: sample_vcc_data/first_run/final.ckpt
INFO:helical.models.state.state_transition:Model device: cpu
INFO:helical.models.state.state_transition:Model cell_set_len (max sequence length): 128
INFO:helical.models.state.state_transition:Model uses batch encoder: False
INFO:helical.models.state.state_transition:Model output space: all
INFO:helical.models.state.state_transition:Grouping by cell type column: cell_type
INFO:helical.models.state.state_transition:Using adata.X as input features
INFO:helical.models.state.state_transition:Cells: total=98927, control=38176, non-control=60751
INFO:helical.models.state.state_transition:Running virtual experiment (homogeneous per-perturbation forward passes; controls included)...
Group H1: 100%|██████████| 51/51 [00:31<00:00,  1.64it/s, Pert: non-targeting           ]
INFO:helical.models.state.state_transition:--Complete--
Input cells: 98927, Control simulated: 38176, Treated simulated: 60751
I

Now you can use the `cell-eval` package to create a submission to the Virtual Cell Challenge (generates a `.vcc` file).

In [None]:
! pip install cell-eval
! cell-eval prep -i sample_vcc_data/prediction.h5ad -g sample_vcc_data/gene_names.csv