In [None]:
import os

def get_directory_size(directory_path):
    """Get the total size of a directory in bytes"""
    total_size = 0
    try:
        for dirpath, dirnames, filenames in os.walk(directory_path):
            for filename in filenames:
                filepath = os.path.join(dirpath, filename)
                if os.path.exists(filepath):
                    total_size += os.path.getsize(filepath)
    except (OSError, IOError):
        pass
    return total_size

def format_size(size_bytes):
    """Convert bytes to human readable format"""
    if size_bytes == 0:
        return "0 B"
    size_names = ["B", "KB", "MB", "GB", "TB"]
    i = 0
    while size_bytes >= 1024 and i < len(size_names) - 1:
        size_bytes /= 1024.0
        i += 1
    return f"{size_bytes:.2f} {size_names[i]}"

# Get size of current directory
current_dir = "."
total_size_bytes = get_directory_size(current_dir)
print(f"Total directory size: {format_size(total_size_bytes)}")
print(f"Total directory size: {total_size_bytes / (1024**3):.3f} GB")

# Get size of specific directory (example)
competition_dir = "competition_support_set"
competition_size = get_directory_size(competition_dir)
print(f"Competition directory size: {format_size(competition_size)}")

In [None]:
import os 
file_size_bytes = os.path.getsize("/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/nn/embed_model_epoch16_weights.pt")
file_size_gb = file_size_bytes / (1024**3)
print(f"Original file size: {file_size_gb:.3f} GB")

In [None]:
import scanpy as sc
import os

# Read the dataset
adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# Get original size info
print(f"Original shape: {adata.shape}")
file_size_bytes = os.path.getsize("competition_support_set/competition_val_template.h5ad")
file_size_gb = file_size_bytes / (1024**3)
print(f"Original file size: {file_size_gb:.3f} GB")

# Truncate the dataset (adjust numbers as needed)
# Take first 100 cells and first 1000 genes
truncated_adata = adata[:100].copy()

print(f"Truncated shape: {truncated_adata.shape}")

# Save the truncated dataset
truncated_adata.write_h5ad("state_test.h5ad")

# Check the new file size
new_file_size_bytes = os.path.getsize("state_test.h5ad")
new_file_size_gb = new_file_size_bytes / (1024**3)
print(f"New file size: {new_file_size_gb:.3f} GB")
print(f"Size reduction: {(1 - new_file_size_gb/file_size_gb)*100:.1f}%")

# Show some basic info about the truncated dataset
print(f"\nTruncated dataset info:")
print(f"Number of cells: {truncated_adata.n_obs}")
print(f"Number of genes: {truncated_adata.n_vars}")
print(f"Obs columns: {list(truncated_adata.obs.columns)}")
print(f"Var columns: {list(truncated_adata.var.columns)}")

# Generating Example Data

We use data from the Virtual Cell Challenge for model training, inference, and finetuning.

In [None]:
'''
Download the dataset

(taken from Colab Notebook by Adduri et al.
https://colab.research.google.com/drive/1QKOtYP7bMpdgDJEipDxaJqOchv7oQ-_l#scrollTo=h0aSjKX7Rtyw)
'''

import requests
from tqdm.auto import tqdm  # picks the best bar for the environment
from zipfile import ZipFile
from tqdm.auto import tqdm
import os

# Download the Replogle-Nadig training dataset.
url = "https://storage.googleapis.com/vcc_data_prod/datasets/state/competition_support_set.zip"
output_path = "competition_support_set.zip"

# stream the download so we can track progress
response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0))

with open(output_path, "wb") as f, tqdm(
    total=total, unit='B', unit_scale=True, desc="Downloading"
) as bar:
    for chunk in response.iter_content(chunk_size=8192):
        if not chunk:
            break
        f.write(chunk)
        bar.update(len(chunk))

out_dir  = "competition_support_set"
os.makedirs(out_dir, exist_ok=True)
with ZipFile(output_path, 'r') as z:
    for member in tqdm(z.infolist(), desc="Unzipping", unit="file"):
        z.extract(member, out_dir)

# State embeddings

In [None]:
from helical.models.state import stateEmbed
from helical.models.state import stateConfig

