# Classification Models for Single-Cell Data with PROTOplast

This tutorial demonstrates how to use PROTOplast to train different classification models in PyTorch with the `h5ad` format.

**Download the Tahoe-100M `h5ad` files**
- The Tahoe-100M dataset can be downloaded in `h5ad` format from the **Arc Institute Google Cloud Storage**.
- For step-by-step instructions, see the [official tutorial](https://github.com/ArcInstitute/arc-virtual-cell-atlas/blob/main/tahoe-100M/README.md).

**Setup**  
- Configure the training environment for single-cell RNA sequencing (scRNA-seq) data using **PROTOplast** in combination with **PyTorch Lightning** and **Ray**.

In [1]:
%%time
import anndata
import numpy as np
import ray

# models
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.torch_dataloader import DistributedCellLineAnnDataset as Dcl
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner
from ray.train.lightning import RayDDPStrategy
from scsims.model import SIMSClassifier

# scvi training plan
## install scvi-tools if needed:
## uv add scvi-tools
from scvi.module import Classifier
from scvi.train import ClassifierTrainingPlan

  from .autonotebook import tqdm as notebook_tqdm
2025-09-24 10:35:30,416	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


✓ Applied AnnDataFileManager patch


2025-09-24 10:35:36,024	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-09-24 10:35:36,067	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


✓ Applied AnnDataFileManager patch
CPU times: user 18.5 s, sys: 1.25 s, total: 19.7 s
Wall time: 10.6 s


## 1. Load the Tahoe 100-M Dataset (`h5ad`)
- `file_paths`: Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas
- `thread_per_worker`: number of threads allocated per worker. The default value is `1`
- `batch_size`: number of samples per training batch
- `test_size`: fraction of data reserved for testing (use `0.0` if no test set is needed)
- `val_size`: fraction of data reserved for validation 


In [2]:
%%time
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
thread_per_worker = 2
batch_size = 2000
test_size = 0.0
val_size = 0.2

CPU times: user 17 μs, sys: 2 μs, total: 19 μs
Wall time: 34.3 μs


## 2. Simple Classifier

This example illustrates how to configure a training runner with **PROTOplast** and **Ray**.

- `LinearClassifier`: a simple baseline model that can be swapped with a custom implementation
- `Dcl`: the dataset object for training, imported from `protoplast.scrna.anndata.torch_dataloader`
  - Defined as a subclass of `DistributedAnnDataset`, customized for cell line classification tasks
- `["num_genes", "num_classes"]`: arguments that specify the model’s input and output dimensions
- `cell_line_metadata_cb`: a callback function that attaches dataset-specific metadata, such as cell line labels and class counts

In [3]:
%%time
LinearClassifier_trainer = RayTrainRunner(
    LinearClassifier,  # replace with your own model
    Dcl,  # replace with your own Dataset
    ["num_genes", "num_classes"],  # change according to what you need for your model
    cell_line_metadata_cb,  # include data you need for your dataset
)

2025-09-24 10:35:39,991	INFO worker.py:1951 -- Started a local Ray instance.


CPU times: user 216 ms, sys: 286 ms, total: 502 ms
Wall time: 3.83 s
[36m(TrainTrainable pid=1150735)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1150735)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1150881)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1150735)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1150735)[0m - (node_id=c3fc7992c9dd18093e236be96e846423f856ed9a44226d1a5590ece9, ip=192.168.1.226, pid=1150881) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1150881)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1150881)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1150881)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1150881)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1150881)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1150881)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1150881)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pytho ...
[36m(RayTrainWorker pid=1150881)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1150881)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1150881)[0m   return torch.sparse_compressed_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  4.25it/s]
                                                                           


