# Showcasing Protoplast Checkpointing in Cell-line Classification Model

## 1. Introduction

This notebook showcases the checkpointing feature in PROTOplast, which enables resuming model training even after interruptions or switching to a different dataset. It demonstrates how to save and load training checkpoints, making it easy to continue model development without starting from scratch. This is particularly useful for long training sessions, experimentation with various datasets, or training across multiple sessions or environments.

In [1]:
import anndata
import glob
import numpy as np
import pandas as pd
import os
import pathlib
import protoplast as pt
import ray
import torch

from anndata.experimental import AnnCollection
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.trainer import RayTrainRunner
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb, DistributedCellLineAnnDataset

from ray.train import Checkpoint
from ray.train.lightning import RayDDPStrategy

✓ Applied AnnDataFileManager patch
✓ Applied AnnDataFileManager patch


## 2. Dataset pre-processing

We begin by reading the two datasets used to train the cell-line classification model in this notebook. To ensure compatibility, the model requires that both datasets have the same output dimensions

In the following section, we create a unified view by performing an **inner join** on the two datasets based on shared features. During this step, we:

- Identify and record the **number of output classes** (cell-lines),
- Extract the list of **cell-line** of both dataset.

This alignment is essential to ensure the model receives a consistent input/output structure regardless of the dataset source.