In [None]:
# from helical.models.state import stateEmbed
# from helical.models.state import stateConfig
# from helical.models.state import stateEmbedTorch
# import scanpy as sc

# state_config = stateConfig()
# state_embed = stateEmbed(configurer=state_config)


# state_config2 = stateConfig()
# state_embed_torch = stateEmbedTorch(configurer=state_config2)

INFO:datasets:PyTorch version 2.6.0 available.
INFO:datasets:Polars version 1.33.0 available.
INFO:helical.models.state.state_embeddings:Using model checkpoint: /home/rasched/.cache/helical/models/state/state_embed/se600m_epoch16.ckpt
INFO:helical.models.state.state_embeddings_torch:Using model checkpoint: /home/rasched/.cache/helical/models/state/state_embed/se600m_model_weights.pt
INFO:helical.models.state.state_embeddings_torch:number of free parameters: 603245738
INFO:helical.models.state.state_embeddings_torch:Missing keys: []
INFO:helical.models.state.state_embeddings_torch:Successfully loaded model


In [None]:
# import torch

# def compare_models(model1, model2):
#     params1 = list(model1.parameters())
#     params2 = list(model2.parameters())
    
#     print(f"Model 1 has {len(params1)} parameters")
#     print(f"Model 2 has {len(params2)} parameters")
    
#     for i, (p1, p2) in enumerate(zip(params1, params2)):
#         if not torch.allclose(p1, p2, atol=1e-6):
#             print(f"Parameter {i} differs! Max diff: {torch.max(torch.abs(p1 - p2)).item()}")
#             return False
#     print("All parameters are identical!")
#     return True

# compare_models(state_embed.model, state_embed_torch.model)
# print(state_embed.model.training, state_embed_torch.model.training)


# adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")
# adata = adata[:2].copy()

# processed_data1 = state_embed.process_data(adata=adata.copy())
# batch = next(iter(processed_data1))
# print(batch)


Model 1 has 234 parameters
Model 2 has 234 parameters
All parameters are identical!
False False


INFO:helical.models.state.state_embeddings:Auto-detected gene column: var.index (overlap: 17773/19790 protein embeddings, 98.3% of genes)
INFO:/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/data/loader.py:17773 genes mapped to embedding file (out of 18080)
INFO:/home/rasched/final_helical_with_state/helical/helical/models/state/model_dir/embed_utils/data/loader.py:17773 genes mapped to embedding file (out of 18080)


(tensor([[12646,   205,  2577,  ...,  5543, 11824, 13012],
        [12646,   205,  4770,  ..., 18803,   367, 11218]], dtype=torch.int32), tensor([[15287,  4376,  5904,  ...,  5158, 19639, 12565],
        [ 1413, 14521, 13956,  ...,  5158, 19639, 12565]], dtype=torch.int32), tensor([[0.6914, 0.6914, 1.6094,  ..., 2.9375, 0.0000, 0.0000],
        [3.3281, 1.6094, 1.6094,  ..., 1.6094, 0.0000, 0.0000]],
       dtype=torch.bfloat16), tensor([0, 1], dtype=torch.int32), tensor([[0.0000, 0.6914, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 1.1016, 1.3828,  ..., 0.0000, 0.0000, 0.0000]],
       dtype=torch.bfloat16), tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]]), tensor([1376., 1456.], dtype=torch.bfloat16), tensor([[0.0000, 0.0403, 0.0344,  ..., 0.0148, 0.0148, 0.0148],
        [0.0040, 0.0369, 0.0322,  ..., 0.0145, 0.0145, 0.0145]],
       dtype=torch.bfloat16), tensor([0, 0], dtype=torch.int32))


In [None]:
# with torch.no_grad():
#     with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
#         _, _, _, emb1, _ = state_embed.model._compute_embedding_for_batch(batch)
#         _, _, _, emb2, _ = state_embed_torch.model._compute_embedding_for_batch(batch)
    
#     print(f"Model 1 embedding: {emb1[0, :5]}")
#     print(f"Model 2 embedding: {emb2[0, :5]}")
    
#     # Compare
#     diff = torch.abs(emb1 - emb2).sum()
#     print(f"Difference: {torch.max(diff).item()}")