[36m(RayTrainWorker pid=1150881)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1150881)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/4192 [00:21<25:23:43,  0.05it/s, v_num=0, train_loss=4.230]
Epoch 0:   0%|          | 4/4192 [00:22<6:24:22,  0.18it/s, v_num=0, train_loss=2.540] 
Epoch 0:   0%|          | 8/4192 [00:22<3:12:53,  0.36it/s, v_num=0, train_loss=1.310]
Epoch 0:   0%|          | 13/4192 [00:22<1:59:12,  0.58it/s, v_num=0, train_loss=0.844]
Epoch 0:   0%|          | 18/4192 [00:22<1:26:27,  0.80it/s, v_num=0, train_loss=0.430]
Epoch 0:   1%|          | 23/4192 [00:22<1:07:56,  1.02it/s, v_num=0, train_loss=0.611]
Epoch 0:   1%|          | 28/4192 [00:22<56:00,  1.24it/s, v_num=0, train_loss=0.362]  
Epoch 0:   1%|          | 29/4192 [00:22<54:06,  1.28it/s, v_num=0, train_loss=0.526]
Epoch 0:   1%|          | 32/4192 [00:22<49:08,  1.41it/s, v_num=0, train_loss=0.284]
Epoch 0:   1%|          | 33/4192 [00:25<52:32,  1.32it/s, v_num=0, train_loss=0.310]
Epoch 0:   1%|          | 37/4192 [00:25<47:12,  1.47it/s, v_num=0, train_lo

[36m(RayTrainWorker pid=1150881)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_10-36-00/TorchTrainer_43548_00000_0_2025-09-24_10-36-00/checkpoint_000000)
[36m(RayTrainWorker pid=1150881)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4192/4192 [09:35<00:00,  7.29it/s, v_num=0, train_loss=0.108]
Epoch 0: 100%|██████████| 4192/4192 [09:35<00:00,  7.28it/s, v_num=0, train_loss=0.108]


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `LinearClassifier_trainer.train()` completed in approximately **11 minutes**.

In [4]:
%%time
LinearClassifier_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
    thread_per_worker=thread_per_worker,  # 2
)
ray.shutdown()

Using 1 workers with {'CPU': 2} each


2025-09-24 10:36:00,545	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 19.35 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 10:36:00 (running for 00:00:00.13)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_10-35-37_276608_1141478/artifacts/2025-09-24_10-36-00/TorchTrainer_2025-09-24_10-36-00/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 10:36:05 (running for 00:00:05.17)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_10-35-37_276608_1141478/artifacts/2025-09-24_10-36-00/TorchTrainer_2025-09-24_10-36-00/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 10:36:10 (running for 00:00:10.22)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-24 10:46:14,786	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-24_10-36-00' in 0.0064s.
2025-09-24 10:46:14,790	INFO tune.py:1041 -- Total run time: 614.24 seconds (614.21 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 10:46:14 (running for 00:10:14.22)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_10-35-37_276608_1141478/artifacts/2025-09-24_10-36-00/TorchTrainer_2025-09-24_10-36-00/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 25.4 s, sys: 3.98 s, total: 29.3 s
Wall time: 10min 35s


## 3. SIMS: Scalable, Interpretable Models for Cell Annotation of large scale single-cell RNA-seq data
**SIMS** is a pipeline designed to build interpretable and accurate classifiers for identifying any target in single-cell RNA sequencing (scRNA-seq) data.  
- The core SIMS model is based on a **sequential transformer**, a specialized transformer architecture built for large-scale tabular datasets. 
- SIMS provides a framework for **cell type annotation**: it trains on labeled single-cell data and predicts cell type labels for new, unlabeled cells. 
- It leverages the **TabNet** deep learning model, which automatically selects the most informative genes for each prediction, ensuring results that are both **accurate** and **interpretable**.  
For implementation details and source code, see the [SIMS GitHub repository](https://github.com/braingeneers/SIMS/tree/main).

### SIMS Metadata Callback
This callback (`sims_metadata_cb`) extracts key information from the AnnData object to configure the SIMS model.
- `input_dim`: the number of genes (features) in the dataset.
- `cell_lines`: list of unique cell line categories.
- `output_dim`: the number of distinct classes (cell lines) to be predicted.

In [5]:
%%time


def sims_metadata_cb(ad: anndata.AnnData, metadata: dict):
    metadata["num_genes"] = ad.var.shape[0]
    metadata["input_dim"] = metadata["num_genes"]
    metadata["cell_lines"] = ad.obs["cell_line"].cat.categories.to_list()
    metadata["num_classes"] = len(metadata["cell_lines"])
    metadata["output_dim"] = metadata["num_classes"]

CPU times: user 32 μs, sys: 0 ns, total: 32 μs
Wall time: 42.7 μs


### Training the SIMS Classifier

- The **SIMSClassifier** model is initialized with the dataset (`Dcl`), while essential arguments (`input_dim`, `output_dim`) are supplied through the `sims_metadata_cb` callback 
- Training is distributed using **RayDDPStrategy**, with `find_unused_parameters=True` enabled to ensure proper handling of layers that may not be active in every forward pass


In [6]:
%%time
sims_trainer = RayTrainRunner(
    SIMSClassifier,
    Dcl,
    ["input_dim", "output_dim"],  # maps to SIMSClassifier(input_dim, output_dim)
    sims_metadata_cb,
    ray_trainer_strategy=RayDDPStrategy(find_unused_parameters=True),
)

2025-09-24 10:50:12,599	INFO worker.py:1951 -- Started a local Ray instance.


CPU times: user 93.6 ms, sys: 213 ms, total: 307 ms
Wall time: 3.58 s
[36m(TrainTrainable pid=1164991)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1164991)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1165153)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1164991)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1164991)[0m - (node_id=e9cfae65cac0d14535b39069127ce6a70b7f23ce583d528085585376, ip=192.168.1.226, pid=1165153) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1165153)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1165153)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1165153)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1165153)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1165153)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1165153)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1165153)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pytho ...
[36m(RayTrainWorker pid=1165153)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1165153)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1165153)[0m   return torch.sparse_compressed_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1165153)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:01<00:01,  0.89it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:01<00:00,  1.67it/s]


