# Showcasing Protoplast Checkpointing in Cell-line Classification Model

## 1. Introduction

This notebook showcases the checkpointing feature in PROTOplast, which enables resuming model training after finishing one dataset & switching to another. It demonstrates how to save and load training checkpoints, making it easy to continue model development without starting from scratch. This is particularly useful for long training sessions, experimentation with various datasets, or training across multiple sessions or environments.

In [1]:
import anndata
import glob
import numpy as np
import pandas as pd
import os
import pathlib
import protoplast as pt
import ray
import torch

from anndata.experimental import AnnCollection
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.trainer import RayTrainRunner
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb, DistributedCellLineAnnDataset

from ray.train import Checkpoint
from ray.train.lightning import RayDDPStrategy

  import pynvml  # type: ignore[import]


✓ Applied AnnDataFileManager patch
✓ Applied AnnDataFileManager patch


## 2. Dataset pre-processing

We begin by reading the two datasets used to train the cell-line classification model in this notebook. To ensure compatibility, the model requires that both datasets have the same output dimensions

In the following section, we create a unified view by performing an **inner join** on the two datasets based on shared features. During this step, we:

- Identify and record the **number of output classes** (cell-lines),
- Extract the list of **cell-line** of both dataset.

This alignment is essential to ensure the model receives a consistent input/output structure regardless of the dataset source.