Model 1 embedding: tensor([ 0.0048,  0.0108,  0.0209, -0.0180,  0.0364], device='cuda:0')
Model 2 embedding: tensor([ 0.0048,  0.0108,  0.0209, -0.0180,  0.0364], device='cuda:0')
Difference: 0.0


In [None]:
# Extract weights from checkpoint
# import torch
# checkpoint = state_embed.checkpoint['state_dict']
# state_dict = state_embed.checkpoint['state_dict']
# torch.save(state_dict, "embed_model_epoch16_weights.pt")

# Training the Model

The example data should generate a directory. Before training edit the `competition_support_set/starter.toml` file to point to the correct path on your machine. Here you can also change the train/test split but feel free to leave this at default.

In [None]:
# train the model on the training data
from helical.models.state import stateTransitionTrainModel
from helical.models.state.train_configs import trainingConfig

train_config = trainingConfig(
    output_dir="competition",
    name="first_run",
    toml_config_path="competition_support_set/starter.toml",
    checkpoint_name="final.ckpt",
    max_steps=40000,
    max_epochs=1,
    ckpt_every_n_steps=20000,
    num_workers=4,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    perturbation_features_file="competition_support_set/ESM2_pert_features.pt"
    )

state_train = stateTransitionTrainModel(configurer=train_config)
state_train.train() 
state_train.predict() 

Once the model is trained we can perform inference on a new dataset using:

In [None]:
from helical.models.state import stateTransitionModel

state_config = stateConfig(
    output = "competition/prediction.h5ad",
    model_dir = "competition/first_run",
    model_config = "model_dir/config.yaml",
    pert_col = "target_gene",
)

adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

state_transition = stateTransitionModel(configurer=state_config)
adata = state_transition.process_data(adata)
embeds = state_transition.get_embeddings(adata)

# Creating a Virtual Cell Challenge Submission

To create a submission for the Virtual Cell Challenge we pass our previous prediction file into `helical.models.state.vcc_eval` which uses the `cell-eval` package. This will generate a `.vcc` file that can be uploaded to the public leaderboard.

In [None]:
# evaluate the model - underlying function uses cell-eval package 
# (https://github.com/ArcInstitute/cell-eval)
from helical.models.state import vcc_eval

# default configs for competition dataset
EXPECTED_GENE_DIM = 18080
MAX_CELL_DIM = 100000
DEFAULT_PERT_COL = "target_gene"
DEFAULT_CTRL = "non-targeting"
DEFAULT_COUNTS_COL = "n_cells"
DEFAULT_CELLTYPE_COL = "celltype"
DEFAULT_NTC_NAME = "non-targeting"

configs = {
    # path to the prediction file
    "input": "competition/prediction.h5ad",
    # path to the gene names file
    "genes": "competition_support_set/gene_names.csv",
    # path to the output file - if None will be created with default naming
    "output": None,
    "pert_col": DEFAULT_PERT_COL,
    "celltype_col": None,
    "ntc_name": DEFAULT_NTC_NAME,
    "output_pert_col": DEFAULT_PERT_COL,
    "output_celltype_col": DEFAULT_CELLTYPE_COL,
    "encoding": 32,
    "allow_discrete": False,
    "expected_gene_dim": EXPECTED_GENE_DIM,
    "max_cell_dim": MAX_CELL_DIM,
}

# this creates a submission file in the output directory which can be uploaded to the challenge leaderboard
vcc_eval(configs)

# Finetuning Example

We can further finetune the model. This can start with a model directory containing a `.ckpt` file. If the model directory does not exists, one will be created and so will a new model instance. The finetuning will produce model weights and head weights that will be saved and training will resume from these `.pt` files in the directory (if present).

In [None]:
from helical.models.state import stateFineTuningModel

# Load the desired dataset
adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# Get the desired label class
cell_types = list(adata.obs.cell_type)
label_set = set(cell_types)

# Create the fine-tuning model (no need to specify var_dims location)
config = stateConfig(
    batch_size=8,
    model_dir="competition/first_run",
    model_config="model_dir/config.yaml",
    freeze_backbone=True
)

model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]

# Fine-tune
model.train(train_input_data=data, train_labels=cell_types)