In [None]:
import os

def get_directory_size(directory_path):
    """Get the total size of a directory in bytes"""
    total_size = 0
    try:
        for dirpath, dirnames, filenames in os.walk(directory_path):
            for filename in filenames:
                filepath = os.path.join(dirpath, filename)
                if os.path.exists(filepath):
                    total_size += os.path.getsize(filepath)
    except (OSError, IOError):
        pass
    return total_size

def format_size(size_bytes):
    """Convert bytes to human readable format"""
    if size_bytes == 0:
        return "0 B"
    size_names = ["B", "KB", "MB", "GB", "TB"]
    i = 0
    while size_bytes >= 1024 and i < len(size_names) - 1:
        size_bytes /= 1024.0
        i += 1
    return f"{size_bytes:.2f} {size_names[i]}"

# Get size of current directory
current_dir = "."
total_size_bytes = get_directory_size(current_dir)
print(f"Total directory size: {format_size(total_size_bytes)}")
print(f"Total directory size: {total_size_bytes / (1024**3):.3f} GB")

# Get size of specific directory (example)
competition_dir = "competition_support_set"
competition_size = get_directory_size(competition_dir)
print(f"Competition directory size: {format_size(competition_size)}")

In [12]:
import os 
file_size_bytes = os.path.getsize("/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/nn/embed_model_epoch16_weights.pt")
file_size_gb = file_size_bytes / (1024**3)
print(f"Original file size: {file_size_gb:.3f} GB")

Original file size: 2.625 GB


In [None]:
import scanpy as sc
import os

# Read the dataset
adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# Get original size info
print(f"Original shape: {adata.shape}")
file_size_bytes = os.path.getsize("competition_support_set/competition_val_template.h5ad")
file_size_gb = file_size_bytes / (1024**3)
print(f"Original file size: {file_size_gb:.3f} GB")

# Truncate the dataset (adjust numbers as needed)
# Take first 100 cells and first 1000 genes
truncated_adata = adata[:100].copy()

print(f"Truncated shape: {truncated_adata.shape}")

# Save the truncated dataset
truncated_adata.write_h5ad("state_test.h5ad")

# Check the new file size
new_file_size_bytes = os.path.getsize("state_test.h5ad")
new_file_size_gb = new_file_size_bytes / (1024**3)
print(f"New file size: {new_file_size_gb:.3f} GB")
print(f"Size reduction: {(1 - new_file_size_gb/file_size_gb)*100:.1f}%")

# Show some basic info about the truncated dataset
print(f"\nTruncated dataset info:")
print(f"Number of cells: {truncated_adata.n_obs}")
print(f"Number of genes: {truncated_adata.n_vars}")
print(f"Obs columns: {list(truncated_adata.obs.columns)}")
print(f"Var columns: {list(truncated_adata.var.columns)}")

# Generating Example Data

We use data from the Virtual Cell Challenge for model training, inference, and finetuning.

In [None]:
'''
Download the dataset

(taken from Colab Notebook by Adduri et al.
https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l#scrollTo=h0aSjKX7Rtyw)
'''

import requests
from tqdm.auto import tqdm  # picks the best bar for the environment
from zipfile import ZipFile
from tqdm.auto import tqdm
import os

# Download the Replogle-Nadig training dataset.
url = "https://storage.googleapis.com/vcc_data_prod/datasets/state/competition_support_set.zip"
output_path = "competition_support_set.zip"

# stream the download so we can track progress
response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0))

with open(output_path, "wb") as f, tqdm(
    total=total, unit='B', unit_scale=True, desc="Downloading"
) as bar:
    for chunk in response.iter_content(chunk_size=8192):
        if not chunk:
            break
        f.write(chunk)
        bar.update(len(chunk))

out_dir  = "competition_support_set"
os.makedirs(out_dir, exist_ok=True)
with ZipFile(output_path, 'r') as z:
    for member in tqdm(z.infolist(), desc="Unzipping", unit="file"):
        z.extract(member, out_dir)

# State embeddings

We can generate embeddings using the STATE Embedding model in the helical package.

In [1]:
from helical.models.state import stateEmbed
from helical.models.state import stateConfig
from helical.models.state import stateEmbedTorch

# state_config = stateConfig(embed_checkpoint="model.safetensors")
state_config = stateConfig()
state_embed = stateEmbed(configurer=state_config)

INFO:datasets:PyTorch version 2.6.0 available.
INFO:datasets:Polars version 1.33.0 available.
INFO:helical.models.state.state_embeddings:Using model checkpoint: /home/rasched/.cache/helical/models/state/state_embed/se600m_epoch16.ckpt


