# ctv2 Value-Masked SSL Demo

This notebook demonstrates how to run the PyTorch/Accelerate implementation of the value-masked self-supervised objective using AnnData-backed datasets.

In [1]:
from pathlib import Path
import sys
repo_root = Path.cwd().resolve().parents[0]
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

In [None]:
%pip install --quiet anndata accelerate torch --extra-index-url https://download.pytorch.org/whl/cpu

In [None]:
import anndata as ad
import numpy as np
import pandas as pd

In [None]:
# Create a toy AnnData object if you do not have a cohort at hand
num_samples = 64
num_features = 50
data = np.random.randn(num_samples, num_features).astype(np.float32)
var_names = [f'feat_{i}' for i in range(num_features)]
adata = ad.AnnData(X=data, var=pd.DataFrame(index=var_names))

In [None]:
from torch.utils.data import DataLoader
from accelerate import Accelerator

In [None]:
from ctv2 import (
    FeatureTokenizer,
    ValueMaskedAnnDataDataset,
    value_mask_collate,
    ValueMaskedConfig,
    ValueMaskedTransformer,
    ValueMaskedTrainer,
)

In [None]:
tokenizer = FeatureTokenizer(adata.var_names)
dataset = ValueMaskedAnnDataDataset(
    adata,
    tokenizer,
    max_features=48,
    mask_fraction=0.2,
)
loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=value_mask_collate,
)

In [None]:
config = ValueMaskedConfig(
    vocab_size=tokenizer.vocab_size,
    max_seq_len=dataset.seq_len,
    d_model=128,
    num_heads=4,
    num_layers=4,
    d_ff=256,
    dropout=0.1,
)
model = ValueMaskedTransformer(config, pad_token_id=tokenizer.pad_id)

In [None]:
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [None]:
accelerator = Accelerator()
trainer = ValueMaskedTrainer(model, accelerator=accelerator)
history = trainer.fit(loader, optimizer, epochs=5)

In [None]:
[(state.epoch, state.loss) for state in history]