# Classification Models for Single-Cell Data with PROTOplast

This tutorial demonstrates how to use PROTOplast to train different classification models in PyTorch with the `h5ad` format.

**Download the Tahoe-100M `h5ad` files**
- The Tahoe-100M dataset can be downloaded in `h5ad` format from the **Arc Institute Google Cloud Storage**.
- For step-by-step instructions, see the [official tutorial](https://github.com/ArcInstitute/arc-virtual-cell-atlas/blob/main/tahoe-100M/README.md).

**Setup**  
- Configure the training environment for single-cell RNA sequencing (scRNA-seq) data using **PROTOplast** in combination with **PyTorch Lightning** and **Ray**.

In [1]:
import protoplast

âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


root - INFO - Logging initialized. Current level is: INFO


In [2]:
from importlib.metadata import version

print(version("protoplast"))

0.1.2


In [3]:
%%time
import anndata
import numpy as np
import ray

# models
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.torch_dataloader import DistributedCellLineAnnDataset as Dcl
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner
from ray.train.lightning import RayDDPStrategy
from scsims.model import SIMSClassifier

# scvi training plan
## install scvi-tools if needed:
## uv add scvi-tools
from scvi.module import Classifier
from scvi.train import ClassifierTrainingPlan

CPU times: user 372 ms, sys: 96.8 ms, total: 469 ms
Wall time: 474 ms


## 1. Load the Tahoe 100-M Dataset (`h5ad`)
- `file_paths`: Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas
- `batch_size`: number of samples per training batch
- `test_size`: fraction of data reserved for testing (use `0.0` if no test set is needed)
- `val_size`: fraction of data reserved for validation 


In [4]:
%%time
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
batch_size = 2000
test_size = 0.0
val_size = 0.2

CPU times: user 10 Î¼s, sys: 1e+03 ns, total: 11 Î¼s
Wall time: 20.7 Î¼s


## 2. Simple Classifier

This example illustrates how to configure a training runner with **PROTOplast** and **Ray**.

- `LinearClassifier`: a simple baseline model that can be swapped with a custom implementation
- `Dcl`: the dataset object for training, imported from `protoplast.scrna.anndata.torch_dataloader`
  - Defined as a subclass of `DistributedAnnDataset`, customized for cell line classification tasks
- `["num_genes", "num_classes"]`: arguments that specify the modelâ€™s input and output dimensions
- `cell_line_metadata_cb`: a callback function that attaches dataset-specific metadata, such as cell line labels and class counts

In [5]:
%%time
LinearClassifier_trainer = RayTrainRunner(
    LinearClassifier,  # replace with your own model
    Dcl,  # replace with your own Dataset
    ["num_genes", "num_classes"],  # change according to what you need for your model
    cell_line_metadata_cb,  # include data you need for your dataset
)

2025-11-04 06:50:14,623	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-11-04 06:50:14,647	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-04 06:50:14,665	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_eea06cd8e8481dee.zip' (3.11MiB) to Ray cluster...
2025-11-04 06:50:14,685	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_eea06cd8e8481dee.zip'.


CPU times: user 284 ms, sys: 337 ms, total: 622 ms
Wall time: 9.26 s




On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `LinearClassifier_trainer.train()` completed in approximately **8 minutes**.

In [6]:
%%time
LinearClassifier_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

protoplast.scrna.anndata.trainer - INFO - Setting thread_per_worker to half of the available CPUs capped at 4
protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 4, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 65 length of test_split: 0, length of train_split: 262
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 65, length of test_split: 0, length of train_split: 262


[36m(TrainController pid=1052917)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=1052917)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=1052917)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1052917)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 4, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=1053233)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=1053233)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=1053233)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=1053233)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1052917)[0m Started training worker group of size 1: 
[36m(TrainController pid=1052917)[0m - (ip=192.168.1.226, pid=1053233) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=1053233)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1053233)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1053233)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1053233)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1053233)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/e

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1053233)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1053233)[0m 
[36m(RayTrainWorker pid=1053233)[0m   | Name    | Type             | Params | Mode 
[36m(RayTrainWorker pid=1053233)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1053233)[0m 0 | model   | Linear           | 3.1 M  | train
[36m(RayTrainWorker pid=1053233)[0m 1 | loss_fn | CrossEntropyLoss | 0      | train
[36m(RayTrainWorker pid=1053233)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1053233)[0m 3.1 M     Trainable params
[36m(RayTrainWorker pid=1053233)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=1053233)[0m 3.1 M     Total params
[36m(RayTrainWorker pid=1053233)[0m 12.542    Total estimated model params size (MB)
[36m(RayTrainWorker pid=1053233)[0m 2         Modules in train mode
[36m(RayTrainWorker pid=1053233)[0m 0         Modules in eval mode
[36m(RayTrainWork

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:00<00:00,  5.11it/s]
Sanity Checking DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00,  9.28it/s]
                                                                           


[36m(RayTrainWorker pid=1053233)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4192 [00:26<31:03:41,  0.04it/s, v_num=0, train_loss=4.110]
...
...
Epoch 0:  99%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4170/4192 [05:12<00:01, 13.34it/s, v_num=0, train_loss=0.0547]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4175/4192 [05:12<00:01, 13.36it/s, v_num=0, train_loss=0.0602]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4176/4192 [05:12<00:01, 13.36it/s, v_num=0, train_loss=0.146] 
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4182/4192 [05:12<00:00, 13.37it/s, v_num=0, train_loss=0.108]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4188/4192 [05:12<00:00, 13.39it/s, v_num=0, train_loss=0.136] 
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4189/4192 [05:12<00:00, 13.39it/s, v_num=0, train_loss=0.055]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [05:12<00:00, 13.40it/s, v_num=0, train_loss=0.157] 
[36m(RayTrainWorker pid=1053233)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m

[36m(RayTrainWorker pid=1053233)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_06-50-37/checkpoint_2025-11-04_06-58-22.072518)
[36m(RayTrainWorker pid=1053233)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_06-50-37/checkpoint_2025-11-04_06-58-22.072518), metrics={'train_loss': 0.15701858699321747, 'val_acc': 0.9862620234489441, 'epoch': 0, 'step': 4192}, validation_spec=None)


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [06:58<00:00, 10.02it/s, v_num=0, train_loss=0.157]


[36m(RayTrainWorker pid=1053233)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 21.8 s, sys: 3.7 s, total: 25.5 s
Wall time: 8min 11s


## 3. SIMS: Scalable, Interpretable Models for Cell Annotation of large scale single-cell RNA-seq data
**SIMS** is a pipeline designed to build interpretable and accurate classifiers for identifying any target in single-cell RNA sequencing (scRNA-seq) data.  
- The core SIMS model is based on a **sequential transformer**, a specialized transformer architecture built for large-scale tabular datasets. 
- SIMS provides a framework for **cell type annotation**: it trains on labeled single-cell data and predicts cell type labels for new, unlabeled cells. 
- It leverages the **TabNet** deep learning model, which automatically selects the most informative genes for each prediction, ensuring results that are both **accurate** and **interpretable**.  
For implementation details and source code, see the [SIMS GitHub repository](https://github.com/braingeneers/SIMS/tree/main).

### SIMS Metadata Callback
This callback (`sims_metadata_cb`) extracts key information from the AnnData object to configure the SIMS model.
- `input_dim`: the number of genes (features) in the dataset.
- `cell_lines`: list of unique cell line categories.
- `output_dim`: the number of distinct classes (cell lines) to be predicted.

In [7]:
def sims_metadata_cb(ad: anndata.AnnData, metadata: dict):
    metadata["num_genes"] = ad.var.shape[0]
    metadata["input_dim"] = metadata["num_genes"]
    metadata["cell_lines"] = ad.obs["cell_line"].cat.categories.to_list()
    metadata["num_classes"] = len(metadata["cell_lines"])
    metadata["output_dim"] = metadata["num_classes"]

### Training the SIMS Classifier

- The **SIMSClassifier** model is initialized with the dataset (`Dcl`), while essential arguments (`input_dim`, `output_dim`) are supplied through the `sims_metadata_cb` callback 
- Training is distributed using **RayDDPStrategy**, with `find_unused_parameters=True` enabled to ensure proper handling of layers that may not be active in every forward pass


In [8]:
%%time
sims_trainer = RayTrainRunner(
    SIMSClassifier,
    Dcl,
    ["input_dim", "output_dim"],  # maps to SIMSClassifier(input_dim, output_dim)
    sims_metadata_cb,
    ray_trainer_strategy=RayDDPStrategy(find_unused_parameters=True),
)

2025-11-04 06:58:51,891	ERROR services.py:1360 -- Failed to start the dashboard 
2025-11-04 06:58:51,896	ERROR services.py:1385 -- Error should be written to 'dashboard.log' or 'dashboard.err'. We are printing the last 20 lines for you. See 'https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#logging-directory-structure' to find where the log file is.
2025-11-04 06:58:51,898	ERROR services.py:1429 -- 
The last 20 lines of /tmp/ray/session_2025-11-04_06-58-29_922337_1042730/logs/dashboard.log (it contains the error message from the dashboard): 
2025-11-04 06:58:53,184	INFO worker.py:2012 -- Started a local Ray instance.
2025-11-04 06:58:53,410	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-04 06:58:53,420	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_2df2b05812cc5858.zip' (0.92MiB) to Ray cluster...
2025-11-04 06:58:53,428	INFO packaging.py:393 -- Successfully push

CPU times: user 303 ms, sys: 321 ms, total: 624 ms
Wall time: 27.7 s


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `sims_trainer.train()` completed in about **14 minutes**.

In [9]:
%%time
sims_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

protoplast.scrna.anndata.trainer - INFO - Setting thread_per_worker to half of the available CPUs capped at 4
protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 4, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 65 length of test_split: 0, length of train_split: 262
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 65, length of test_split: 0, length of train_split: 262


[36m(TrainController pid=1072383)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=1072383)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=1072383)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1072383)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 4, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=1073355)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=1073355)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=1073355)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=1073355)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1072383)[0m Started training worker group of size 1: 
[36m(TrainController pid=1072383)[0m - (ip=192.168.1.226, pid=1073355) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=1073355)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1073355)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1073355)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1073355)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1073355)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/e

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1073355)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
[36m(RayTrainWorker pid=1073355)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1073355)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=1073355)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1073355)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1073355)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:01<00:01,  0.56it/s]
Sanity Checking DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:01<00:00,  1.08it/s]
                                                                           


[36m(RayTrainWorker pid=1073355)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=1073355)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=1073355)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/macro_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the 

Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4192 [00:26<31:02:06,  0.04it/s, v_num=0, train/loss_step=4.940]
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4190/4192 [11:06<00:00,  6.29it/s, v_num=0, train/loss_step=0.400]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4191/4192 [11:06<00:00,  6.29it/s, v_num=0, train/loss_step=0.456]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [11:06<00:00,  6.29it/s, v_num=0, train/loss_step=0.480]
[36m(RayTrainWorker pid=1073355)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=1073355)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
...
...
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1036/1040 [01:33<00:00, 11.11it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1037/1040 [01:33<00:00, 11.12it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1038/1040 [01:33<00:00, 11.12it/s][A
[36m(RayTrainWorker pid=1

[36m(RayTrainWorker pid=1073355)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


[36m(RayTrainWorker pid=1073355)[0m 
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [13:02<00:00,  5.36it/s, v_num=0, train/loss_step=0.480, val/loss=0.596, val/f1=0.820, val/macro_acc=0.814, val/micro_acc=0.921, val/precision=0.833, val/recall=0.814, val/specificity=0.998, val/weighted_acc=0.921]


[36m(RayTrainWorker pid=1073355)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_06-59-17/checkpoint_2025-11-04_07-15-15.179172)
[36m(RayTrainWorker pid=1073355)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_06-59-17/checkpoint_2025-11-04_07-15-15.179172), metrics={'train/loss': 0.48666849732398987, 'train/loss_step': 0.479697585105896, 'val/loss': 0.5956600904464722, 'val/f1': 0.8201902508735657, 'val/macro_acc': 0.813599705696106, 'val/micro_acc': 0.9213297963142395, 'val/precision': 0.8329737782478333, 'val/recall': 0.813599705696106, 'val/specificity': 0.9983789920806885, 'val/weighted_acc': 0.9213297963142395, 'train/loss_epoch': 0.48666849732398987, 'epoch': 0, 'step': 4192}, validation_spec=None)
[36m(RayTrainWorker pid=1073355)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site

Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [13:02<00:00,  5.35it/s, v_num=0, train/loss_step=0.480, val/loss=0.596, val/f1=0.820, val/macro_acc=0.814, val/micro_acc=0.921, val/precision=0.833, val/recall=0.814, val/specificity=0.998, val/weighted_acc=0.921, train/loss_epoch=0.487]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [13:03<00:00,  5.35it/s, v_num=0, train/loss_step=0.480, val/loss=0.596, val/f1=0.820, val/macro_acc=0.814, val/micro_acc=0.921, val/precision=0.833, val/recall=0.814, val/specificity=0.998, val/weighted_acc=0.921, train/loss_epoch=0.487]
CPU times: user 28.1 s, sys: 5.54 s, total: 33.7 s
Wall time: 16min 22s


## 4. Autoencoder
- An **autoencoder** is an unsupervised neural network consisting of three main components:  
  - **Encoder**: compresses the input into a lower-dimensional representation.  
  - **Bottleneck**: stores the compressed features.  
  - **Decoder**: reconstructs the input from the bottleneck representation.  
- In this setup, separate encoders process **gene** and **protein** data. Their outputs are concatenated, passed through an additional encoder to form the bottleneck, and then decoded back to the original input.  
- Since **Tahoe-100M** does not include protein data, the protein input is set to `0`, and the source code was adapted to ensure compatibility with datasets lacking protein features.
- For testing purposes, we temporarily set mid = 128, which reduces the hidden layer size and simplifies the model architecture. For implementation details, see the [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [10]:
# group linear, batchnorm, and dropout layers. This module was from citeseq_autoencoder notebook
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from torch import nn, optim


class LinBnDrop(nn.Sequential):
    """Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers, adapted from fastai."""

    def __init__(self, n_in, n_out, bn=True, p=0.0, act=None, lin_first=True):
        layers = [nn.BatchNorm1d(n_out if lin_first else n_in)] if bn else []
        if p != 0:
            layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None:
            lin.append(act)
        layers = lin + layers if lin_first else layers + lin
        super().__init__(*layers)

We implement an encoder that processes RNA features through a two-layer MLP (`nfeatures_rna` â†’ `mid=128` â†’ `hidden_rna`, with `mid=2` set for testing). The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [11]:
class Encoder(nn.Module):
    """Encoder for CITE-seq data"""

    def __init__(
        self, nfeatures_rna: int, nfeatures_pro: int, hidden_rna: int, hidden_pro: int, latent_dim: int, p: float = 0
    ):
        super().__init__()
        self.nfeatures_rna = nfeatures_rna
        self.nfeatures_pro = nfeatures_pro

        if nfeatures_rna > 0:
            mid = 128  # 128 is for testing the code
            self.encoder_rna = nn.Sequential(
                LinBnDrop(nfeatures_rna, mid, p=p, act=nn.LeakyReLU()),
                LinBnDrop(mid, hidden_rna, act=nn.LeakyReLU()),
            )

        if nfeatures_pro > 0:
            self.encoder_protein = LinBnDrop(nfeatures_pro, hidden_pro, p=p, act=nn.LeakyReLU())

        # make sure hidden_rna and hidden_pro are set correctly
        hidden_rna = 0 if nfeatures_rna == 0 else hidden_rna
        hidden_pro = 0 if nfeatures_pro == 0 else hidden_pro

        hidden_dim = hidden_rna + hidden_pro

        self.encoder = LinBnDrop(hidden_dim, latent_dim, act=nn.LeakyReLU())

    def forward(self, x):
        if self.nfeatures_rna > 0 and self.nfeatures_pro > 0:
            x_rna = self.encoder_rna(x[:, : self.nfeatures_rna])
            x_pro = self.encoder_protein(x[:, self.nfeatures_rna :])
            x = torch.cat([x_rna, x_pro], 1)
        elif self.nfeatures_rna > 0 and self.nfeatures_pro == 0:
            x = self.encoder_rna(x)
        elif self.nfeatures_rna == 0 and self.nfeatures_pro > 0:
            x = self.encoder_protein(x)
        return self.encoder(x)

We implement a decoder that maps the latent vector to the RNA feature space by first expanding it to `hidden_rna`, passing it through a small intermediate layer (`mid_out` = `128`, used for testing), and finally projecting it to the RNA output dimension. The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [12]:
class Decoder(nn.Module):
    """Decoder for CITE-seq data"""

    def __init__(self, nfeatures_rna: int, nfeatures_pro: int, hidden_rna: int, hidden_pro: int, latent_dim: int):
        super().__init__()
        # make sure hidden_rna and hidden_pro are set correctly
        hidden_rna = 0 if nfeatures_rna == 0 else hidden_rna
        hidden_pro = 0 if nfeatures_pro == 0 else hidden_pro

        hidden_dim = hidden_rna + hidden_pro
        out_dim = nfeatures_rna + nfeatures_pro
        mid_out = 128  # 128 is for testing the code

        self.decoder = nn.Sequential(
            LinBnDrop(latent_dim, hidden_dim, act=nn.LeakyReLU()),
            LinBnDrop(hidden_dim, mid_out, act=nn.LeakyReLU()),
            LinBnDrop(mid_out, out_dim, bn=False),
        )

    def forward(self, x):
        return self.decoder(x)

The encoder and decoder are assembled into an autoencoder, which is defined as a PyTorch Lightning Module to simplify the training process. The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb)

In [13]:
class CiteAutoencoder(pl.LightningModule):
    def __init__(
        self,
        nfeatures_rna: int,
        nfeatures_pro: int,
        hidden_rna: int,
        hidden_pro: int,
        latent_dim: int,
        p: float = 0,
        lr: float = 0.1,
    ):
        """Autoencoder for citeseq data"""
        super().__init__()

        # save hyperparameters
        self.save_hyperparameters()

        self.encoder = Encoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim, p)
        self.decoder = Decoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim)

        # example input array for visualizing network graph
        self.example_input_array = torch.zeros(256, nfeatures_rna + nfeatures_pro)

    def forward(self, x):
        # extract latent embeddings
        z = self.encoder(x)
        return z

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def _get_reconstruction_loss(self, batch):
        """Calculate MSE loss for a given batch."""
        x, _ = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        # MSE loss
        loss = F.mse_loss(x_hat, x)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

### Autoencoder Metadata Callback
- The `ae_metadata_cb` function extends `cell_line_metadata_cb` and configures the metadata required for training the autoencoder. It sets up cell line information, defines feature counts, and specifies key model hyperparameters such as hidden dimensions, latent space size, dropout, and learning rate

**Note (for testing):**  
In `ae_metadata_cb`, both the hidden RNA dimension (`hidden_rna=128`) and the latent dimension (`latent_dim=16`) are intentionally set to very small values. This configuration is used for quick testing and validation, not for full-scale training.

In [14]:
def ae_metadata_cb(ad, metadata):
    cell_line_metadata_cb(ad, metadata)
    metadata["cell_lines"] = np.sort(np.unique(ad.obs["cell_line"].to_numpy()))
    metadata["nfeatures_rna"] = metadata["num_genes"]
    metadata["nfeatures_pro"] = 0
    metadata["hidden_rna"] = 128
    metadata["hidden_pro"] = 0
    metadata["latent_dim"] = 16
    metadata["p"] = 0.1
    metadata["lr"] = 1e-3

### Training the CiteAutoencoder model
- The dataset (`Dcl`) is provided along with key model parameters such as RNA/protein feature counts, hidden layer sizes, latent dimension, dropout p, and learning rate lr, all supplied through the `ae_metadata_cb` callback.

In [15]:
%%time
autoencoder_trainer = RayTrainRunner(
    CiteAutoencoder,
    Dcl,
    ["nfeatures_rna", "nfeatures_pro", "hidden_rna", "hidden_pro", "latent_dim", "p", "lr"],
    metadata_cb=ae_metadata_cb,
)

2025-11-04 07:15:31,804	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m
2025-11-04 07:15:31,824	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-04 07:15:31,834	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_0a9329e6e0c18e86.zip' (1.38MiB) to Ray cluster...
2025-11-04 07:15:31,846	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_0a9329e6e0c18e86.zip'.


CPU times: user 203 ms, sys: 316 ms, total: 519 ms
Wall time: 15.1 s


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12GiB) + 96 CPUs + 125GiB RAM**, `autoencoder_trainer()` finished in **8 minutes**

In [16]:
%%time
autoencoder_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

protoplast.scrna.anndata.trainer - INFO - Setting thread_per_worker to half of the available CPUs capped at 4
protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 4, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 65 length of test_split: 0, length of train_split: 262
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 65, length of test_split: 0, length of train_split: 262


[36m(TrainController pid=1097366)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=1097366)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=1097366)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1097366)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 4, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=1098099)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=1098099)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=1098099)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=1098099)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1097366)[0m Started training worker group of size 1: 
[36m(TrainController pid=1097366)[0m - (ip=192.168.1.226, pid=1098099) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=1098099)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1098099)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1098099)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1098099)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1098099)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/e

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1098099)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
[36m(RayTrainWorker pid=1098099)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1098099)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1098099)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=1098099)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=1098099)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00,  8.45it/s]
                                                                           