"cfg":          {'dataset': {'N': 512, 'P': 512, 'S': 512, 'cellxgene': {'ds_type': 'h5ad', 'filter': False, 'num_datasets': 1139, 'train': '/large_storage/ctc/userspace/aadduri/data/auxillary/esm_cellxgene_train.csv', 'val': '/large_storage/ctc/userspace/aadduri/data/auxillary/esm_cellxgene_val.csv'}, 'cellxgene-tahoe': {'ds_type': 'filtered_h5ad', 'filter': True, 'filter_by_species': None, 'num_datasets': 1139, 'train': '/large_storage/ctc/userspace/aadduri/data/auxillary/esm_tahoe_cellxgene_train_filtered.csv', 'val': '/large_storage/ctc/userspace/aadduri/data/auxillary/esm_tahoe_cellxgene_val_filtered.csv'}, 'chrom_token_right_idx': 2, 'cls_token_idx': 3, 'current': 'scbasecamp-cellxgene-tahoe-filtered', 'name': 'vci', 'num_cells': 36238464, 'num_train_workers': 32, 'num_val_workers': 8, 'overrides': {'rpe1_top5000_variable': '/large_storage/ctc/datasets/vci/validation/rpe1_top5000_variable.h5ad'}, 'pad_length': 2048, 'pad_token_idx': 0, 'scbasecamp-cellxgene-tahoe': {'ds_type': 'f

In [None]:
# get model cfg 
cfg = state_embed.model.hparams.cfg

In [10]:
checkpoint  =  state_embed.checkpoint['state_dict']
print(checkpoint.keys())
state_dict = state_embed.checkpoint['state_dict']
import torch
torch.save(state_dict, "embed_model_epoch16_weights.pt")

odict_keys(['cls_token', 'dataset_token', 'encoder.0.weight', 'encoder.0.bias', 'encoder.1.weight', 'encoder.1.bias', 'transformer_encoder.layers.0.qkv_proj.weight', 'transformer_encoder.layers.0.qkv_proj.bias', 'transformer_encoder.layers.0.out_proj.weight', 'transformer_encoder.layers.0.out_proj.bias', 'transformer_encoder.layers.0.norm1.weight', 'transformer_encoder.layers.0.norm1.bias', 'transformer_encoder.layers.0.norm2.weight', 'transformer_encoder.layers.0.norm2.bias', 'transformer_encoder.layers.0.linear1.weight', 'transformer_encoder.layers.0.linear1.bias', 'transformer_encoder.layers.0.linear2.weight', 'transformer_encoder.layers.0.linear2.bias', 'transformer_encoder.layers.1.qkv_proj.weight', 'transformer_encoder.layers.1.qkv_proj.bias', 'transformer_encoder.layers.1.out_proj.weight', 'transformer_encoder.layers.1.out_proj.bias', 'transformer_encoder.layers.1.norm1.weight', 'transformer_encoder.layers.1.norm1.bias', 'transformer_encoder.layers.1.norm2.weight', 'transformer_

In [6]:
from safetensors.torch import load_file
safetensors_state_dict = load_file("/home/rasched/.cache/helical/models/state/state_embed/model.safetensors")
print(safetensors_state_dict.keys())

dict_keys(['bin_encoder.weight', 'binary_decoder.0.dense.bias', 'binary_decoder.0.dense.weight', 'binary_decoder.0.intermediate_dense.bias', 'binary_decoder.0.intermediate_dense.weight', 'binary_decoder.0.layer_norm.bias', 'binary_decoder.0.layer_norm.weight', 'binary_decoder.1.dense.bias', 'binary_decoder.1.dense.weight', 'binary_decoder.1.intermediate_dense.bias', 'binary_decoder.1.intermediate_dense.weight', 'binary_decoder.1.layer_norm.bias', 'binary_decoder.1.layer_norm.weight', 'binary_decoder.2.bias', 'binary_decoder.2.weight', 'cls_token', 'count_encoder.0.bias', 'count_encoder.0.weight', 'count_encoder.2.bias', 'count_encoder.2.weight', 'dataset_embedder.bias', 'dataset_embedder.weight', 'dataset_encoder.0.bias', 'dataset_encoder.0.weight', 'dataset_encoder.2.bias', 'dataset_encoder.2.weight', 'dataset_encoder.4.bias', 'dataset_encoder.4.weight', 'dataset_token', 'decoder.0.dense.bias', 'decoder.0.dense.weight', 'decoder.0.intermediate_dense.bias', 'decoder.0.intermediate_dens

In [None]:
import torch
from safetensors.torch import load_file

# Load both
ckpt_state_dict = state_embed.checkpoint['state_dict']
safetensors_state_dict = load_file("/home/rasched/.cache/helical/models/state/state_embed/model.safetensors")

# Most reliable comparison
print("Keys match:", set(ckpt_state_dict.keys()) == set(safetensors_state_dict.keys()))

# all_match = True
# for key in ckpt_state_dict.keys():
#     if key in safetensors_state_dict:
#         shape_match = ckpt_state_dict[key].shape == safetensors_state_dict[key].shape
#         values_match = torch.allclose(ckpt_state_dict[key], safetensors_state_dict[key], rtol=1e-5)
#         print(f"{key}: shapes={shape_match}, values={values_match}")
#         if not (shape_match and values_match):
#             all_match = False

# print(f"\nAll tensors match: {all_match}")

Keys match: False


In [None]:
import scanpy as sc

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")
adata = adata[:2].copy()

processed_data = state_embed.process_data(adata=adata)
embeddings = state_embed.get_embeddings(processed_data)

print(embeddings.shape)

# Training the Model

The example data should generate a directory. Before training edit the `competition_support_set/starter.toml` file to point to the correct path on your machine. Here you can also change the train/test split but feel free to leave this at default.

In [None]:
# train the model on the training data
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainingConfig

train_config = trainingConfig(
    output_dir="competition",
    name="first_run",
    toml_config_path="competition_support_set/starter.toml",
    checkpoint_name="final.ckpt",
    max_steps=40000,
    max_epochs=1,
    ckpt_every_n_steps=20000,
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="competition_support_set/ESM2_pert_features.pt"
    )

state_train = stateTransitionTrainModel(configurer=train_config)
state_train.train() 
state_train.predict() 

Once the model is trained we can perform inference on a new dataset using:

In [None]:
from helical.models.state import stateTransitionModel

state_config = stateConfig(
    output = "competition/prediction.h5ad",
    model_dir = "competition/first_run",
    model_config = "model_dir/config.yaml",
    pert_col = "target_gene",
)

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

state_transition = stateTransitionModel(configurer=state_config)
adata = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(adata)

# Creating a Virtual Cell Challenge Submission

To create a submission for the Virtual Cell Challenge we pass our previous prediction file into `helical.models.state.vcc_eval` which uses the `cell-eval` package. This will generate a `.vcc` file that can be uploaded to the public leaderboard.

In [None]:
# evaluate the model - underlying function uses cell-eval package 
# (https://github.com/ArcInstitute/cell-eval)
from helical.models.state import vcc_eval

# default configs for competition dataset
EXPECTED_GENE_DIM = 18080
MAX_CELL_DIM = 100000
DEFAULT_PERT_COL = "target_gene"
DEFAULT_CTRL = "non-targeting"
DEFAULT_COUNTS_COL = "n_cells"
DEFAULT_CELLTYPE_COL = "celltype"
DEFAULT_NTC_NAME = "non-targeting"

configs = {
    # path to the prediction file
    "input": "competition/prediction.h5ad",
    # path to the gene names file
    "genes": "competition_support_set/gene_names.csv",
    # path to the output file - if None will be created with default naming
    "output": None,
    "pert_col": DEFAULT_PERT_COL,
    "celltype_col": None,
    "ntc_name": DEFAULT_NTC_NAME,
    "output_pert_col": DEFAULT_PERT_COL,
    "output_celltype_col": DEFAULT_CELLTYPE_COL,
    "encoding": 32,
    "allow_discrete": False,
    "expected_gene_dim": EXPECTED_GENE_DIM,
    "max_cell_dim": MAX_CELL_DIM,
}

# this creates a submission file in the output directory which can be uploaded to the challenge leaderboard
vcc_eval(configs)

# Finetuning Example

We can further finetune the model. This can start with a model directory containing a `.ckpt` file. If the model directory does not exists, one will be created and so will a new model instance. The finetuning will produce model weights and head weights that will be saved and training will resume from these `.pt` files in the directory (if present).

In [None]:
from helical.models.state import stateFineTuningModel

# Load the desired dataset
adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# Get the desired label class
cell_types = list(adata.obs.cell_type)
label_set = set(cell_types)

# Create the fine-tuning model (no need to specify var_dims location)
config = stateConfig(
    batch_size=8,
    model_dir="competition/first_run",
    model_config="model_dir/config.yaml",
    freeze_backbone=True
)

model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]

# Fine-tune
model.train(train_input_data=data, train_labels=cell_types)