# CTv2 Value-Masked SSL Demo

This notebook demonstrates how to run the PyTorch/Accelerate implementation of the value-masked self-supervised objective using AnnData-backed datasets.

In [1]:
from pathlib import Path
import sys
repo_root = Path.cwd().resolve().parents[0]
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

Found existing installation: ClinicalTranformer 0.0.1
Uninstalling ClinicalTranformer-0.0.1:
  Successfully uninstalled ClinicalTranformer-0.0.1
Collecting git+https://github.com/gaarangoa/ClinicalTransformerV1.git
  Cloning https://github.com/gaarangoa/ClinicalTransformerV1.git to /private/var/folders/8g/qx8n9nfx1339jqj88bc8fv000000gq/T/pip-req-build-hxgppjig
  Running command git clone --filter=blob:none --quiet https://github.com/gaarangoa/ClinicalTransformerV1.git /private/var/folders/8g/qx8n9nfx1339jqj88bc8fv000000gq/T/pip-req-build-hxgppjig
  Resolved https://github.com/gaarangoa/ClinicalTransformerV1.git to commit a5394dd613111b63c48e22d50c9e741ffa08bb1a
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: ClinicalTranformer
[33m  DEPRECATION: Building 'ClinicalTranformer' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to us

In [2]:
%pip install --quiet anndata accelerate torch transformers --extra-index-url https://download.pytorch.org/whl/cpu

Note: you may need to restart the kernel to use updated packages.


In [3]:
import anndata as ad
import numpy as np
import pandas as pd

In [4]:
# Create a toy AnnData object if you do not have a cohort at hand
num_samples = 64
num_features = 50
data = np.random.randn(num_samples, num_features).astype(np.float32)
var_names = [f'feat_{i}' for i in range(num_features)]
adata = ad.AnnData(X=data, var=pd.DataFrame(index=var_names))

In [5]:
from torch.utils.data import DataLoader
from accelerate import Accelerator

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from ctv2 import (
    FeatureTokenizer,
    ValueMaskedAnnDataDataset,
    value_mask_collate,
    ValueMaskedConfig,
    ValueMaskedTransformer,
    ValueMaskedTrainer,
)

In [7]:
tokenizer = FeatureTokenizer(adata.var_names)
dataset = ValueMaskedAnnDataDataset(
    adata,
    tokenizer,
    max_features=48,
    mask_fraction=0.2,
)
loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=value_mask_collate,
)

In [8]:
config = ValueMaskedConfig(
    vocab_size=tokenizer.vocab_size,
    max_seq_len=dataset.seq_len,
    d_model=128,
    num_heads=4,
    num_layers=4,
    d_ff=256,
    dropout=0.1,
)
model = ValueMaskedTransformer(config, pad_token_id=tokenizer.pad_id)

In [9]:
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [12]:
accelerator = Accelerator()
trainer = ValueMaskedTrainer(model, accelerator=accelerator)
history = trainer.fit(loader, optimizer, epochs=100)