# Simple Classifier Model for Single-Cell Data with PROTOplast

This tutorial demonstrates how to use PROTOplast to train a simple classification model in PyTorch with the `h5ad` format.

**Setup**  
- Configure the training environment for single-cell RNA sequencing (scRNA-seq) data using **PROTOplast** in combination with **PyTorch Lightning** and **Ray**.

In [1]:
%%time
import ray
import protoplast
import glob
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.torch_dataloader import DistributedCellLineAnnDataset as Dcl
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner

✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
CPU times: user 18.7 s, sys: 1.91 s, total: 20.6 s
Wall time: 18.9 s


## Simple Classifier

This example illustrates how to configure a training runner with **PROTOplast** and **Ray**.

- `LinearClassifier`: a simple baseline model that can be swapped with a custom implementation
- `Dcl`: the dataset object for training, imported from `protoplast.scrna.anndata.torch_dataloader`
  - Defined as a subclass of `DistributedAnnDataset`, customized for cell line classification tasks
- `["num_genes", "num_classes"]`: arguments that specify the model’s input and output dimensions
- `cell_line_metadata_cb`: a callback function that attaches dataset-specific metadata, such as cell line labels and class counts

In [2]:
%%time
trainer = RayTrainRunner(
    LinearClassifier,  # replace with your own model
    Dcl,  # replace with your own Dataset
    ["num_genes", "num_classes"],  # change according to what you need for your model
    cell_line_metadata_cb,  # include data you need for your dataset
)

2025-09-29 12:33:20,673	INFO worker.py:1951 -- Started a local Ray instance.
2025-09-29 12:33:23,058	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example'.
2025-09-29 12:33:23,448	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_5a587d83feb1136c.zip' (69.28MiB) to Ray cluster...
2025-09-29 12:33:23,931	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_5a587d83feb1136c.zip'.


CPU times: user 918 ms, sys: 833 ms, total: 1.75 s
Wall time: 11.9 s


[33m(raylet)[0m Using CPython [36m3.11.13[39m
[33m(raylet)[0m Creating virtual environment at: [36m.venv[39m
[33m(raylet)[0m [2mInstalled [1m296 packages[0m [2min 407ms[0m[0m


[36m(TrainTrainable pid=3232244)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=3232244)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3233388)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=3232244)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3232244)[0m - (node_id=6090e6db35961bd68520f85638e41d21a6a20d06c51b719a5a02951f, ip=192.168.1.226, pid=3233388) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=3233388)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=3233388)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3233388)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=3233388)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=3233388)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3233388)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3233388)[0m /tmp/ray/session_2025-09-29_12-33-16_154841_3224373/runtime_resources/working_dir_files/_ray_pkg_5a587d83feb1136c/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3 /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pyth ...
[36m(RayTrainWorker pid=3233388)[0m You are using 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=3233388)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3233388)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=3233388)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  6.50it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 11.51it/s]
                                                                           


[36m(RayTrainWorker pid=3233388)[0m /tmp/ray/session_2025-09-29_12-33-16_154841_3224373/runtime_resources/working_dir_files/_ray_pkg_5a587d83feb1136c/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4160 [00:00<?, ?it/s] 


[36m(RayTrainWorker pid=3233388)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3233388)[0m   return torch.sparse_csr_tensor(


Epoch 0:   0%|          | 4/4160 [00:23<6:42:21,  0.17it/s, v_num=0, train_loss=2.430] 
Epoch 0:   0%|          | 9/4160 [00:23<2:59:26,  0.39it/s, v_num=0, train_loss=1.450]
Epoch 0:   0%|          | 14/4160 [00:23<1:55:46,  0.60it/s, v_num=0, train_loss=0.564]
.
.
.
Epoch 0: 100%|█████████▉| 4143/4160 [04:44<00:01, 14.59it/s, v_num=0, train_loss=0.112] 
Epoch 0: 100%|█████████▉| 4148/4160 [04:44<00:00, 14.60it/s, v_num=0, train_loss=0.154]
Epoch 0: 100%|█████████▉| 4149/4160 [04:44<00:00, 14.60it/s, v_num=0, train_loss=0.117]
Epoch 0: 100%|█████████▉| 4154/4160 [04:44<00:00, 14.61it/s, v_num=0, train_loss=0.0889]
Epoch 0: 100%|█████████▉| 4155/4160 [04:44<00:00, 14.61it/s, v_num=0, train_loss=0.048] 
Epoch 0: 100%|██████████| 4160/4160 [04:44<00:00, 14.63it/s, v_num=0, train_loss=0.130] 
[36m(RayTrainWorker pid=3233388)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=3233388)[0m 
Validation:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataL

[36m(RayTrainWorker pid=3233388)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_12-33-50/TorchTrainer_8d4a2_00000_0_2025-09-29_12-33-50/checkpoint_000000)
[36m(RayTrainWorker pid=3233388)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4160/4160 [06:10<00:00, 11.23it/s, v_num=0, train_loss=0.130]




On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `train()` completed in approximately **11 minutes**.

- `file_paths`: Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas

In [3]:
%%time
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
trainer.train(file_paths)

Setting thread_per_worker to half of the available CPUs capped at 4
Using 1 workers with {'CPU': 4} each
Data splitting time: 22.06 seconds
Spawning Ray worker and initiating distributed training


2025-09-29 12:33:50,287	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-09-29 12:33:50 (running for 00:00:00.16)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_12-33-16_154841_3224373/artifacts/2025-09-29_12-33-50/TorchTrainer_2025-09-29_12-33-50/driver_artifacts
Number of trials: 1/1 (1 PENDING)


.
.
.
== Status ==
Current time: 2025-09-29 12:40:58 (running for 00:07:07.94)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_12-33-16_154841_3224373/artifacts/2025-09-29_12-33-50/TorchTrainer_2025-09-29_12-33-50/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2025-09-29 12:41:01,331	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-29_12-33-50' in 0.0099s.
2025-09-29 12:41:01,335	INFO tune.py:1041 -- Total run time: 431.05 seconds (430.65 seconds for the tuning loop).


== Status ==
Current time: 2025-09-29 12:41:01 (running for 00:07:10.66)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_12-33-16_154841_3224373/artifacts/2025-09-29_12-33-50/TorchTrainer_2025-09-29_12-33-50/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 24.9 s, sys: 3.91 s, total: 28.8 s
Wall time: 7min 33s


Result(
  metrics={'train_loss': 0.13011327385902405, 'val_acc': 0.9862006902694702, 'epoch': 0, 'step': 4160},
  path='/home/dtran/protoplast_results/TorchTrainer_2025-09-29_12-33-50/TorchTrainer_8d4a2_00000_0_2025-09-29_12-33-50',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_12-33-50/TorchTrainer_8d4a2_00000_0_2025-09-29_12-33-50/checkpoint_000000)
)

- `batch_size`: number of samples per training batch. The default value is `2000`
- `test_size`: fraction of data reserved for testing. The default value is `0.0`
- `val_size`: fraction of data reserved for validation. The default value is `0.2`