# Showcasing Protoplast Checkpointing in Perturbation Model

## 1. Introduction

This notebook showcases the checkpointing feature in PROTOplast, which enables resuming model training even after interruptions or switching to a different dataset. It demonstrates how to save and load training checkpoints, making it easy to continue model development without starting from scratch. This is particularly useful for long training sessions, experimentation with various datasets, or training across multiple sessions or environments.

In [1]:
import anndata
import glob
import numpy as np
import pandas as pd
import os
import pathlib
import protoplast as pt
import ray
import torch

from anndata.experimental import AnnCollection
from protoplast.scrna.anndata.trainer import RayTrainRunner
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb

from ray.train import Checkpoint
from ray.train.lightning import RayDDPStrategy

✓ Applied AnnDataFileManager patch
✓ Applied AnnDataFileManager patch


## 2. Dataset pre-processing

We begin by reading the two datasets used to train the perturbation model in this notebook. To ensure compatibility, the model requires that both datasets share a common set of features (e.g., genes).

In the following section, we create a unified view by performing an **inner join** on the two datasets based on shared features. During this step, we:

- Identify and record the **number of overlapping genes** (shared features),
- Capture the **indices** of these shared genes in each dataset,
- Extract the list of **perturbed genes** specific to each dataset,
- And prepare **metadata** necessary for consistent training across datasets.

This alignment is essential to ensure the model receives a consistent input/output structure regardless of the dataset source.

In [2]:
DS_PATHS = ["/mnt/hdd2/nam/hct116.h5ad",
           "/mnt/hdd2/tan/competition_support_set/competition_train.h5"]
adatas = [anndata.io.read_h5ad(p, backed = "r") for p in DS_PATHS]

In [3]:
for idx, adata in enumerate(adatas):
    if idx == 0:
        # 
        adata.obs["target_gene"] = adata.obs[["gene_target"]]
        adata.obs["cell_type"] = "HCT116"
    else:
        adata.obs = adata.obs[["target_gene", "cell_type"]]

In [4]:
# Create a view of all dataset
ds_view = AnnCollection(adatas, join_vars = "inner")

# Record the genes shared by the training datasets
n_genes = ds_view.n_vars
genes = ds_view.var_names.tolist()
perts = ds_view.obs["target_gene"].unique().tolist()
cell_types = ds_view.obs["cell_type"].unique().tolist()

print("Number of genes", n_genes)

Number of genes 18080


In [5]:
# Include the indices of the shared variables in the anndata object to help transform
# yieled data batch later in training step
shared_vars = {}
for idx_i, adata_i in enumerate(adatas):
    shared_vars[idx_i] = np.where(np.isin(adata_i.var_names, genes))[0]

## 3. Define model & configure training step

In [6]:
thread_per_worker = 48
test_size = 0.2 
val_size = 0.0 # if you have only training and test data, just put val_size = 0.0

In [7]:
from state.tx.models.perturb_mean import PerturbMeanPerturbationModel # import original class

class PerturbMeanGlobalModel(PerturbMeanPerturbationModel):
    def __init__(self, *wargs, **kwargs):
        kwargs['gene_decoder_bool'] = False
        super().__init__(*wargs, **kwargs)

    """
    Extended class of PerturbMeanPerturbationModel where prediction ignores
    per-cell control embedding and uses only global basal + offset.
    """

    def forward(self, batch: dict) -> torch.Tensor:
        B = len(batch["pert_name"])
        device = self.dummy_param.device
        pred_out = torch.zeros((B, self.output_dim), device=device)

        for i in range(B):
            p_name = str(batch["pert_name"][i])
            offset_vec = self.pert_mean_offsets.get(p_name)
            if offset_vec is None:
                offset_vec = torch.zeros(self.output_dim, device=device)

            # Use global basal instead of batch["ctrl_cell_emb"]
            pred_out[i] = self.global_basal.to(device) + offset_vec.to(device)

        return pred_out

## 4. Train on `HCT116_filtered_dual_guide_cells` dataset

We first train this perturbation model on a dataset that contains HCT116 cells. We need to define a callback function to set up a metadata for the perturbation model based on the dataset's attributes.

In [8]:
hct116_adata = anndata.read_h5ad(DS_PATHS[0], backed = "r")

In [9]:
hct116_adata.obs.head(n = 5)