[36m(RayTrainWorker pid=1165153)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=1165153)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=1165153)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/macro_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the 

Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]                          
Epoch 0:   0%|          | 1/4192 [00:24<28:07:32,  0.04it/s, v_num=0, train/loss_step=4.640]
Epoch 0:   0%|          | 2/4192 [00:24<14:08:17,  0.08it/s, v_num=0, train/loss_step=4.180]
Epoch 0:   0%|          | 3/4192 [00:24<9:28:04,  0.12it/s, v_num=0, train/loss_step=3.940] 
Epoch 0:   0%|          | 4/4192 [00:24<7:08:00,  0.16it/s, v_num=0, train/loss_step=3.730]
Epoch 0:   0%|          | 5/4192 [00:24<5:43:56,  0.20it/s, v_num=0, train/loss_step=3.570]
Epoch 0:   0%|          | 6/4192 [00:24<4:47:55,  0.24it/s, v_num=0, train/loss_step=3.410]
Epoch 0:   0%|          | 7/4192 [00:24<4:07:51,  0.28it/s, v_num=0, train/loss_step=3.290]
Epoch 0:   0%|          | 8/4192 [00:24<3:37:51,  0.32it/s, v_num=0, train/loss_step=3.230]
Epoch 0:   0%|          | 9/4192 [00:25<3:14:30,  0.36it/s, v_num=0, train/loss_step=3.140]
Epoch 0:   0%|          | 10/4192 [00:25<2:55:49,  0.40it/s, v_num=0, train/loss_step=3.080]


[36m(RayTrainWorker pid=1165153)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=1165153)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_10-50-35/TorchTrainer_4d208_00000_0_2025-09-24_10-50-35/checkpoint_000000)
[36m(RayTrainWorker pid=1165153)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train/f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=1165153)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv

Epoch 0: 100%|██████████| 4192/4192 [14:56<00:00,  4.67it/s, v_num=0, train/loss_step=0.495, val/loss=0.583, val/f1=0.830, val/macro_acc=0.827, val/micro_acc=0.929, val/precision=0.837, val/recall=0.827, val/specificity=0.999, val/weighted_acc=0.929, train/loss_epoch=0.489]
Epoch 0: 100%|██████████| 4192/4192 [14:56<00:00,  4.67it/s, v_num=0, train/loss_step=0.495, val/loss=0.583, val/f1=0.830, val/macro_acc=0.827, val/micro_acc=0.929, val/precision=0.837, val/recall=0.827, val/specificity=0.999, val/weighted_acc=0.929, train/loss_epoch=0.489]


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `sims_trainer.train()` completed in about **17 minutes**.