[36m(RayTrainWorker pid=1098099)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4192 [00:29<34:47:22,  0.03it/s, v_num=0]
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4186/4192 [06:55<00:00, 10.08it/s, v_num=0]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4189/4192 [06:55<00:00, 10.09it/s, v_num=0]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [06:55<00:00, 10.09it/s, v_num=0]
[36m(RayTrainWorker pid=1098099)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=1098099)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
...
...
Validation DataLoader 0:  99%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1034/1040 [01:09<00:00, 14.94it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1035/1040 [01:09<00:00, 14.95it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1036/1040 [01:09<00:00, 14.96it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1037/1040 [01:09<00:00, 14.97it/s][A
Validation DataLoa

[36m(RayTrainWorker pid=1098099)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_07-16-01/checkpoint_2025-11-04_07-25-25.031792)
[36m(RayTrainWorker pid=1098099)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_07-16-01/checkpoint_2025-11-04_07-25-25.031792), metrics={'train_loss': 0.21229848265647888, 'val_loss': 0.07657571136951447, 'epoch': 0, 'step': 4192}, validation_spec=None)


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [08:30<00:00,  8.22it/s, v_num=0]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [08:30<00:00,  8.21it/s, v_num=0]


[36m(RayTrainWorker pid=1098099)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 29.3 s, sys: 3.68 s, total: 33 s
Wall time: 9min 53s


## 5. DistributedClassifierTrainingPlan
- **ClassifierTrainingPlan** (from `scvi-tools`) is not a model itself, but a training plan.  
  Its purpose is to coordinate the entire training workflow of an scvi-tools classifier, including optimization, scheduling, and evaluation.  
- For details, see the [source code](https://github.com/scverse/scvi-tools/blob/main/src/scvi/train/_trainingplans.py#L1479).

### Classifier Training metadata callback
Calls `cell_line_metadata_cb` to extract `num_genes` and `num_classes` from the input AnnData object.

In [17]:
def clf_metadata_cb(ad, metadata):
    # Populate num_genes / num_classes from the AnnData file
    cell_line_metadata_cb(ad, metadata)

    # Create the classifier instance and attach it to metadata
    metadata["classifier"] = Classifier(
        n_input=metadata["num_genes"],
        n_labels=metadata["num_classes"],
        logits=True,  # ClassifierTrainingPlan requirement that the module returns logits
    )
    metadata["lr"] = 1e-3
    metadata["weight_decay"] = 1e-6
    metadata["eps"] = 0.01
    metadata["optimizer"] = "Adam"

The `DistributedClassifierTrainingPlan` subclass extends `ClassifierTrainingPlan` by explicitly defining its own `training_step` and `validation_step`:

In [18]:
class DistributedClassifierTrainingPlan(ClassifierTrainingPlan):
    def training_step(self, batch, batch_idx):
        """Training step for classifier training."""
        x, y = batch
        soft_prediction = self.forward(x)
        loss = self.loss_fn(soft_prediction, y.view(-1).long())
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step for classifier training."""
        x, y = batch
        soft_prediction = self.forward(x)
        loss = self.loss_fn(soft_prediction, y.view(-1).long())
        self.log("validation_loss", loss)

### Executing ClassifierTrainingPlan

In [19]:
%%time
ClassifierTrainingPlan_trainer = RayTrainRunner(
    Model=DistributedClassifierTrainingPlan,
    Ds=Dcl,
    model_keys=["classifier", "lr", "weight_decay", "eps", "optimizer"],
    metadata_cb=clf_metadata_cb,
)

2025-11-04 07:25:36,178	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-11-04 07:25:36,221	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-04 07:25:36,231	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_073c9cf1e7e22cdb.zip' (1.94MiB) to Ray cluster...
2025-11-04 07:25:36,246	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_073c9cf1e7e22cdb.zip'.


CPU times: user 187 ms, sys: 317 ms, total: 505 ms
Wall time: 10.2 s


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12GiB) + 96 CPUs + 125GiB RAM**, `ClassifierTrainingPlan_trainer()` finished in **8 minutes**

In [20]:
%%time
ClassifierTrainingPlan_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

protoplast.scrna.anndata.trainer - INFO - Setting thread_per_worker to half of the available CPUs capped at 4
protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 4, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 65 length of test_split: 0, length of train_split: 262
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 65, length of test_split: 0, length of train_split: 262


[36m(TrainController pid=1116077)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=1116077)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=1116077)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1116077)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 4, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=1116630)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=1116630)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=1116630)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(RayTrainWorker pid=1116630)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TrainController pid=1116077)[0m Started training worker group of size 1: 
[36m(TrainController pid=1116077)[0m - (ip=192.168.1.226, pid=1116630) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=1116630)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1116630)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1116630)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1116630)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1116630)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/e

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1116630)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1116630)[0m 
[36m(RayTrainWorker pid=1116630)[0m   | Name    | Type             | Params | Mode 
[36m(RayTrainWorker pid=1116630)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1116630)[0m 0 | module  | Classifier       | 8.0 M  | train
[36m(RayTrainWorker pid=1116630)[0m 1 | loss_fn | CrossEntropyLoss | 0      | train
[36m(RayTrainWorker pid=1116630)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1116630)[0m 8.0 M     Trainable params
[36m(RayTrainWorker pid=1116630)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=1116630)[0m 8.0 M     Total params
[36m(RayTrainWorker pid=1116630)[0m 32.135    Total estimated model params size (MB)
[36m(RayTrainWorker pid=1116630)[0m 11        Modules in train mode
[36m(RayTrainWorker pid=1116630)[0m 0         Modules in eval mode
[36m(RayTrainWork

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1116630)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:00<00:00,  1.23it/s]
                                                                           


[36m(RayTrainWorker pid=1116630)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('validation_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4192 [00:29<34:24:10,  0.03it/s, v_num=0, train_loss_step=3.990]
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4187/4192 [05:36<00:00, 12.45it/s, v_num=0, train_loss_step=0.0984]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4190/4192 [05:36<00:00, 12.46it/s, v_num=0, train_loss_step=0.0704]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [05:36<00:00, 12.46it/s, v_num=0, train_loss_step=0.153] 
[36m(RayTrainWorker pid=1116630)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=1116630)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1040 [00:00<?, ?it/s][A
...
...
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1037/1040 [01:02<00:00, 16.51it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 1038/1040 [01:02<00:00, 16.52it/s][A
Validation DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ

[36m(RayTrainWorker pid=1116630)[0m /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [07:04<00:00,  9.87it/s, v_num=0, train_loss_step=0.153, train_loss_epoch=0.180]


[36m(RayTrainWorker pid=1116630)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_07-25-59/checkpoint_2025-11-04_07-33-57.004318)
[36m(RayTrainWorker pid=1116630)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_07-25-59/checkpoint_2025-11-04_07-33-57.004318), metrics={'train_loss': 0.1795438826084137, 'train_loss_step': 0.15308770537376404, 'validation_loss': 0.11337331682443619, 'train_loss_epoch': 0.1795438826084137, 'epoch': 0, 'step': 4192}, validation_spec=None)


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [07:04<00:00,  9.87it/s, v_num=0, train_loss_step=0.153, train_loss_epoch=0.180]


[36m(RayTrainWorker pid=1116630)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 21.9 s, sys: 3.47 s, total: 25.4 s
Wall time: 8min 22s
