# Showcasing Protoplast Checkpointing in Cell-line Classification Model

## 1. Introduction

This notebook showcases the checkpointing feature in PROTOplast, which enables resuming model training after finishing one dataset & switching to another. It demonstrates how to save and load training checkpoints, making it easy to continue model development without starting from scratch. This is particularly useful for long training sessions, experimentation with various datasets, or training across multiple sessions or environments.

In [1]:
import anndata
import glob
import numpy as np
import pandas as pd
import os
import pathlib
import protoplast as pt
import ray
import torch

from anndata.experimental import AnnCollection
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.trainer import RayTrainRunner
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb, DistributedCellLineAnnDataset

from ray.train import Checkpoint
from ray.train.lightning import RayDDPStrategy

âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


root - INFO - Logging initialized. Current level is: INFO


In [2]:
from importlib.metadata import version

print(version("protoplast"))

0.1.2


## 2. Dataset pre-processing

We begin by reading the two datasets used to train the cell-line classification model in this notebook. To ensure compatibility, the model requires that both datasets have the same output dimensions

In the following section, we create a unified view by performing an **inner join** on the two datasets based on shared features. During this step, we:

- Identify and record the **number of output classes** (cell-lines),
- Extract the list of **cell-line** of both dataset.

This alignment is essential to ensure the model receives a consistent input/output structure regardless of the dataset source.

In [3]:
DS_PATHS = ["/mnt/hdd2/tan/tahoe100m/plate1_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad",
           "/mnt/hdd2/tan/tahoe100m/plate2_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
adatas = [anndata.io.read_h5ad(p, backed = "r") for p in DS_PATHS]

In [4]:
# Create a view of all dataset
collection = AnnCollection(adatas, join_vars = "inner")

# Record the cell-lines (output classes) in both datasets
cell_lines = collection.obs.cell_line.unique().tolist()
cell_lines_count = collection.obs.cell_line.nunique()

## 3. Configure training step

In [5]:
thread_per_worker = 12
test_size = 0.0 # We don't have the test step in the model, so we can set this to 0
val_size = 0.2

## 4. Train on `plate1_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab` dataset

In [6]:
plate1_adata = adatas[0]

In [7]:
plate1_adata.obs.head(n = 5)

Unnamed: 0_level_0,sample,gene_count,tscp_count,mread_count,drugname_drugconc,drug,cell_line,sublibrary,BARCODE,pcnt_mito,S_score,G2M_score,phase,pass_filter,cell_name,plate
BARCODE_SUB_LIB_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
01_001_025-lib_841,smp_1495,1676,2441,2892,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0131,lib_841,01_001_025,0.025399,-0.066667,-0.095055,G1,full,A-172,plate1
01_001_026-lib_841,smp_1495,1657,2454,2925,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0480,lib_841,01_001_026,0.042787,0.128571,0.650549,G2M,full,PANC-1,plate1
01_001_048-lib_841,smp_1495,1749,2521,2963,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0293,lib_841,01_001_048,0.056724,0.242857,0.308791,G2M,full,HEC-1-A,plate1
01_001_076-lib_841,smp_1495,834,1038,1258,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_0397,lib_841,01_001_076,0.066474,0.009524,0.245788,G2M,full,LS 180,plate1
01_001_088-lib_841,smp_1495,1275,1710,2006,"[('Infigratinib', 0.05, 'uM')]",Infigratinib,CVCL_1097,lib_841,01_001_088,0.028655,-0.1,-0.085348,G1,full,C32,plate1


In [8]:
# Set up training
trainer = RayTrainRunner(
    LinearClassifier,
    DistributedCellLineAnnDataset,
    model_keys = ["num_genes",
                  "num_classes"],
    metadata_cb = cell_line_metadata_cb,
    sparse_key = "X",
    runtime_env_config = {"working_dir": os.getcwd()},
)

2025-11-07 05:59:05,041	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-11-07 05:59:05,068	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-07 05:59:05,089	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_540b6eeb2de965aa.zip' (2.08MiB) to Ray cluster...
2025-11-07 05:59:05,107	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_540b6eeb2de965aa.zip'.


In [9]:
# Start training process. The output of the training phase will be output to the cell above where 
# we initialize a ray train runner.
result = trainer.train([DS_PATHS[0]],
                       batch_size = 1024,
                       test_size = test_size, 
                       val_size = val_size,
                       num_workers = 1,
                       thread_per_worker = thread_per_worker,
                       result_storage_path = "~/training_results")

protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 12, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 66 length of test_split: 0, length of train_split: 268
protoplast.scrna.anndata.strategy - INFO - Dropping 4 mini-batches
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 66, length of test_split: 0, length of train_split: 268


[36m(TrainController pid=2964641)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=2964641)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=2964641)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=2964641)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 12, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=2965202)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=2965202)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=2965202)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=2965202)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=2964641)[0m Started training worker group of size 1: 
[36m(TrainController pid=2964641)[0m - (ip=192.168.1.226, pid=2965202) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=2965202)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=2965202)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=2965202)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=2965202)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=2965202)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/e

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=2965202)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 11.92it/s]
                                                                           