In [7]:
%%time
sims_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
    thread_per_worker=thread_per_worker,
)
ray.shutdown()

Using 1 workers with {'CPU': 2} each


2025-09-24 10:50:35,977	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 19.49 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 10:50:36 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_10-50-10_126896_1141478/artifacts/2025-09-24_10-50-35/TorchTrainer_2025-09-24_10-50-35/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 10:50:41 (running for 00:00:05.17)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_10-50-10_126896_1141478/artifacts/2025-09-24_10-50-35/TorchTrainer_2025-09-24_10-50-35/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 10:50:46 (running for 00:00:10.27)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-24 11:07:00,414	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-24_10-50-35' in 0.0150s.
2025-09-24 11:07:00,419	INFO tune.py:1041 -- Total run time: 984.44 seconds (984.41 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 11:07:00 (running for 00:16:24.43)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_10-50-10_126896_1141478/artifacts/2025-09-24_10-50-35/TorchTrainer_2025-09-24_10-50-35/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 32.9 s, sys: 5.4 s, total: 38.3 s
Wall time: 16min 45s


## 4. Autoencoder
- An **autoencoder** is an unsupervised neural network consisting of three main components:  
  - **Encoder**: compresses the input into a lower-dimensional representation.  
  - **Bottleneck**: stores the compressed features.  
  - **Decoder**: reconstructs the input from the bottleneck representation.  
- In this setup, separate encoders process **gene** and **protein** data. Their outputs are concatenated, passed through an additional encoder to form the bottleneck, and then decoded back to the original input.  
- Since **Tahoe-100M** does not include protein data, the protein input is set to `0`, and the source code was adapted to ensure compatibility with datasets lacking protein features.
- For testing purposes, we temporarily set mid = 128, which reduces the hidden layer size and simplifies the model architecture. For implementation details, see the [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [8]:
%%time
# group linear, batchnorm, and dropout layers. This module was from citeseq_autoencoder notebook
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from torch import nn, optim


class LinBnDrop(nn.Sequential):
    """Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers, adapted from fastai."""

    def __init__(self, n_in, n_out, bn=True, p=0.0, act=None, lin_first=True):
        layers = [nn.BatchNorm1d(n_out if lin_first else n_in)] if bn else []
        if p != 0:
            layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None:
            lin.append(act)
        layers = lin + layers if lin_first else layers + lin
        super().__init__(*layers)

CPU times: user 433 μs, sys: 0 ns, total: 433 μs
Wall time: 446 μs


We implement an encoder that processes RNA features through a two-layer MLP (`nfeatures_rna` → `mid=128` → `hidden_rna`, with `mid=2` set for testing). The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [9]:
%%time


class Encoder(nn.Module):
    """Encoder for CITE-seq data"""

    def __init__(
        self, nfeatures_rna: int, nfeatures_pro: int, hidden_rna: int, hidden_pro: int, latent_dim: int, p: float = 0
    ):
        super().__init__()
        self.nfeatures_rna = nfeatures_rna
        self.nfeatures_pro = nfeatures_pro

        if nfeatures_rna > 0:
            mid = 128  # 128 is for testing the code
            self.encoder_rna = nn.Sequential(
                LinBnDrop(nfeatures_rna, mid, p=p, act=nn.LeakyReLU()),
                LinBnDrop(mid, hidden_rna, act=nn.LeakyReLU()),
            )

        if nfeatures_pro > 0:
            self.encoder_protein = LinBnDrop(nfeatures_pro, hidden_pro, p=p, act=nn.LeakyReLU())

        # make sure hidden_rna and hidden_pro are set correctly
        hidden_rna = 0 if nfeatures_rna == 0 else hidden_rna
        hidden_pro = 0 if nfeatures_pro == 0 else hidden_pro

        hidden_dim = hidden_rna + hidden_pro

        self.encoder = LinBnDrop(hidden_dim, latent_dim, act=nn.LeakyReLU())

    def forward(self, x):
        if self.nfeatures_rna > 0 and self.nfeatures_pro > 0:
            x_rna = self.encoder_rna(x[:, : self.nfeatures_rna])
            x_pro = self.encoder_protein(x[:, self.nfeatures_rna :])
            x = torch.cat([x_rna, x_pro], 1)
        elif self.nfeatures_rna > 0 and self.nfeatures_pro == 0:
            x = self.encoder_rna(x)
        elif self.nfeatures_rna == 0 and self.nfeatures_pro > 0:
            x = self.encoder_protein(x)
        return self.encoder(x)

CPU times: user 116 μs, sys: 0 ns, total: 116 μs
Wall time: 128 μs


We implement a decoder that maps the latent vector to the RNA feature space by first expanding it to `hidden_rna`, passing it through a small intermediate layer (`mid_out` = `128`, used for testing), and finally projecting it to the RNA output dimension. The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [10]:
%%time


class Decoder(nn.Module):
    """Decoder for CITE-seq data"""

    def __init__(self, nfeatures_rna: int, nfeatures_pro: int, hidden_rna: int, hidden_pro: int, latent_dim: int):
        super().__init__()
        # make sure hidden_rna and hidden_pro are set correctly
        hidden_rna = 0 if nfeatures_rna == 0 else hidden_rna
        hidden_pro = 0 if nfeatures_pro == 0 else hidden_pro

        hidden_dim = hidden_rna + hidden_pro
        out_dim = nfeatures_rna + nfeatures_pro
        mid_out = 128  # 128 is for testing the code

        self.decoder = nn.Sequential(
            LinBnDrop(latent_dim, hidden_dim, act=nn.LeakyReLU()),
            LinBnDrop(hidden_dim, mid_out, act=nn.LeakyReLU()),
            LinBnDrop(mid_out, out_dim, bn=False),
        )

    def forward(self, x):
        return self.decoder(x)

CPU times: user 69 μs, sys: 0 ns, total: 69 μs
Wall time: 80.3 μs


The encoder and decoder are assembled into an autoencoder, which is defined as a PyTorch Lightning Module to simplify the training process. The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb)

In [11]:
%%time


class CiteAutoencoder(pl.LightningModule):
    def __init__(
        self,
        nfeatures_rna: int,
        nfeatures_pro: int,
        hidden_rna: int,
        hidden_pro: int,
        latent_dim: int,
        p: float = 0,
        lr: float = 0.1,
    ):
        """Autoencoder for citeseq data"""
        super().__init__()

        # save hyperparameters
        self.save_hyperparameters()

        self.encoder = Encoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim, p)
        self.decoder = Decoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim)

        # example input array for visualizing network graph
        self.example_input_array = torch.zeros(256, nfeatures_rna + nfeatures_pro)

    def forward(self, x):
        # extract latent embeddings
        z = self.encoder(x)
        return z

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def _get_reconstruction_loss(self, batch):
        """Calculate MSE loss for a given batch."""
        x, _ = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        # MSE loss
        loss = F.mse_loss(x_hat, x)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

