This notebook goes over how to use `STATE` using `helical`.

# Download Example Data

We start by using the helical downloader to obtain an example huggingface dataset. 

In [1]:
from helical.utils.downloader import Downloader
from pathlib import Path

downloader = Downloader()
downloader.download_via_link(
    Path("yolksac_human.h5ad"),
    "https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true",)

  from .autonotebook import tqdm as notebook_tqdm

INFO:datasets:PyTorch version 2.6.0 available.
INFO:datasets:Polars version 1.33.0 available.
INFO:helical.utils.downloader:Starting to download: 'https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true'
yolksac_human.h5ad: 100%|██████████| 553M/553M [00:04<00:00, 116MB/s]  


# STATE Embeddings

Using the STATE model we can obtain single cell transcriptome embeddings. We first slice the dataset for demonstration purposes.

In [2]:
# load the data 
import scanpy as sc

adata = sc.read_h5ad("yolksac_human.h5ad")
# We subset to 10 cells and 2000 genes
n_cells = 10
n_genes = 2000
adata = adata[:n_cells, :n_genes].copy()

print(adata.shape)
n_cells = adata.n_obs
print(n_cells)

(10, 2000)
10


Initialise the model - this will download the relevant files needed in `.cache/helical/state/`. It will download the necessary files when run the first time so will take slightly longer. 


In [3]:
from helical.models.state import StateConfig    
from helical.models.state import StateEmbed

state_config = StateConfig(batch_size=16)
state_embed = StateEmbed(configurer=state_config)

INFO:helical.models.state.state_embeddings:Using model checkpoint: /home/rasched/.cache/helical/models/state/state_embed/se600m_model_weights.pt
INFO:helical.models.state.state_embeddings:Successfully loaded model


We process the data by calling `state_embed.process_data` and pass this into `state_embed.get_embeddings` to get the final embeddings.

In [4]:
processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

# note that the STATE model returns a numpy array of shape (n_cells, 1024)
print(embeddings.shape)
print(type(embeddings))

INFO:helical.models.state.state_embeddings:Auto-detected gene column: var.index (overlap: 113/19790 protein embeddings, 5.7% of genes)
INFO:/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/loader.py:113 genes mapped to embedding file (out of 2000)
INFO:/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/loader.py:113 genes mapped to embedding file (out of 2000)
Encoding: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]

(10, 2058)
<class 'numpy.ndarray'>





# STATE Perturbations

To use the perturbation model you can pass in perturbations and batch labels alongside the raw gene expression data.

Let's create some dummy data for the previous example.

In [5]:
import numpy as np
# some default control and non-control perturbations
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

n_cells = adata.n_obs
# we assign perturbations to cells randomly
adata.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata.obs['cell_type'] = adata.obs['LVL1']  # Use your cell type column
# we can also add a batch variable to take into account batch effects
batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata.obs['batch_var'] = batch_labels

config = StateConfig(
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    output_path="yolksac_perturbed.h5ad",
)


Now we can run the perturbation model.

In [6]:
from helical.models.state import StatePerturb

state_perturb = StatePerturb(configurer=config)

# again we process the data and get the perturbed embeddings
processed_data = state_perturb.process_data(adata)
perturbed_embeds = state_perturb.get_embeddings(processed_data)

print(perturbed_embeds.shape)

INFO:helical.models.state.model_dir.perturb_utils.base:Loaded decoder from checkpoint decoder_cfg
INFO:helical.models.state.state_perturb:Checkpoint: /home/rasched/.cache/helical/models/state/state_transition/ST_all.pt, Device: cuda
INFO:helical.models.state.state_perturb:Cell set length (max sequence length): 256
INFO:helical.models.state.state_perturb:Batch encoder: True
INFO:helical.models.state.state_perturb:Output space: gene
INFO:helical.models.state.state_perturb:Grouping by cell type column: cell_type
INFO:helical.models.state.state_perturb:Batch column: batch_var
INFO:helical.models.state.state_perturb:Using adata.X as input features
INFO:helical.models.state.state_perturb:Using batch encoder
INFO:helical.models.state.state_perturb:Batch column found: batch_var
INFO:helical.models.state.state_perturb:Batch onehot map found, converting labels to indices
INFO:helical.models.state.state_perturb:Cells: total=10, control=3, non-control=7
INFO:helical.models.state.state_perturb:Runn

(10, 2000)