In [2]:
DS_PATHS = ["/mnt/hdd2/tan/tahoe100m/plate1_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad",
           "/mnt/hdd2/tan/tahoe100m/plate2_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
adatas = [anndata.io.read_h5ad(p, backed = "r") for p in DS_PATHS]

In [3]:
# Create a view of all dataset
collection = AnnCollection(adatas, join_vars = "inner")

# Record the cell-lines (output classes) in both datasets
cell_lines = collection.obs.cell_line.unique().tolist()
cell_lines_count = collection.obs.cell_line.nunique()

## 3. Configure training step

In [4]:
thread_per_worker = 12
test_size = 0.1 
val_size = 0.1

## 4. Train on `plate1_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab` dataset

In [5]:
plate1_adata = adatas[0]

In [6]:
plate1_adata.obs.head(n = 5)

Unnamed: 0_level_0,sample,gene_count,tscp_count,mread_count,drugname_drugconc,drug,cell_line,sublibrary,BARCODE,pcnt_mito,S_score,G2M_score,phase,pass_filter,cell_name,plate
BARCODE_SUB_LIB_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
01_001_025-lib_841,smp_1495,1676,2441,2892,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0131,lib_841,01_001_025,0.025399,-0.066667,-0.095055,G1,full,A-172,plate1
01_001_026-lib_841,smp_1495,1657,2454,2925,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0480,lib_841,01_001_026,0.042787,0.128571,0.650549,G2M,full,PANC-1,plate1
01_001_048-lib_841,smp_1495,1749,2521,2963,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0293,lib_841,01_001_048,0.056724,0.242857,0.308791,G2M,full,HEC-1-A,plate1
01_001_076-lib_841,smp_1495,834,1038,1258,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0397,lib_841,01_001_076,0.066474,0.009524,0.245788,G2M,full,LS 180,plate1
01_001_088-lib_841,smp_1495,1275,1710,2006,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_1097,lib_841,01_001_088,0.028655,-0.1,-0.085348,G1,full,C32,plate1


In [7]:
# Set up training
trainer = RayTrainRunner(
    LinearClassifier,
    DistributedCellLineAnnDataset,
    model_keys = ["num_genes",
                  "num_classes"],
    metadata_cb = cell_line_metadata_cb,
    sparse_keys = "X"
)

2025-09-22 02:23:25,243	INFO worker.py:1951 -- Started a local Ray instance.


[36m(TrainTrainable pid=203998)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=203998)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=204121)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=203998)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=203998)[0m - (node_id=6a13510f495dd252e402f8bfa2d7f5008bbf3e59edf6d92a68fd6167, ip=192.168.1.226, pid=204121) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=204121)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=204121)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=204121)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=204121)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=204121)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=204121)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=204121)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.1 ...
[36m(RayTrainWorker pid=204121)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, 



[36m(RayTrainWorker pid=204121)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=204121)[0m 
[36m(RayTrainWorker pid=204121)[0m   | Name    | Type             | Params | Mode 
[36m(RayTrainWorker pid=204121)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=204121)[0m 0 | model   | Linear           | 3.1 M  | train
[36m(RayTrainWorker pid=204121)[0m 1 | loss_fn | CrossEntropyLoss | 0      | train
[36m(RayTrainWorker pid=204121)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=204121)[0m 3.1 M     Trainable params
[36m(RayTrainWorker pid=204121)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=204121)[0m 3.1 M     Total params
[36m(RayTrainWorker pid=204121)[0m 12.542    Total estimated model params size (MB)
[36m(RayTrainWorker pid=204121)[0m 2         Modules in train mode
[36m(RayTrainWorker pid=204121)[0m 0         Modules in eval mode
[36m(RayTrainWorker pid=204121

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  5.44it/s]


[36m(RayTrainWorker pid=204121)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=204121)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


                                                                           
Epoch 0:   0%|          | 0/4329 [00:00<?, ?it/s]
Epoch 0:   0%|          | 2/4329 [00:13<8:10:07,  0.15it/s, v_num=0, train_loss=3.550] 
Epoch 0:   0%|          | 9/4329 [00:15<2:02:00,  0.59it/s, v_num=0, train_loss=1.320]
Epoch 0:   0%|          | 10/4329 [00:15<1:49:53,  0.66it/s, v_num=0, train_loss=0.780]
Epoch 0:   0%|          | 18/4329 [00:15<1:01:21,  1.17it/s, v_num=0, train_loss=0.446]
Epoch 0:   0%|          | 19/4329 [00:15<58:09,  1.23it/s, v_num=0, train_loss=0.481]  
Epoch 0:   1%|          | 27/4329 [00:15<41:08,  1.74it/s, v_num=0, train_loss=0.231]
Epoch 0:   1%|          | 28/4329 [00:15<39:41,  1.81it/s, v_num=0, train_loss=0.306]
Epoch 0:   1%|          | 29/4329 [00:15<38:20,  1.87it/s, v_num=0, train_loss=0.458]
Epoch 0:   1%|          | 37/4329 [00:15<30:12,  2.37it/s, v_num=0, train_loss=0.191]
Epoch 0:   1%|          | 38/4329 [00:15<29:25,  2.43it/s, v_num=0, train_loss=0.209]
Epoch

[36m(RayTrainWorker pid=204121)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/nam/protoplast_results/TorchTrainer_2025-09-22_02-23-39/TorchTrainer_2679d_00000_0_2025-09-22_02-23-39/checkpoint_000000)
[36m(RayTrainWorker pid=204121)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4329/4329 [02:15<00:00, 31.95it/s, v_num=0, train_loss=0.0386]
Epoch 0: 100%|██████████| 4329/4329 [02:15<00:00, 31.93it/s, v_num=0, train_loss=0.0386]




In [None]:
result = trainer.train([DS_PATHS[0]],
                       batch_size = 1024,
                       test_size = test_size, 
                       val_size = val_size,
                       num_workers = 1,
                       resource_per_worker = {"GPU": 1, "CPU": thread_per_worker})

## 5. Train on `plate2_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab` dataset

We now have a checkpoint saved after training the classification model using the first dataset. We need to pass into `train()` the path to the checkpoint file.

In [9]:
plate2_adata = adatas[1]

In [10]:
plate2_adata.obs.head(n = 5)

Unnamed: 0_level_0,sample,gene_count,tscp_count,mread_count,drugname_drugconc,drug,cell_line,sublibrary,BARCODE,pcnt_mito,S_score,G2M_score,phase,pass_filter,cell_name,plate
BARCODE_SUB_LIB_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
01_001_053-lib_1000,smp_1591,2671,5629,6830,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_1119,lib_1000,01_001_053,0.016522,-0.265873,-0.313553,G1,full,CFPAC-1,plate2
01_001_082-lib_1000,smp_1591,2148,3173,3826,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0292,lib_1000,01_001_082,0.025843,0.400794,0.520879,G2M,full,HCT15,plate2
01_001_145-lib_1000,smp_1591,683,886,1073,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_1098,lib_1000,01_001_145,0.029345,-0.019841,-0.032967,G1,full,HepG2/C3A,plate2
01_001_175-lib_1000,smp_1591,1845,2786,3368,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0131,lib_1000,01_001_175,0.031587,-0.123016,-0.118498,G1,full,A-172,plate2
01_001_181-lib_1000,smp_1591,1228,1849,2226,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0399,lib_1000,01_001_181,0.015143,0.02381,-0.008791,S,full,LoVo,plate2


In [12]:
# Set up training
trainer = RayTrainRunner(
    LinearClassifier,
    DistributedCellLineAnnDataset,
    model_keys = ["num_genes",
                  "num_classes"],
    metadata_cb = cell_line_metadata_cb,
    sparse_keys = "X"
)

2025-09-22 02:29:38,621	INFO worker.py:1951 -- Started a local Ray instance.


[36m(TrainTrainable pid=211376)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=211376)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=211522)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=211376)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=211376)[0m - (node_id=3324969ef194e9e5d3e410c253ef86193a325fd59630fe3965868918, ip=192.168.1.226, pid=211522) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=211522)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=211522)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=211522)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=211522)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=211522)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=211522)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=211522)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.1 ...
[36m(RayTrainWorker pid=211522)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=211522)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=211522)[0m 
[36m(RayTrainWorker pid=211522)[0m   | Name    | Type             | Params | Mode 
[36m(RayTrainWorker pid=211522)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=211522)[0m 0 | model   | Linear           | 3.1 M  | train
[36m(RayTrainWorker pid=211522)[0m 1 | loss_fn | CrossEntropyLoss | 0      | train
[36m(RayTrainWorker pid=211522)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=211522)[0m 3.1 M     Trainable params
[36m(RayTrainWorker pid=211522)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=211522)[0m 3.1 M     Total params
[36m(RayTrainWorker pid=211522)[0m 12.542    Total estimated model params size (MB)
[36m(RayTrainWorker pid=211522)[0m 2         Modules in train mode
[36m(RayTrainWorker pid=211522)[0m 0         Modules in eval mode
[36m(RayTrainWorker pid=211522

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
                                                                           


[36m(RayTrainWorker pid=211522)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=211522)[0m   return torch.sparse_csr_tensor(


Epoch 0:   0%|          | 0/6368 [00:00<?, ?it/s]
Epoch 0:   0%|          | 2/6368 [00:23<20:24:12,  0.09it/s, v_num=0, train_loss=3.560]
Epoch 0:   0%|          | 3/6368 [00:25<14:50:02,  0.12it/s, v_num=0, train_loss=3.150]
Epoch 0:   0%|          | 12/6368 [00:26<3:53:11,  0.45it/s, v_num=0, train_loss=0.511]
Epoch 0:   0%|          | 13/6368 [00:26<3:35:17,  0.49it/s, v_num=0, train_loss=2.080]
Epoch 0:   0%|          | 22/6368 [00:26<2:07:35,  0.83it/s, v_num=0, train_loss=0.612]
Epoch 0:   0%|          | 31/6368 [00:26<1:30:49,  1.16it/s, v_num=0, train_loss=0.509]
Epoch 0:   1%|          | 39/6368 [00:26<1:12:22,  1.46it/s, v_num=0, train_loss=0.243]
Epoch 0:   1%|          | 40/6368 [00:26<1:10:35,  1.49it/s, v_num=0, train_loss=0.166]
Epoch 0:   1%|          | 49/6368 [00:26<57:46,  1.82it/s, v_num=0, train_loss=0.781]  
Epoch 0:   1%|          | 50/6368 [00:26<56:38,  1.86it/s, v_num=0, train_loss=0.194]
Epoch 0:   1%|          | 57/6368 [00:27<49:50,  2.11it/s, v_num=0, trai

[36m(RayTrainWorker pid=211522)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/nam/protoplast_results/TorchTrainer_2025-09-22_02-29-56/TorchTrainer_07872_00000_0_2025-09-22_02-29-56/checkpoint_000000)
[36m(RayTrainWorker pid=211522)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 6368/6368 [03:19<00:00, 31.92it/s, v_num=0, train_loss=0.277]




In [None]:
ckpt_path = os.path.join(result.checkpoint.path, "checkpoint.ckpt")

result = trainer.train([DS_PATHS[1]],
                       batch_size = 1024,
                       test_size = test_size, 
                       val_size = val_size,
                       num_workers = 1,
                       resource_per_worker = {"GPU": 1, "CPU": thread_per_worker})

### Conclusion

This brings us to the end of the tutorial notebook.

This workflow highlights using checkpointing in **PROTOplast**, enabling efficient model development across diverse datasets.

Feel free to explore and extend this notebook to suit your own data and use cases!