CPU times: user 288 μs, sys: 0 ns, total: 288 μs
Wall time: 300 μs


### Autoencoder Metadata Callback
- The `ae_metadata_cb` function extends `cell_line_metadata_cb` and configures the metadata required for training the autoencoder. It sets up cell line information, defines feature counts, and specifies key model hyperparameters such as hidden dimensions, latent space size, dropout, and learning rate

**Note (for testing):**  
In `ae_metadata_cb`, both the hidden RNA dimension (`hidden_rna=128`) and the latent dimension (`latent_dim=16`) are intentionally set to very small values. This configuration is used for quick testing and validation, not for full-scale training.

In [13]:
%%time


def ae_metadata_cb(ad, metadata):
    cell_line_metadata_cb(ad, metadata)
    metadata["cell_lines"] = np.sort(np.unique(ad.obs["cell_line"].to_numpy()))
    metadata["nfeatures_rna"] = metadata["num_genes"]
    metadata["nfeatures_pro"] = 0
    metadata["hidden_rna"] = 128
    metadata["hidden_pro"] = 0
    metadata["latent_dim"] = 16
    metadata["p"] = 0.1
    metadata["lr"] = 1e-3

CPU times: user 16 μs, sys: 0 ns, total: 16 μs
Wall time: 25.7 μs


