This notebook goes over how to use `STATE` using `helical`.

# Download Example Data

We start by using the helical downloader to obtain an example huggingface dataset. 

In [None]:
from helical.utils.downloader import Downloader
from pathlib import Path

downloader = Downloader()
downloader.download_via_link(
    Path("yolksac_human.h5ad"),
    "https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true",)

# STATE Embeddings

Using the STATE model we can obtain single cell transcriptome embeddings. We first slice the dataset for demonstration purposes.

In [None]:
# load the data 
import scanpy as sc

adata = sc.read_h5ad("yolksac_human.h5ad")
# for demonstration we subset to 10 cells and 2000 genes
n_cells = 10
n_genes = 2000
# for demonstration we subset to 10 cells and 2000 genes
adata = adata[:n_cells, :n_genes].copy()

print(adata.shape)
n_cells = adata.n_obs
print(n_cells)

Initialise the model - this will download the relevant files needed in `.cache/helical/state/`. It will download the necessary files when run the first time so will take slightly longer. 


In [None]:
from helical.models.state import StateConfig    
from helical.models.state import StateEmbed

state_config = StateConfig(batch_size=16)
state_embed = StateEmbed(configurer=state_config)

We process the data by calling `state_embed.process_data` and pass this into `state_embed.get_embeddings` to get the final embeddings.

In [None]:
processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

# note that the STATE model returns a numpy array of shape (n_cells, 1024)
print(embeddings.shape)
print(type(embeddings))

# store the embeddings in adata.obsm['state_emb']
adata.obsm['state_emb'] = embeddings

# STATE Perturbations

To use the perturbation model you can either pass in embeddings by specifiyng the `embed_key` arguement in `stateConfig` or use the deafult `None` value in which case the expression values are used (`adata.X`).

For use of previous embeddings, the `embed_key` must exist in `adata.obsm[<embed_key>]` otherwise an error will be thrown. When set to `None` the model uses `adata.X`.

Let's create some dummy data for the previous example.

In [None]:
import numpy as np
# some default control and non-control perturbations
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

n_cells = adata.n_obs
# we assign perturbations to cells randomly
adata.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata.obs['cell_type'] = adata.obs['LVL1']  # Use your cell type column
# we can also add a batch variable to take into account batch effects
batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata.obs['batch_var'] = batch_labels

config = StateConfig(
    embed_key=None,
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    output_path="yolksac_perturbed.h5ad",
)


Now we can run the perturbation model.

In [None]:
from helical.models.state import StateTransitionModel

state_transition = StateTransitionModel(configurer=config)

# again we process the data and get the perturbed embeddings
processed_data = state_transition.process_data(adata)
perturbed_embeds = state_transition.get_embeddings(processed_data)

print(perturbed_embeds.shape)

# Finetuning STATE

We can finetune the STATE perturbation embeddings using an additional head for downstream classification and regression. Below is a dummy example using data above to get you started.

In [None]:
from helical.models.state import StateFineTuningModel

# Dummy cell types and labels for demonstration
cell_types = list(adata.obs['LVL1'])
label_set = set(cell_types)
print(f"Found {len(label_set)} unique cell types:")

config = StateConfig(
    embed_key=None,
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    batch_size=8,
)

# Create the fine-tuning model - we use a classification head for demonstration
model = StateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training - returns a dataset object
dataset = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

# Convert cell type labels to integers
cell_type_labels = [class_id_dict[ct] for ct in cell_types]

print(f"Class mapping: {class_id_dict}")

# Fine-tune
model.train(train_input_data=dataset, train_labels=cell_type_labels)

# Training STATE for the Virtual Cell Challenge

We use data from the Virtual Cell Challenge for model training and downstream inference. For this we require the VCC dataset as in the colab notebook by the authors. See the relevant code snippet for the entire dataset in the below colab notebook:

[STATE Colab Notebook](https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l)

For demonstration we have created a subset of the data. We also need to change the filepath in `starter.toml` to point to the correct dataset location (see top of file), but this is done below in the code. Start by downloading the data:

In [None]:
from helical.utils.downloader import Downloader
from helical.constants.paths import CACHE_DIR_HELICAL
import toml
from pathlib import Path

downloader = Downloader()
downloader.download_via_name("state/sample_vcc_data/config.yaml")
downloader.download_via_name("state/sample_vcc_data/starter.toml")
downloader.download_via_name("state/sample_vcc_data/gene_names.csv")
downloader.download_via_name("state/sample_vcc_data/ESM2_pert_features.pt")
downloader.download_via_name("state/sample_vcc_data/hepg2_mini.h5")
downloader.download_via_name("state/sample_vcc_data/rpe1_mini.h5")
downloader.download_via_name("state/sample_vcc_data/test.h5ad")

toml.dump({**toml.load(open(Path(CACHE_DIR_HELICAL, "state/sample_vcc_data/starter.toml"))),**{"datasets": {"replogle_h1": str(Path(CACHE_DIR_HELICAL, "state/sample_vcc_data").absolute() / "{rpe1_mini,hepg2_mini}.h5")}},},open(Path(CACHE_DIR_HELICAL, "state/sample_vcc_data/starter.toml"), "w"))

We use the `stateTransitionTrainModel` class and initialise training configurations using the `config.yaml` file in the sample directory. You can edit these based on your training preferences. Currently this is set to one epoch for demonstration.

In [None]:
# we can then train the model and perform inference on a held out test set
from helical.models.state import StateTransitionTrainModel
from omegaconf import OmegaConf

train_configs = OmegaConf.load(Path(CACHE_DIR_HELICAL, "state/sample_vcc_data/config.yaml"))
# set the correct paths for the data
train_configs.data.kwargs.toml_config_path = str(CACHE_DIR_HELICAL / "state/sample_vcc_data/starter.toml")
train_configs.data.kwargs.perturbation_features_file = str(CACHE_DIR_HELICAL / "state/sample_vcc_data/ESM2_pert_features.pt")

state_train = StateTransitionTrainModel(configurer=train_configs)
state_train.train() 
state_train.predict() 

The trained model will be saved to the `sample_vcc_data/first_run` directory, alongside the necessary files and checkpoints to intialise a new model. We can initialise `stateTransitionModel` as before and run inference.

In [None]:
from helical.models.state import StateTransitionModel
from helical.models.state import StateConfig
import scanpy as sc

adata = sc.read_h5ad(Path(CACHE_DIR_HELICAL, "state/sample_vcc_data/test.h5ad"))

state_config = StateConfig(
    output_path = "sample_run/prediction.h5ad",
    perturb_dir = "sample_run/first_run",
    pert_col = "target_gene",
)

state_transition = StateTransitionModel(configurer=state_config)
processed_data = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(processed_data)

Now you can use the `cell-eval` package to create a submission to the Virtual Cell Challenge (generates a `.vcc` file).

In [None]:
gene_file = CACHE_DIR_HELICAL / "state/sample_vcc_data/gene_names.csv"
input_file = "sample_run/prediction.h5ad"

! pip install cell-eval
! cell-eval prep -i {input_file} -g {gene_file}