[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2965202)[0m   return torch.sparse_csr_tensor(


Epoch 0:   0%|          | 0/4284 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4284 [00:13<16:22:35,  0.07it/s, v_num=0, train_loss=3.980]
Epoch 0:   0%|          | 2/4284 [00:13<8:12:29,  0.14it/s, v_num=0, train_loss=3.840] 
...
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4278/4284 [01:38<00:00, 43.47it/s, v_num=0, train_loss=0.137] 
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4279/4284 [01:38<00:00, 43.47it/s, v_num=0, train_loss=0.211]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4284/4284 [01:38<00:00, 43.50it/s, v_num=0, train_loss=0.145] 
[36m(RayTrainWorker pid=2965202)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=2965202)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1056 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/1056 [00:00<00:04, 220.40it/s][A
Validation DataLoader 0:   0%|          | 2/1056 [00:00<00:10, 99.65it/s] [A
[36m(RayTrainWorker pid=2965202)[

[36m(RayTrainWorker pid=2965202)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/training_results/ray_train_run-2025-11-07_05-59-18/checkpoint_2025-11-07_06-02-11.615409)
[36m(RayTrainWorker pid=2965202)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/training_results/ray_train_run-2025-11-07_05-59-18/checkpoint_2025-11-07_06-02-11.615409), metrics={'train_loss': 0.1452408730983734, 'val_acc': 0.982025146484375, 'epoch': 0, 'step': 4284}, validation_spec=None)
[36m(RayTrainWorker pid=2965202)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


In [10]:
ray.shutdown()

## 5. Train on `plate2_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab` dataset

We now have a checkpoint saved after training the classification model using the first dataset. We need to pass into `train()` the path to the checkpoint file. This path can be retrieved from the result trainer in previous `train()`

In [11]:
plate2_adata = adatas[1]

In [12]:
plate2_adata.obs.head(n = 5)

Unnamed: 0_level_0,sample,gene_count,tscp_count,mread_count,drugname_drugconc,drug,cell_line,sublibrary,BARCODE,pcnt_mito,S_score,G2M_score,phase,pass_filter,cell_name,plate
BARCODE_SUB_LIB_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
01_001_053-lib_1000,smp_1591,2671,5629,6830,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_1119,lib_1000,01_001_053,0.016522,-0.265873,-0.313553,G1,full,CFPAC-1,plate2
01_001_082-lib_1000,smp_1591,2148,3173,3826,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0292,lib_1000,01_001_082,0.025843,0.400794,0.520879,G2M,full,HCT15,plate2
01_001_145-lib_1000,smp_1591,683,886,1073,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_1098,lib_1000,01_001_145,0.029345,-0.019841,-0.032967,G1,full,HepG2/C3A,plate2
01_001_175-lib_1000,smp_1591,1845,2786,3368,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0131,lib_1000,01_001_175,0.031587,-0.123016,-0.118498,G1,full,A-172,plate2
01_001_181-lib_1000,smp_1591,1228,1849,2226,"[('Infigratinib', 0.5, 'uM')]",Infigratinib,CVCL_0399,lib_1000,01_001_181,0.015143,0.02381,-0.008791,S,full,LoVo,plate2


In [13]:
# Set up training
trainer = RayTrainRunner(
    LinearClassifier,
    DistributedCellLineAnnDataset,
    model_keys = ["num_genes",
                  "num_classes"],
    metadata_cb = cell_line_metadata_cb,
    sparse_key = "X",
    runtime_env_config = {"working_dir": os.getcwd()},
    
)

2025-11-07 06:02:23,753	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-11-07 06:02:23,777	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-07 06:02:23,795	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_6ce7ea5a35b55e7c.zip' (2.58MiB) to Ray cluster...
2025-11-07 06:02:23,815	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_6ce7ea5a35b55e7c.zip'.


In [14]:
# We get the checkpoint path from the training result in Tahoe plate 1. The training
# progress will be shown in cell 12 above.
ckpt_path = os.path.join(result.checkpoint.path, "checkpoint.ckpt")

trainer.train([DS_PATHS[1]],
              max_epochs = 2,
              batch_size = 1024,
              test_size = test_size, 
              val_size = val_size,
              num_workers = 1,
              thread_per_worker = thread_per_worker,
              ckpt_path = ckpt_path)

protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 12, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 98 length of test_split: 0, length of train_split: 394
protoplast.scrna.anndata.strategy - INFO - Dropping 8 mini-batches
protoplast.scrna.anndata.strategy - INFO - Dropping 4 mini-batches
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 98, length of test_split: 0, length of train_split: 394


[36m(TrainController pid=2980637)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=2980637)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=2980637)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=2980637)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 12, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=2980921)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=2980921)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=2980921)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=2980921)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=2980637)[0m Started training worker group of size 1: 
[36m(TrainController pid=2980637)[0m - (ip=192.168.1.226, pid=2980921) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=2980921)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=2980921)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=2980921)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=2980921)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=2980921)[0m   return _C._get_float32_matmul_precision()
[36m(RayTrainWorker pid=2980921)[0m You are using a 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_compressed_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:00<00:00,  7.03it/s]


[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2980921)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in di

                                                                           
Epoch 1:   0%|          | 0/6300 [00:00<?, ?it/s]
Epoch 1:   0%|          | 3/6300 [00:22<13:08:33,  0.13it/s, v_num=0, train_loss=0.101] 
...
...
...
Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 6293/6300 [03:17<00:00, 31.92it/s, v_num=0, train_loss=0.153] 
Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6300/6300 [03:17<00:00, 31.94it/s, v_num=0, train_loss=0.108] 
[36m(RayTrainWorker pid=2980921)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=2980921)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1560 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/1560 [00:00<00:37, 41.08it/s][A
[36m(RayTrainWorker pid=2980921)[0m 
Validation DataLoader 0:   0%|          | 2/1560 [00:00<00:38, 40.40it/s][A
...
...
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1559/1560 [00:30<00:00, 50.56it/s][A
Validatio

[36m(RayTrainWorker pid=2980921)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-02-42/checkpoint_2025-11-07_06-07-35.887101)
[36m(RayTrainWorker pid=2980921)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-02-42/checkpoint_2025-11-07_06-07-35.887101), metrics={'train_loss': 0.10751987993717194, 'val_acc': 0.981541097164154, 'epoch': 1, 'step': 10584}, validation_spec=None)
[36m(RayTrainWorker pid=2980921)[0m `Trainer.fit` stopped: `max_epochs=2` reached.


Result(metrics={'train_loss': 0.10751987993717194, 'val_acc': 0.981541097164154, 'epoch': 1, 'step': 10584}, checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-02-42/checkpoint_2025-11-07_06-07-35.887101), error=None, path='/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-02-42', metrics_dataframe=   train_loss   val_acc  epoch   step
0     0.10752  0.981541      1  10584, best_checkpoints=[(Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-02-42/checkpoint_2025-11-07_06-07-35.887101), {'train_loss': 0.10751987993717194, 'val_acc': 0.981541097164154, 'epoch': 1, 'step': 10584})], _storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7ac7398edcf0>)

In [15]:
ray.shutdown()

### Conclusion

This brings us to the end of the tutorial notebook.

This workflow highlights using checkpointing in **PROTOplast**, enabling efficient model development across diverse datasets.

Feel free to explore and extend this notebook to suit your own data and use cases!