### Training the CiteAutoencoder model
- The dataset (`Dcl`) is provided along with key model parameters such as RNA/protein feature counts, hidden layer sizes, latent dimension, dropout p, and learning rate lr, all supplied through the `ae_metadata_cb` callback.

In [14]:
%%time
autoencoder_trainer = RayTrainRunner(
    CiteAutoencoder,
    Dcl,
    ["nfeatures_rna", "nfeatures_pro", "hidden_rna", "hidden_pro", "latent_dim", "p", "lr"],
    metadata_cb=ae_metadata_cb,
)

2025-09-24 11:10:06,808	INFO worker.py:1951 -- Started a local Ray instance.


CPU times: user 98.4 ms, sys: 249 ms, total: 347 ms
Wall time: 3.66 s
[36m(TrainTrainable pid=1179158)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1179158)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1179295)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1179158)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1179158)[0m - (node_id=f80fa4e223aee996e709cca29ff756652a6e60f007fef994b41ebd7e, ip=192.168.1.226, pid=1179295) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1179295)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1179295)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1179295)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1179295)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1179295)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1179295)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1179295)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pytho ...
[36m(RayTrainWorker pid=1179295)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1179295)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1179295)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=1179295)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
                                                                           


[36m(RayTrainWorker pid=1179295)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 3/4192 [00:21<8:23:39,  0.14it/s, v_num=0] 
Epoch 0:   0%|          | 6/4192 [00:21<4:12:52,  0.28it/s, v_num=0]
Epoch 0:   0%|          | 9/4192 [00:21<2:49:14,  0.41it/s, v_num=0]
Epoch 0:   0%|          | 10/4192 [00:21<2:32:31,  0.46it/s, v_num=0]
Epoch 0:   0%|          | 10/4192 [00:21<2:32:31,  0.46it/s, v_num=0]
Epoch 0:   0%|          | 13/4192 [00:21<1:57:46,  0.59it/s, v_num=0]
Epoch 0:   0%|          | 17/4192 [00:22<1:30:30,  0.77it/s, v_num=0]
Epoch 0:   0%|          | 20/4192 [00:22<1:17:13,  0.90it/s, v_num=0]
Epoch 0:   1%|          | 24/4192 [00:22<1:04:41,  1.07it/s, v_num=0]
Epoch 0:   1%|          | 24/4192 [00:22<1:04:41,  1.07it/s, v_num=0]
Epoch 0:   1%|          | 27/4192 [00:22<57:42,  1.20it/s, v_num=0]  
Epoch 0:   1%|          | 28/4192 [00:22<55:42,  1.25it/s, v_num=0]
Epoch 0:   1%|          | 31/4192 [00:22<50:29,  1.37it/s, v_num=0]
Epoch 0:   1%|          | 32/4192 [00:22<48:5

[36m(RayTrainWorker pid=1179295)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-10-35/TorchTrainer_17e07_00000_0_2025-09-24_11-10-35/checkpoint_000000)


Epoch 0: 100%|██████████| 4192/4192 [09:41<00:00,  7.21it/s, v_num=0]


[36m(RayTrainWorker pid=1179295)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4192/4192 [09:41<00:00,  7.20it/s, v_num=0]


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12GiB) + 96 CPUs + 125GiB RAM**, `autoencoder_trainer()` finished in **11 minutes**

In [15]:
%%time
autoencoder_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
    thread_per_worker=thread_per_worker,
)
ray.shutdown()

Using 1 workers with {'CPU': 2} each


2025-09-24 11:10:35,128	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 27.18 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 11:10:35 (running for 00:00:00.13)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-10-04_295190_1141478/artifacts/2025-09-24_11-10-35/TorchTrainer_2025-09-24_11-10-35/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 11:10:40 (running for 00:00:05.18)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-10-04_295190_1141478/artifacts/2025-09-24_11-10-35/TorchTrainer_2025-09-24_11-10-35/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 11:10:45 (running for 00:00:10.26)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-24 11:20:56,141	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-10-35' in 0.0226s.
2025-09-24 11:20:56,145	INFO tune.py:1041 -- Total run time: 621.02 seconds (620.98 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 11:20:56 (running for 00:10:21.00)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-10-04_295190_1141478/artifacts/2025-09-24_11-10-35/TorchTrainer_2025-09-24_11-10-35/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 33.6 s, sys: 4.37 s, total: 38 s
Wall time: 10min 50s


## 5. DistributedClassifierTrainingPlan
- **ClassifierTrainingPlan** (from `scvi-tools`) is not a model itself, but a training plan.  
  Its purpose is to coordinate the entire training workflow of an scvi-tools classifier, including optimization, scheduling, and evaluation.  
- For details, see the [source code](https://github.com/scverse/scvi-tools/blob/main/src/scvi/train/_trainingplans.py#L1479).

In [16]:
# install scvi:
# uv add scvi-tools in terminal

### Classifier Training metadata callback
Calls `cell_line_metadata_cb` to extract `num_genes` and `num_classes` from the input AnnData object.

In [17]:
%%time


def clf_metadata_cb(ad, metadata):
    # Populate num_genes / num_classes from the AnnData file
    cell_line_metadata_cb(ad, metadata)

    # Create the classifier instance and attach it to metadata
    metadata["classifier"] = Classifier(
        n_input=metadata["num_genes"],
        n_labels=metadata["num_classes"],
        logits=True,  # ClassifierTrainingPlan requirement that the module returns logits
    )
    metadata["lr"] = 1e-3
    metadata["weight_decay"] = 1e-6
    metadata["eps"] = 0.01
    metadata["optimizer"] = "Adam"

CPU times: user 22 μs, sys: 0 ns, total: 22 μs
Wall time: 33.6 μs


The `DistributedClassifierTrainingPlan` subclass extends `ClassifierTrainingPlan` by explicitly defining its own `training_step` and `validation_step`:

In [18]:
%%time


class DistributedClassifierTrainingPlan(ClassifierTrainingPlan):
    def training_step(self, batch, batch_idx):
        """Training step for classifier training."""
        x, y = batch
        soft_prediction = self.forward(x)
        loss = self.loss_fn(soft_prediction, y.view(-1).long())
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step for classifier training."""
        x, y = batch
        soft_prediction = self.forward(x)
        loss = self.loss_fn(soft_prediction, y.view(-1).long())
        self.log("validation_loss", loss)

CPU times: user 292 μs, sys: 0 ns, total: 292 μs
Wall time: 303 μs


### Executing ClassifierTrainingPlan

In [19]:
%%time
from protoplast.scrna.anndata.torch_dataloader import DistributedCellLineAnnDataset as Dcl

ClassifierTrainingPlan_trainer = RayTrainRunner(
    Model=DistributedClassifierTrainingPlan,
    Ds=Dcl,
    model_keys=["classifier", "lr", "weight_decay", "eps", "optimizer"],
    metadata_cb=clf_metadata_cb,
)

2025-09-24 11:21:28,271	INFO worker.py:1951 -- Started a local Ray instance.


CPU times: user 110 ms, sys: 253 ms, total: 362 ms
Wall time: 3.72 s
[36m(TrainTrainable pid=1192577)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1192577)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1192897)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1192577)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1192577)[0m - (node_id=f36cc440d80dfc099158a8e73d3e08d928895b6225c2d92911afd93b, ip=192.168.1.226, pid=1192897) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1192897)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1192897)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1192897)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1192897)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1192897)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1192897)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1192897)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pytho ...
[36m(RayTrainWorker pid=1192897)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly 



[36m(RayTrainWorker pid=1192897)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1192897)[0m 
[36m(RayTrainWorker pid=1192897)[0m   | Name    | Type             | Params | Mode 
[36m(RayTrainWorker pid=1192897)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1192897)[0m 0 | module  | Classifier       | 8.0 M  | train
[36m(RayTrainWorker pid=1192897)[0m 1 | loss_fn | CrossEntropyLoss | 0      | train
[36m(RayTrainWorker pid=1192897)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1192897)[0m 8.0 M     Trainable params
[36m(RayTrainWorker pid=1192897)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=1192897)[0m 8.0 M     Total params
[36m(RayTrainWorker pid=1192897)[0m 32.135    Total estimated model params size (MB)
[36m(RayTrainWorker pid=1192897)[0m 11        Modules in train mode
[36m(RayTrainWorker pid=1192897)[0m 0         Modules in eval mode
[36m(RayTrainWork

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1192897)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1192897)[0m   return torch.sparse_compressed_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  5.45it/s]


[36m(RayTrainWorker pid=1192897)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1192897)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('validation_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]                          
Epoch 0:   0%|          | 1/4192 [00:21<25:27:26,  0.05it/s, v_num=0, train_loss_step=3.980]
Epoch 0:   0%|          | 4/4192 [00:22<6:24:31,  0.18it/s, v_num=0, train_loss_step=3.170] 
Epoch 0:   0%|          | 8/4192 [00:22<3:12:56,  0.36it/s, v_num=0, train_loss_step=2.620]
Epoch 0:   0%|          | 9/4192 [00:22<2:51:39,  0.41it/s, v_num=0, train_loss_step=2.510]
Epoch 0:   0%|          | 13/4192 [00:22<1:59:16,  0.58it/s, v_num=0, train_loss_step=2.220]
Epoch 0:   0%|          | 14/4192 [00:22<1:50:51,  0.63it/s, v_num=0, train_loss_step=2.220]
Epoch 0:   0%|          | 14/4192 [00:22<1:50:51,  0.63it/s, v_num=0, train_loss_step=2.280]
Epoch 0:   0%|          | 19/4192 [00:22<1:22:01,  0.85it/s, v_num=0, train_loss_step=1.950]
Epoch 0:   0%|          | 19/4192 [00:22<1:22:01,  0.85it/s, v_num=0, train_loss_step=1.800]
Epoch 0:   1%|          | 23/4192 [00:22<1:08:01,  1.02it/s, v_num=0, train_loss_step=1.9

[36m(RayTrainWorker pid=1192897)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=1192897)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-21-49/TorchTrainer_a9e9d_00000_0_2025-09-24_11-21-49/checkpoint_000000)


Epoch 0: 100%|██████████| 4192/4192 [08:46<00:00,  7.97it/s, v_num=0, train_loss_step=0.104, train_loss_epoch=0.181]


[36m(RayTrainWorker pid=1192897)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4192/4192 [08:46<00:00,  7.96it/s, v_num=0, train_loss_step=0.104, train_loss_epoch=0.181]




On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12GiB) + 96 CPUs + 125GiB RAM**, `ClassifierTrainingPlan_trainer()` finished in **10 minutes**

In [20]:
%%time
ClassifierTrainingPlan_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
    thread_per_worker=thread_per_worker,  # 2
)
ray.shutdown()

Using 1 workers with {'CPU': 2} each


2025-09-24 11:21:49,629	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 20.11 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 11:21:50 (running for 00:00:00.78)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-21-25_731212_1141478/artifacts/2025-09-24_11-21-49/TorchTrainer_2025-09-24_11-21-49/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 11:21:55 (running for 00:00:05.79)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-21-25_731212_1141478/artifacts/2025-09-24_11-21-49/TorchTrainer_2025-09-24_11-21-49/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 11:22:00 (running for 00:00:10.91)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-24 11:31:18,075	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-21-49' in 2.7891s.
2025-09-24 11:31:18,084	INFO tune.py:1041 -- Total run time: 568.46 seconds (565.64 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 11:31:18 (running for 00:09:28.43)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-21-25_731212_1141478/artifacts/2025-09-24_11-21-49/TorchTrainer_2025-09-24_11-21-49/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 45.3 s, sys: 17.4 s, total: 1min 2s
Wall time: 9min 50s