In [2]:
DS_PATHS = ["/mnt/hdd2/tan/tahoe100m/plate1_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad",
           "/mnt/hdd2/tan/tahoe100m/plate2_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
adatas = [anndata.io.read_h5ad(p, backed = "r") for p in DS_PATHS]

In [3]:
# Create a view of all dataset
collection = AnnCollection(adatas, join_vars = "inner")

# Record the cell-lines (output classes) in both datasets
cell_lines = collection.obs.cell_line.unique().tolist()
cell_lines_count = collection.obs.cell_line.nunique()

## 3. Configure training step

In [4]:
thread_per_worker = 12
test_size = 0.0 # We don't have the test step in the model, so we can set this to 0
val_size = 0.2

## 4. Train on `plate1_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab` dataset

In [5]:
plate1_adata = adatas[0]

In [6]:
plate1_adata.obs.head(n = 5)

Unnamed: 0_level_0,sample,gene_count,tscp_count,mread_count,drugname_drugconc,drug,cell_line,sublibrary,BARCODE,pcnt_mito,S_score,G2M_score,phase,pass_filter,cell_name,plate
BARCODE_SUB_LIB_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
01_001_025-lib_841,smp_1495,1676,2441,2892,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0131,lib_841,01_001_025,0.025399,-0.066667,-0.095055,G1,full,A-172,plate1
01_001_026-lib_841,smp_1495,1657,2454,2925,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0480,lib_841,01_001_026,0.042787,0.128571,0.650549,G2M,full,PANC-1,plate1
01_001_048-lib_841,smp_1495,1749,2521,2963,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0293,lib_841,01_001_048,0.056724,0.242857,0.308791,G2M,full,HEC-1-A,plate1
01_001_076-lib_841,smp_1495,834,1038,1258,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0397,lib_841,01_001_076,0.066474,0.009524,0.245788,G2M,full,LS 180,plate1
01_001_088-lib_841,smp_1495,1275,1710,2006,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_1097,lib_841,01_001_088,0.028655,-0.1,-0.085348,G1,full,C32,plate1


In [11]:
# Set up training
trainer = RayTrainRunner(
    LinearClassifier,
    DistributedCellLineAnnDataset,
    model_keys = ["num_genes",
                  "num_classes"],
    metadata_cb = cell_line_metadata_cb,
    sparse_key = "X"
)

2025-09-24 08:30:40,605	INFO worker.py:1951 -- Started a local Ray instance.
[36m(pid=1092253)[0m   import pynvml  # type: ignore[import]


[36m(TrainTrainable pid=1092253)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1092253)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1092401)[0m   import pynvml  # type: ignore[import]
[36m(RayTrainWorker pid=1092401)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1092253)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1092253)[0m - (node_id=ebdeee0faec90155601309a734c802e76b452887457840f4cbbf9649, ip=192.168.1.226, pid=1092401) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1092401)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1092401)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1092401)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1092401)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1092401)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1092401)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1092401)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.1 ...
[36m(RayTrainWorker pid=1092401)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize 



[36m(RayTrainWorker pid=1092401)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1092401)[0m 
[36m(RayTrainWorker pid=1092401)[0m   | Name    | Type             | Params | Mode 
[36m(RayTrainWorker pid=1092401)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1092401)[0m 0 | model   | Linear           | 3.1 M  | train
[36m(RayTrainWorker pid=1092401)[0m 1 | loss_fn | CrossEntropyLoss | 0      | train
[36m(RayTrainWorker pid=1092401)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1092401)[0m 3.1 M     Trainable params
[36m(RayTrainWorker pid=1092401)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=1092401)[0m 3.1 M     Total params
[36m(RayTrainWorker pid=1092401)[0m 12.542    Total estimated model params size (MB)
[36m(RayTrainWorker pid=1092401)[0m 2         Modules in train mode
[36m(RayTrainWorker pid=1092401)[0m 0         Modules in eval mode
[36m(RayTrainWork

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  6.53it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 12.06it/s]
                                                                           


[36m(RayTrainWorker pid=1092401)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4224 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4224 [00:15<18:07:27,  0.06it/s, v_num=0, train_loss=3.980]
Epoch 0:   0%|          | 4/4224 [00:15<4:38:12,  0.25it/s, v_num=0, train_loss=2.700] 
.
.
.
Epoch 0: 100%|█████████▉| 4223/4224 [01:50<00:00, 38.25it/s, v_num=0, train_loss=0.158]
Epoch 0: 100%|██████████| 4224/4224 [01:50<00:00, 38.26it/s, v_num=0, train_loss=0.065]

[36m(RayTrainWorker pid=1092401)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/960 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/960 [00:00<00:03, 251.17it/s][A
Validation DataLoader 0:   0%|          | 2/960 [00:00<00:32, 29.05it/s] [A
.
.
.
Validation DataLoader 0: 100%|█████████▉| 957/960 [00:17<00:00, 53.84it/s][A
Validation DataLoader 0: 100%|█████████▉| 958/960 [00:17<00:00, 53.86it/s][A
Validation DataLoader 0: 100%|█████████▉| 959/960 [00:17<00:00, 53.87it/s][A
Validation DataLoader 0: 100%|██████████|

[36m(RayTrainWorker pid=1092401)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/nam/protoplast_results/TorchTrainer_2025-09-24_08-30-55/TorchTrainer_c9b6d_00000_0_2025-09-24_08-30-55/checkpoint_000000)
[36m(RayTrainWorker pid=1092401)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


In [12]:
result = trainer.train([DS_PATHS[0]],
                       batch_size = 1024,
                       test_size = test_size, 
                       val_size = val_size,
                       num_workers = 1,
                       thread_per_worker = thread_per_worker)

Using 1 workers with {'CPU': 12} each


2025-09-24 08:30:55,062	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 10.40 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 08:30:55 (running for 00:00:00.13)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_08-30-36_061134_1073100/artifacts/2025-09-24_08-30-55/TorchTrainer_2025-09-24_08-30-55/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 08:33:56 (running for 00:03:01.42)
Using FIFO scheduling algorithm.
Logical resource usage: 13.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_08-30-36_061134_1073100/artifacts/2025-09-24_08-30-55/TorchTrainer_2025-09-24_08-30-55/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2025-09-24 08:33:57,070	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/nam/protoplast_results/TorchTrainer_2025-09-24_08-30-55' in 0.0064s.
2025-09-24 08:33:57,075	INFO tune.py:1041 -- Total run time: 182.01 seconds (181.99 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 08:33:57 (running for 00:03:01.99)
Using FIFO scheduling algorithm.
Logical resource usage: 13.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_08-30-36_061134_1073100/artifacts/2025-09-24_08-30-55/TorchTrainer_2025-09-24_08-30-55/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)




## 5. Train on `plate2_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab` dataset

We now have a checkpoint saved after training the classification model using the first dataset. We need to pass into `train()` the path to the checkpoint file.

In [14]:
plate2_adata = adatas[1]

In [15]:
plate2_adata.obs.head(n = 5)

Unnamed: 0_level_0,sample,gene_count,tscp_count,mread_count,drugname_drugconc,drug,cell_line,sublibrary,BARCODE,pcnt_mito,S_score,G2M_score,phase,pass_filter,cell_name,plate
BARCODE_SUB_LIB_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
01_001_053-lib_1000,smp_1591,2671,5629,6830,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_1119,lib_1000,01_001_053,0.016522,-0.265873,-0.313553,G1,full,CFPAC-1,plate2
01_001_082-lib_1000,smp_1591,2148,3173,3826,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0292,lib_1000,01_001_082,0.025843,0.400794,0.520879,G2M,full,HCT15,plate2
01_001_145-lib_1000,smp_1591,683,886,1073,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_1098,lib_1000,01_001_145,0.029345,-0.019841,-0.032967,G1,full,HepG2/C3A,plate2
01_001_175-lib_1000,smp_1591,1845,2786,3368,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0131,lib_1000,01_001_175,0.031587,-0.123016,-0.118498,G1,full,A-172,plate2
01_001_181-lib_1000,smp_1591,1228,1849,2226,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0399,lib_1000,01_001_181,0.015143,0.02381,-0.008791,S,full,LoVo,plate2


In [16]:
# Set up training
trainer = RayTrainRunner(
    LinearClassifier,
    DistributedCellLineAnnDataset,
    model_keys = ["num_genes",
                  "num_classes"],
    metadata_cb = cell_line_metadata_cb,
    sparse_key = "X"
)

2025-09-24 08:34:02,725	INFO worker.py:1951 -- Started a local Ray instance.
[36m(pid=1100011)[0m   import pynvml  # type: ignore[import]


[36m(TrainTrainable pid=1100011)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1100011)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1100199)[0m   import pynvml  # type: ignore[import]
[36m(RayTrainWorker pid=1100199)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1100011)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1100011)[0m - (node_id=54aaa5046cdd2adea85166bb856363fac49613c2efcd5479680fcaae, ip=192.168.1.226, pid=1100199) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1100199)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1100199)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1100199)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1100199)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1100199)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1100199)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1100199)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.1 ...
[36m(RayTrainWorker pid=1100199)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize 



[36m(RayTrainWorker pid=1100199)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1100199)[0m 
[36m(RayTrainWorker pid=1100199)[0m   | Name    | Type             | Params | Mode 
[36m(RayTrainWorker pid=1100199)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1100199)[0m 0 | model   | Linear           | 3.1 M  | train
[36m(RayTrainWorker pid=1100199)[0m 1 | loss_fn | CrossEntropyLoss | 0      | train
[36m(RayTrainWorker pid=1100199)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1100199)[0m 3.1 M     Trainable params
[36m(RayTrainWorker pid=1100199)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=1100199)[0m 3.1 M     Total params
[36m(RayTrainWorker pid=1100199)[0m 12.542    Total estimated model params size (MB)
[36m(RayTrainWorker pid=1100199)[0m 2         Modules in train mode
[36m(RayTrainWorker pid=1100199)[0m 0         Modules in eval mode
[36m(RayTrainWork

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  6.32it/s]


Epoch 0:   0%|          | 0/6144 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/6144 [00:23<40:24:33,  0.04it/s, v_num=0, train_loss=4.090]
Epoch 0:   0%|          | 5/6144 [00:24<8:14:43,  0.21it/s, v_num=0, train_loss=2.480] 
.
.
.
Epoch 0: 100%|█████████▉| 6134/6144 [02:49<00:00, 36.16it/s, v_num=0, train_loss=0.343] 
Epoch 0: 100%|█████████▉| 6142/6144 [02:49<00:00, 36.18it/s, v_num=0, train_loss=0.124] 
Epoch 0: 100%|██████████| 6144/6144 [02:49<00:00, 36.19it/s, v_num=0, train_loss=0.249] 

[36m(RayTrainWorker pid=1100199)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1536 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/1536 [00:00<00:13, 111.01it/s][A
Validation DataLoader 0:   0%|          | 2/1536 [00:00<00:18, 83.82it/s] [A
.
.
.
Validation DataLoader 0: 100%|█████████▉| 1534/1536 [00:33<00:00, 46.39it/s][A
Validation DataLoader 0: 100%|█████████▉| 1535/1536 [00:33<00:00, 46.40it/s][A
Validation DataLoader

[36m(RayTrainWorker pid=1100199)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/nam/protoplast_results/TorchTrainer_2025-09-24_08-34-24/TorchTrainer_466ae_00000_0_2025-09-24_08-34-24/checkpoint_000000)
[36m(RayTrainWorker pid=1100199)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


In [17]:
ckpt_path = os.path.join(result.checkpoint.path, "checkpoint.ckpt")

result = trainer.train([DS_PATHS[1]],
                       batch_size = 1024,
                       test_size = test_size, 
                       val_size = val_size,
                       num_workers = 1,
                       thread_per_worker = thread_per_worker,
                       resource_per_worker = {"GPU": 1, "CPU": thread_per_worker})

Using 1 workers with {'GPU': 1, 'CPU': 12} each


2025-09-24 08:34:24,282	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 17.42 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 08:34:24 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_08-33-59_087314_1073100/artifacts/2025-09-24_08-34-24/TorchTrainer_2025-09-24_08-34-24/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 08:34:39 (running for 00:00:15.27)
Using FIFO scheduling algorithm.
Logical resource usage: 13.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_08-33-59_087314_1073100/artifacts/2025-09-24_08-34-24/TorchTrainer_2025-09-24_08-34-24/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2025-09-24 08:38:56,148	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/nam/protoplast_results/TorchTrainer_2025-09-24_08-34-24' in 0.0116s.
2025-09-24 08:38:56,163	INFO tune.py:1041 -- Total run time: 271.88 seconds (271.84 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 08:38:56 (running for 00:04:31.86)
Using FIFO scheduling algorithm.
Logical resource usage: 13.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_08-33-59_087314_1073100/artifacts/2025-09-24_08-34-24/TorchTrainer_2025-09-24_08-34-24/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)




### Conclusion

This brings us to the end of the tutorial notebook.

This workflow highlights using checkpointing in **PROTOplast**, enabling efficient model development across diverse datasets.

Feel free to explore and extend this notebook to suit your own data and use cases!