Unnamed: 0,sample,num_features,guide_target,gene_target,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt,pass_guide_filter,target_gene,cell_type,batch_var
AAACCAAAGACATGTT-HCT116_Batch1,HCT116_Batch1,2,ST14_P1P2-1|ST14_P1P2-2,ST14,4883,19136.0,1179.0,6.161162,True,ST14,hct116,HCT116_Batch1
AAACCAAAGACCCAAC-HCT116_Batch1,HCT116_Batch1,2,SIGLEC5_P1P2-1|SIGLEC5_P1P2-2,SIGLEC5,8130,47916.0,1562.0,3.259871,True,SIGLEC5,hct116,HCT116_Batch1
AAACCAAAGAGGTACG-HCT116_Batch1,HCT116_Batch1,2,VSNL1_P1P2-1|VSNL1_P1P2-2,VSNL1,6531,28435.0,1042.0,3.664498,True,VSNL1,hct116,HCT116_Batch1
AAACCAAAGCGATTAT-HCT116_Batch1,HCT116_Batch1,2,KCNK7_P1P2-1|KCNK7_P1P2-2,KCNK7,5931,26080.0,1087.0,4.167945,True,KCNK7,hct116,HCT116_Batch1
AAACCAAAGGCTTAAT-HCT116_Batch1,HCT116_Batch1,2,APOA4_P1P2-1|APOA4_P1P2-2,APOA4,7157,38366.0,955.0,2.489183,True,APOA4,hct116,HCT116_Batch1


In [10]:
def hct116_perturbmean_metadata_cb(ad: anndata.AnnData, metadata: dict):
    metadata["input_dim"] = n_genes
    metadata["output_dim"] = n_genes
    metadata["hidden_dim"] = 10
    metadata["pert_dim"] = len(perts)
    metadata["lr"] = 1e-3  

    metadata["gene_names"] = genes
    metadata["pert_names"] = perts
    metadata["cell_types"] = cell_types

    metadata["control_pert"] = "Non-Targeting"
    metadata["embed_key"] = "X"
    metadata["output_space"] = "gene"

In [11]:
class HCT116_PerturbAnnDataset(DistributedAnnDataset):
    def transform(self, start: int, end: int):
        X = super().transform(start, end)

        # Subset X matrix to include only genes appear in all dataset
        # Need to densify the data
        X = X.to_dense()[:, shared_vars[0]]

        # Metadata froms self.ad
        pert_names = self.ad.obs["gene_target"].iloc[start:end].astype(str).to_list()
        cell_lines = ["HCT116"] * len(pert_names)

        return {
            "pert_name": pert_names,
            "cell_type": cell_lines,
            "pert_cell_counts": X,
            "pert_cell_emb": X,
        }

In [12]:
# Set up training
PerturbMeanPerturbationModel_trainer = RayTrainRunner(
    PerturbMeanGlobalModel,
    HCT116_PerturbAnnDataset,
    ["input_dim",
    "output_dim",
    "hidden_dim",      
    "pert_dim",        
    "lr",
    "control_pert",    # "Non-Tageting"
    "embed_key",       
    "output_space",    # "gene"
    ],
    metadata_cb = hct116_perturbmean_metadata_cb,
    sparse_keys = "X"
)

2025-09-19 03:27:45,030	INFO worker.py:1951 -- Started a local Ray instance.


[36m(TrainTrainable pid=198522)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=198522)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=198782)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=198522)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=198522)[0m - (node_id=49b56875e4b5a843fc7ea0f31d82a4dd8fd27def9d910ce27d0768fd, ip=192.168.1.226, pid=198782) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=198782)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=198782)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=198782)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=198782)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=198782)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=198782)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=198782)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision




