# Simple Classifier Model for Single-Cell Data with PROTOplast

This tutorial demonstrates how to use PROTOplast to train a simple classification model in PyTorch with the `h5ad` format.

**Setup**  
- Configure the training environment for single-cell RNA sequencing (scRNA-seq) data using **PROTOplast** in combination with **PyTorch Lightning** and **Ray**.

In [1]:
%%time
import ray
import protoplast
import glob
import os
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.torch_dataloader import DistributedCellLineAnnDataset as Dcl
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner

âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


root - INFO - Logging initialized. Current level is: INFO


CPU times: user 18.7 s, sys: 1.29 s, total: 20 s
Wall time: 7.65 s


In [2]:
from importlib.metadata import version

print(version("protoplast"))

0.1.2


## Simple Classifier

This example illustrates how to configure a training runner with **PROTOplast** and **Ray**.

- `LinearClassifier`: a simple baseline model that can be swapped with a custom implementation
- `Dcl`: the dataset object for training, imported from `protoplast.scrna.anndata.torch_dataloader`
  - Defined as a subclass of `DistributedAnnDataset`, customized for cell line classification tasks
- `["num_genes", "num_classes"]`: arguments that specify the modelâ€™s input and output dimensions
- `cell_line_metadata_cb`: a callback function that attaches dataset-specific metadata, such as cell line labels and class counts

In [3]:
%%time
trainer = RayTrainRunner(
    LinearClassifier,  # replace with your own model
    Dcl,  # replace with your own Dataset
    ["num_genes", "num_classes"],  # change according to what you need for your model
    cell_line_metadata_cb,  # include data you need for your dataset
    runtime_env_config = {"working_dir": os.getcwd()},
)

2025-11-07 06:08:44,275	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-11-07 06:08:44,357	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-07 06:08:44,372	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_beaf74cebfeabf4e.zip' (3.25MiB) to Ray cluster...
2025-11-07 06:08:44,390	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_beaf74cebfeabf4e.zip'.


CPU times: user 263 ms, sys: 352 ms, total: 615 ms
Wall time: 10.3 s




On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `train()` completed in approximately **7 minutes**.

- `file_paths`: Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas

In [4]:
%%time
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
trainer.train(file_paths)

protoplast.scrna.anndata.trainer - INFO - Setting thread_per_worker to half of the available CPUs capped at 4
protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 4, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 65 length of test_split: 0, length of train_split: 262
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 65, length of test_split: 0, length of train_split: 262


[36m(TrainController pid=2998691)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=2998691)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=2998691)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=2998691)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 4, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=2998964)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=2998964)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=2998964)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=2998964)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=2998691)[0m Started training worker group of size 1: 
[36m(TrainController pid=2998691)[0m - (ip=192.168.1.226, pid=2998964) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=2998964)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=2998964)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=2998964)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=2998964)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=2998964)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/e

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=2998964)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
[36m(RayTrainWorker pid=2998964)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2998964)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=2998964)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 10.13it/s]
                                                                           


[36m(RayTrainWorker pid=2998964)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2998964)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]


[36m(RayTrainWorker pid=2998964)[0m   return torch.sparse_csr_tensor(


Epoch 0:   0%|          | 3/4192 [00:24<9:22:42,  0.12it/s, v_num=0, train_loss=2.920] 
Epoch 0:   0%|          | 4/4192 [00:24<7:02:19,  0.17it/s, v_num=0, train_loss=2.420]
Epoch 0:   0%|          | 9/4192 [00:24<3:08:20,  0.37it/s, v_num=0, train_loss=1.230]
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4186/4192 [04:35<00:00, 15.19it/s, v_num=0, train_loss=0.0745]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4191/4192 [04:35<00:00, 15.20it/s, v_num=0, train_loss=0.121] 
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [04:35<00:00, 15.20it/s, v_num=0, train_loss=0.157]
[36m(RayTrainWorker pid=2998964)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=2998964)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1040 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/1040 [00:00<00:25, 40.46it/s][A
[36m(RayTrainWorker pid=2998964)[0m 
Validation DataLoader 0:   0%|          | 2

[36m(RayTrainWorker pid=2998964)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-09-09/checkpoint_2025-11-07_06-16-00.959916)
[36m(RayTrainWorker pid=2998964)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-09-09/checkpoint_2025-11-07_06-16-00.959916), metrics={'train_loss': 0.15666209161281586, 'val_acc': 0.9862812757492065, 'epoch': 0, 'step': 4192}, validation_spec=None)
[36m(RayTrainWorker pid=2998964)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [06:03<00:00, 11.52it/s, v_num=0, train_loss=0.157]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [06:04<00:00, 11.52it/s, v_num=0, train_loss=0.157]
CPU times: user 21.4 s, sys: 3.69 s, total: 25 s
Wall time: 7min 15s


Result(metrics={'train_loss': 0.15666209161281586, 'val_acc': 0.9862812757492065, 'epoch': 0, 'step': 4192}, checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-09-09/checkpoint_2025-11-07_06-16-00.959916), error=None, path='/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-09-09', metrics_dataframe=   train_loss   val_acc  epoch  step
0    0.156662  0.986281      0  4192, best_checkpoints=[(Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-07_06-09-09/checkpoint_2025-11-07_06-16-00.959916), {'train_loss': 0.15666209161281586, 'val_acc': 0.9862812757492065, 'epoch': 0, 'step': 4192})], _storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7ccac5760570>)

- `batch_size`: number of samples per training batch. The default value is `2000`
- `test_size`: fraction of data reserved for testing. The default value is `0.0`
- `val_size`: fraction of data reserved for validation. The default value is `0.2`