This notebook goes over how to use `STATE` using `helical`.

# Download Example Data

We start by using the helical downloader to obtain an example huggingface dataset. 

In [None]:
from helical.utils.downloader import Downloader
from pathlib import Path

downloader = Downloader()
downloader.download_via_link(
    Path("yolksac_human.h5ad"),
    "https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true",)

# STATE Embeddings

Using the STATE model we can obtain single cell transcriptome embeddings. We first slice the dataset for demonstration purposes.

In [None]:
# load the data 
import scanpy as sc

adata = sc.read_h5ad("yolksac_human.h5ad")
# for demonstration we subset to 10 cells and 2000 genes
n_cells = 10
n_genes = 2000
# for demonstration we subset to 10 cells and 2000 genes
adata = adata[:n_cells, :n_genes].copy()

print(adata.shape)
n_cells = adata.n_obs
print(n_cells)

Initialise the model - this will download the relevant files needed in `.cache/helical/state/`. It will download the necessary files when run the first time so will take slightly longer. 


In [None]:
from helical.models.state import StateConfig    
from helical.models.state import StateEmbed

state_config = StateConfig(batch_size=16)
state_embed = StateEmbed(configurer=state_config)

We process the data by calling `state_embed.process_data` and pass this into `state_embed.get_embeddings` to get the final embeddings.

In [None]:
processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

# note that the STATE model returns a numpy array of shape (n_cells, 1024)
print(embeddings.shape)
print(type(embeddings))

# store the embeddings in adata.obsm['state_emb']
adata.obsm['state_emb'] = embeddings

# STATE Perturbations

To use the perturbation model you can either pass in embeddings by specifiyng the `embed_key` arguement in `stateConfig` or use the deafult `None` value in which case the expression values are used (`adata.X`).

For use of previous embeddings, the `embed_key` must exist in `adata.obsm[<embed_key>]` otherwise an error will be thrown. When set to `None` the model uses `adata.X`.

Let's create some dummy data for the previous example.

In [None]:
import numpy as np
# some default control and non-control perturbations
perturbations = [
    "[('DMSO_TF', 0.0, 'uM')]",  # Control
    "[('Aspirin', 0.5, 'uM')]",
    "[('Dexamethasone', 1.0, 'uM')]",
]

n_cells = adata.n_obs
# we assign perturbations to cells randomly
adata.obs['target_gene'] = np.random.choice(perturbations, size=n_cells)
adata.obs['cell_type'] = adata.obs['LVL1']  # Use your cell type column
# we can also add a batch variable to take into account batch effects
batch_labels = np.random.choice(['batch_1', 'batch_2', 'batch_3', 'batch_4'], size=n_cells)
adata.obs['batch_var'] = batch_labels

config = StateConfig(
    embed_key=None,
    pert_col="target_gene",
    celltype_col="cell_type",
    control_pert="[('DMSO_TF', 0.0, 'uM')]",
    output_path="yolksac_perturbed.h5ad",
)


Now we can run the perturbation model.

In [None]:
from helical.models.state import StatePerturb

state_perturb = StatePerturb(configurer=config)

# again we process the data and get the perturbed embeddings
processed_data = state_perturb.process_data(adata)
perturbed_embeds = state_perturb.get_embeddings(processed_data)

print(perturbed_embeds.shape)