This notebook goes over how to use `STATE` using `helical`.

# Download Example Data

We start by using the helical downloader to obtain an example huggingface dataset. 

In [None]:
from helical.utils.downloader import Downloader
from pathlib import Path

downloader = Downloader()
downloader.download_via_link(
    Path("yolksac_human.h5ad"),
    "https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true",)

# STATE Embeddings

Using the STATE model we can obtain single cell transcriptome embeddings. We first slice the dataset for demonstration purposes.

In [None]:
# load the data 
import scanpy as sc

adata = sc.read_h5ad("yolksac_human.h5ad")
# for demonstration we subset to 10 cells and 2000 genes
adata = adata[:10, :2000].copy()

print(adata.shape)
n_cells = adata.n_obs
print(n_cells)

Initialise the model - this will download the relevant files needed in `.cache/helical/state/`. It will download the necessary files when run the first time so will take slightly longer. 


In [2]:
from helical.models.state import stateConfig
from helical.models.state import stateEmbed

state_config = stateConfig(batch_size=16)
state_embed = stateEmbed(configurer=state_config)

INFO:helical.models.state.state_embeddings:Using model checkpoint: /home/rasched/.cache/helical/models/state/state_embed/se600m_model_weights.pt
INFO:helical.models.state.state_embeddings:Successfully loaded model


We process the data by calling `state_embed.process_data` and pass this into `state_embed.get_embeddings` to get the final embeddings.

In [None]:
processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

# note that the STATE model returns a numpy array of shape (n_cells, 1024)
print(embeddings.shape)
print(type(embeddings))

# store the embeddings in adata.obsm['state_emb']
adata.obsm['state_emb'] = embeddings

# STATE Perturbations

To use the perturbation model you can either pass in embeddings by specifiyng the `embed_key` arguement in `stateConfig` or use the deafult `None` value in which case the expression values are used (`adata.X`).

For use of previous embeddings, the `embed_key` must exist in `adata.obsm[<embed_key>]` otherwise an error will be thrown. When set to `None` the model uses `adata.X`.

Let's create some dummy data for the previous example.

In [None]:
import numpy as np
# some default control and non-control perturbations
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

n_cells = adata.n_obs
# we assign perturbations to cells randomly
adata.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata.obs['cell_type'] = adata.obs['LVL1']  # Use your cell type column
# we can also add a batch variable to take into account batch effects
batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata.obs['batch_var'] = batch_labels

config = stateConfig(
    embed_key='state_emb', # our custom embedding key from above
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    output_path="yolksac_perturbed.h5ad",
)


Now we can run the perturbation model.

In [None]:
from helical.models.state import stateTransitionModel

state_transition = stateTransitionModel(configurer=config)

# again we process the data and get the perturbed embeddings
processed_data = state_transition.process_data(adata)
perturbed_embeds = state_transition.get_embeddings(processed_data)

print(perturbed_embeds.shape)

# Finetuning STATE

We can finetune the STATE perturbation embeddings using an additional head for downstream classification and regression. Below is a dummy example using data above to get you started.

In [None]:
from helical.models.state import stateFineTuningModel

# Dummy cell types and labels for demonstration
cell_types = list(adata.obs['LVL1'])
label_set = set(cell_types)
print(f"Found {len(label_set)} unique cell types:")

config = stateConfig(
    embed_key="state_emb",  # Use gene expression instead of embeddings
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    batch_size=8,
)

# Create the fine-tuning model - we use a classification head for demonstration
model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training - returns a dataset object
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

# Convert cell type labels to integers
cell_type_labels = [class_id_dict[ct] for ct in cell_types]

print(f"Class mapping: {class_id_dict}")

# Fine-tune
model.train(train_input_data=data, train_labels=cell_type_labels)

# Training STATE for the Virtual Cell Challenge

We use data from the Virtual Cell Challenge for model training and downstream inference. For this we require the VCC dataset as in the colab notebook by the authors. See the relevant code snippet for the entire dataset in the below colab notebook:

[STATE Colab Notebook](https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l)

For demonstration we use a subset of the data. 

In [None]:
from helical.utils.downloader import Downloader

downloader = Downloader()
downloader.download_via_name("sample_vcc_data")

We use the `stateTransitionTrainModel` class and initialise `trainConfigs`. Be sure to edit the `competition/stater.toml` file to point to the correct dataset path (see top of file).

In [None]:
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainingConfig

# default configs for competition dataset
train_config = trainingConfig(
    output_dir="competition",
    name="first_run",
    toml_config_path="competition_support_set/starter.toml",
    checkpoint_name="final.ckpt",
    max_steps=40000,
    max_epochs=1,
    ckpt_every_n_steps=20000,
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="competition_support_set/ESM2_pert_features.pt"
    )

In [None]:
# we can then train the model and perform inference on a held out test set
state_train = stateTransitionTrainModel(configurer=train_config)
state_train.train() 
state_train.predict() 

The trained model will be saved to the `competition/first_run` directory, alongside the necessary files and checkpoints to intialise a new model. We can initialise `stateTransitionModel` as before and run inference.

In [None]:
from helical.models.state import stateTransitionModel

adata = sc.read_h5ad("competition_support_set/validation_data.h5ad")

state_config = stateConfig(
    output_path = "competition/prediction.h5ad",
    model_dir = "competition/first_run",
    pert_col = "target_gene",
)

state_transition = stateTransitionModel(configurer=state_config)
processed_data = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(processed_data)

Now you can use the `cell-eval` package to create a submission to the Virtual Cell Challenge. Helical provides a quicker wrapper around the main evaluation function that generates the `.vcc` file.

In [None]:
# evaluate the model - underlying function uses cell-eval package 
# (https://github.com/ArcInstitute/cell-eval)
from helical.models.state import vcc_eval

# default configs for competition dataset
EXPECTED_GENE_DIM = 18080
MAX_CELL_DIM = 100000
DEFAULT_PERT_COL = "target_gene"
DEFAULT_CTRL = "non-targeting"
DEFAULT_COUNTS_COL = "n_cells"
DEFAULT_CELLTYPE_COL = "celltype"
DEFAULT_NTC_NAME = "non-targeting"

configs = {
    "input": "competition/prediction.h5ad",
    "genes": "competition_support_set/gene_names.csv",
    "output": None, # path to the output file - if None will be created with default naming
    "pert_col": DEFAULT_PERT_COL,
    "celltype_col": None,
    "ntc_name": DEFAULT_NTC_NAME,
    "output_pert_col": DEFAULT_PERT_COL,
    "output_celltype_col": DEFAULT_CELLTYPE_COL,
    "encoding": 32,
    "allow_discrete": False,
    "expected_gene_dim": EXPECTED_GENE_DIM,
    "max_cell_dim": MAX_CELL_DIM,
}

# this creates a submission file in the output directory which can be uploaded to the challenge leaderboard
vcc_eval(configs)