[36m(RayTrainWorker pid=198782)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=198782)[0m 
[36m(RayTrainWorker pid=198782)[0m   | Name         | Type    | Params | Mode 
[36m(RayTrainWorker pid=198782)[0m -------------------------------------------------
[36m(RayTrainWorker pid=198782)[0m 0 | loss_fn      | MSELoss | 0      | train
[36m(RayTrainWorker pid=198782)[0m   | other params | n/a     | 1      | n/a  
[36m(RayTrainWorker pid=198782)[0m -------------------------------------------------
[36m(RayTrainWorker pid=198782)[0m 1         Trainable params
[36m(RayTrainWorker pid=198782)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=198782)[0m 1         Total params
[36m(RayTrainWorker pid=198782)[0m 0.000     Total estimated model params size (MB)
[36m(RayTrainWorker pid=198782)[0m 1         Modules in train mode
[36m(RayTrainWorker pid=198782)[0m 0         Modules in eval mode


                                                  
Epoch 0:   0%|          | 0/125 [00:00<?, ?it/s]
Epoch 0:   6%|▌         | 7/125 [00:00<00:09, 12.28it/s, v_num=0, train_loss=0.0167]
Epoch 0:  11%|█         | 14/125 [00:00<00:05, 20.71it/s, v_num=0, train_loss=0.0169]
Epoch 0:  12%|█▏        | 15/125 [00:00<00:05, 21.75it/s, v_num=0, train_loss=0.0117]
Epoch 0:  18%|█▊        | 23/125 [00:00<00:03, 28.44it/s, v_num=0, train_loss=0.0155]
Epoch 0:  23%|██▎       | 29/125 [00:00<00:03, 31.73it/s, v_num=0, train_loss=0.0195]
Epoch 0:  24%|██▍       | 30/125 [00:00<00:02, 32.32it/s, v_num=0, train_loss=0.0136]
Epoch 0:  30%|██▉       | 37/125 [00:01<00:02, 35.40it/s, v_num=0, train_loss=0.0158]
Epoch 0:  34%|███▍      | 43/125 [00:01<00:02, 37.53it/s, v_num=0, train_loss=0.0162]
Epoch 0:  35%|███▌      | 44/125 [00:01<00:02, 37.84it/s, v_num=0, train_loss=0.0124]
Epoch 0:  41%|████      | 51/125 [00:01<00:01, 40.15it/s, v_num=0, train_loss=0.0231]
Epoch 0:  42%|████▏     | 52/125 [00:01<0

[36m(RayTrainWorker pid=198782)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/nam/protoplast_results/TorchTrainer_2025-09-19_03-27-49/TorchTrainer_9e62d_00000_0_2025-09-19_03-27-49/checkpoint_000000)


Epoch 0: 100%|██████████| 125/125 [00:09<00:00, 13.39it/s, v_num=0, train_loss=0.0169]


[36m(RayTrainWorker pid=198782)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 125/125 [00:14<00:00,  8.49it/s, v_num=0, train_loss=0.0169]




In [None]:
result = PerturbMeanPerturbationModel_trainer.train([DS_PATHS[0]],
                                                    batch_size = 64,
                                                    test_size = test_size, 
                                                    val_size = val_size,
                                                    num_workers = 1,
                                                    resource_per_worker = {"GPU": 1, "CPU": thread_per_worker})

## 5. Train on `competition_train` dataset

We now have a checkpoint saved after training the perturbation model using the first dataset. To continue training on a different dataset, several adjustments are necessary to ensure compatibility and correct model behavior.

- We need to define a new **metadata callback function** that sets up the appropriate configurations for the model when training under the new dataset.

Since the second dataset may have different **input dimensions** or **metadata fields**, we also define a custom `AnnDataset` class. This class is responsible for transforming each training batch accordingly, ensuring:

- Features are mapped to the expected input space,
- Metadata is correctly aligned with the model's expectations,
- Any dataset-specific preprocessing is applied consistently.

In [15]:
competition_adata = anndata.read_h5ad(DS_PATHS[1], backed = "r")

In [16]:
competition_adata.obs.head(n = 5)

Unnamed: 0,target_gene,guide_id,batch,batch_var,cell_type
AAACAAGCAACCTTGTACTTTAGG-Flex_1_01,CHMP3,CHMP3_P1P2_A|CHMP3_P1P2_B,Flex_1_01,Flex_1_01,ARC_H1
AAACAAGCATTGCCGCACTTTAGG-Flex_1_01,AKT2,AKT2_P1P2_A|AKT2_P1P2_B,Flex_1_01,Flex_1_01,ARC_H1
AAACCAATCAATGTTCACTTTAGG-Flex_1_01,SHPRH,SHPRH_P1P2_A|SHPRH_P1P2_B,Flex_1_01,Flex_1_01,ARC_H1
AAACCAATCCCTCGCTACTTTAGG-Flex_1_01,TMSB4X,TMSB4X_P1_A|TMSB4X_P1_B,Flex_1_01,Flex_1_01,ARC_H1
AAACCAATCTAAATCCACTTTAGG-Flex_1_01,KLF10,KLF10_P2_A|KLF10_P2_B,Flex_1_01,Flex_1_01,ARC_H1


In [17]:
def competition_perturbmean_metadata_cb(ad: anndata.AnnData, metadata: dict):
    metadata["input_dim"] = n_genes
    metadata["output_dim"] = n_genes
    metadata["hidden_dim"] = 10
    metadata["pert_dim"] = len(perts)
    metadata["lr"] = 1e-3

    metadata["gene_names"] = genes
    metadata["pert_names"] = perts
    metadata["cell_types"] = cell_types

    metadata["control_pert"] = "non-targeting"
    metadata["embed_key"] = "X"
    metadata["output_space"] = "gene"

In [18]:
class Competition_PerturbAnnDataset(DistributedAnnDataset):
    def transform(self, start: int, end: int):
        X = super().transform(start, end)

        # Subset X matrix to include only genes appear in all dataset
        # Need to densify the data
        X = X.to_dense()[:, shared_vars[1]]

        # Metadata froms self.ad
        pert_names = self.ad.obs["target_gene"].iloc[start:end].astype(str).to_list()
        cell_lines = self.ad.obs["cell_type"].iloc[start:end].astype(str).to_list()

        return {
            "pert_name": pert_names,
            "cell_type": cell_lines,
            "pert_cell_counts": X,
            "pert_cell_emb": X,
        }

In [19]:
# Set up training
competition_trainer = RayTrainRunner(
    PerturbMeanGlobalModel,
    Competition_PerturbAnnDataset,
    ["input_dim",
    "output_dim",
    "hidden_dim",      
    "pert_dim",
    "lr",
    "control_pert",    # "non-targeting"
    "embed_key",       
    "output_space",    # "gene"
    ],
    metadata_cb = competition_perturbmean_metadata_cb,
    sparse_keys = "X"
)

2025-09-19 03:29:20,289	INFO worker.py:1951 -- Started a local Ray instance.


[36m(TrainTrainable pid=206826)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=206826)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=207034)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=206826)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=206826)[0m - (node_id=56af46918d2afaab8fe704d215fd39de3de026f40333c93c60883644, ip=192.168.1.226, pid=207034) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=207034)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=207034)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=207034)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=207034)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=207034)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=207034)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=207034)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[36m(RayTrainWorker pid=207034)[0m Restoring states from the checkpoint path at /home/nam/protoplast_results/TorchTrainer_2025-0

                                                  


[36m(RayTrainWorker pid=207034)[0m Restored all states from the checkpoint at /home/nam/protoplast_results/TorchTrainer_2025-09-19_03-27-49/TorchTrainer_9e62d_00000_0_2025-09-19_03-27-49/checkpoint_000000/checkpoint.ckpt


Epoch 1:   0%|          | 0/109 [00:00<?, ?it/s]
Epoch 1:   1%|          | 1/109 [00:07<14:13,  0.13it/s, v_num=0, train_loss=0.184]
Epoch 1:   2%|▏         | 2/109 [00:08<07:35,  0.23it/s, v_num=0, train_loss=0.173]
Epoch 1:   3%|▎         | 3/109 [00:09<05:20,  0.33it/s, v_num=0, train_loss=0.187]
Epoch 1:   4%|▎         | 4/109 [00:09<04:14,  0.41it/s, v_num=0, train_loss=0.177]
Epoch 1:   5%|▍         | 5/109 [00:10<03:32,  0.49it/s, v_num=0, train_loss=0.184]
Epoch 1:   6%|▌         | 6/109 [00:10<03:04,  0.56it/s, v_num=0, train_loss=0.175]
Epoch 1:   6%|▋         | 7/109 [00:11<02:43,  0.62it/s, v_num=0, train_loss=0.187]
Epoch 1:   7%|▋         | 8/109 [00:11<02:29,  0.68it/s, v_num=0, train_loss=0.178]
Epoch 1:   8%|▊         | 9/109 [00:12<02:17,  0.73it/s, v_num=0, train_loss=0.185]
Epoch 1:   9%|▉         | 10/109 [00:12<02:07,  0.78it/s, v_num=0, train_loss=0.175]
Epoch 1:  10%|█         | 11/109 [00:13<01:59,  0.82it/s, v_num=0, train_loss=0.178]
Epoch 1:  11%|█         |

[36m(RayTrainWorker pid=207034)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/nam/protoplast_results/TorchTrainer_2025-09-19_03-29-26/TorchTrainer_d7c14_00000_0_2025-09-19_03-29-26/checkpoint_000000)
[36m(RayTrainWorker pid=207034)[0m `Trainer.fit` stopped: `max_epochs=2` reached.


In [None]:
ckpt_path = os.path.join(result.checkpoint.path, "checkpoint.ckpt")

competition_trainer.train([DS_PATHS[1]],
                          max_epochs = 2,
                          batch_size = 2048, 
                          test_size = test_size, 
                          val_size = val_size,
                          num_workers = 1,
                          resource_per_worker = {"GPU": 1, "CPU": thread_per_worker},
                          ckpt_path = ckpt_path)

### Conclusion

This brings us to the end of the tutorial notebook.

This workflow highlights using checkpointing in **PROTOplast**, enabling efficient model development across diverse datasets.

Feel free to explore and extend this notebook to suit your own data and use cases!