# Classification Models for Single-Cell Data with PROTOplast

This tutorial demonstrates how to use PROTOplast to train different classification models in PyTorch with the `h5ad` format.

**Download the Tahoe-100M `h5ad` files**
- The Tahoe-100M dataset can be downloaded in `h5ad` format from the **Arc Institute Google Cloud Storage**.
- For step-by-step instructions, see the [official tutorial](https://github.com/ArcInstitute/arc-virtual-cell-atlas/blob/main/tahoe-100M/README.md).

**Setup**  
- Configure the training environment for single-cell RNA sequencing (scRNA-seq) data using **PROTOplast** in combination with **PyTorch Lightning** and **Ray**.

In [1]:
%%time
import anndata
import numpy as np
import ray

# models
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.torch_dataloader import DistributedCellLineAnnDataset as Dcl
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner
from ray.train.lightning import RayDDPStrategy
from scsims.model import SIMSClassifier

# scvi training plan
## install scvi-tools if needed:
## uv add scvi-tools
from scvi.module import Classifier
from scvi.train import ClassifierTrainingPlan

✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
CPU times: user 18.9 s, sys: 1.54 s, total: 20.4 s
Wall time: 8.36 s


## 1. Load the Tahoe 100-M Dataset (`h5ad`)
- `file_paths`: Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas
- `batch_size`: number of samples per training batch
- `test_size`: fraction of data reserved for testing (use `0.0` if no test set is needed)
- `val_size`: fraction of data reserved for validation 


In [2]:
%%time
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
batch_size = 2000
test_size = 0.0
val_size = 0.2

CPU times: user 10 μs, sys: 1 μs, total: 11 μs
Wall time: 21.9 μs


## 2. Simple Classifier

This example illustrates how to configure a training runner with **PROTOplast** and **Ray**.

- `LinearClassifier`: a simple baseline model that can be swapped with a custom implementation
- `Dcl`: the dataset object for training, imported from `protoplast.scrna.anndata.torch_dataloader`
  - Defined as a subclass of `DistributedAnnDataset`, customized for cell line classification tasks
- `["num_genes", "num_classes"]`: arguments that specify the model’s input and output dimensions
- `cell_line_metadata_cb`: a callback function that attaches dataset-specific metadata, such as cell line labels and class counts

In [3]:
%%time
LinearClassifier_trainer = RayTrainRunner(
    LinearClassifier,  # replace with your own model
    Dcl,  # replace with your own Dataset
    ["num_genes", "num_classes"],  # change according to what you need for your model
    cell_line_metadata_cb,  # include data you need for your dataset
)

2025-09-29 15:48:10,463	INFO worker.py:1951 -- Started a local Ray instance.
2025-09-29 15:48:10,692	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example'.
2025-09-29 15:48:11,056	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_dbb4b323aeef291d.zip' (69.15MiB) to Ray cluster...
2025-09-29 15:48:11,561	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_dbb4b323aeef291d.zip'.


CPU times: user 794 ms, sys: 726 ms, total: 1.52 s
Wall time: 8.63 s


[33m(raylet)[0m Using CPython [36m3.11.13[39m
[33m(raylet)[0m Creating virtual environment at: [36m.venv[39m
[33m(raylet)[0m [2mInstalled [1m296 packages[0m [2min 348ms[0m[0m


[36m(TrainTrainable pid=3295631)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=3295631)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3296882)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=3295631)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3295631)[0m - (node_id=6322a982fe7253de9609b5bb9b18d97905853f8cf5146fdbd78447c3, ip=192.168.1.226, pid=3296882) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=3296882)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=3296882)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3296882)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=3296882)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=3296882)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3296882)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3296882)[0m /tmp/ray/session_2025-09-29_15-48-06_722491_3287714/runtime_resources/working_dir_files/_ray_pkg_dbb4b323aeef291d/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3 /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pyth ...
[36m(RayTrainWorker pid=3296882)[0m You are using 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=3296882)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3296882)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=3296882)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3296882)[0m   return torch.sparse_csr_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


[36m(RayTrainWorker pid=3296882)[0m /tmp/ray/session_2025-09-29_15-48-06_722491_3287714/runtime_resources/working_dir_files/_ray_pkg_dbb4b323aeef291d/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


                                                                           
Epoch 0:   0%|          | 0/4160 [00:00<?, ?it/s] 


[36m(RayTrainWorker pid=3296882)[0m   return torch.sparse_csr_tensor(


Epoch 0:   0%|          | 1/4160 [00:29<33:58:56,  0.03it/s, v_num=0, train_loss=3.950]
Epoch 0:   0%|          | 3/4160 [00:31<12:11:50,  0.09it/s, v_num=0, train_loss=3.050]
.
.
.
Epoch 0: 100%|█████████▉| 4147/4160 [05:22<00:01, 12.86it/s, v_num=0, train_loss=0.146]
Epoch 0: 100%|█████████▉| 4148/4160 [05:22<00:00, 12.86it/s, v_num=0, train_loss=0.154]
Epoch 0: 100%|█████████▉| 4153/4160 [05:22<00:00, 12.87it/s, v_num=0, train_loss=0.0986]
Epoch 0: 100%|█████████▉| 4154/4160 [05:22<00:00, 12.88it/s, v_num=0, train_loss=0.0893]
Epoch 0: 100%|█████████▉| 4159/4160 [05:22<00:00, 12.89it/s, v_num=0, train_loss=0.108] 
Epoch 0: 100%|██████████| 4160/4160 [05:22<00:00, 12.89it/s, v_num=0, train_loss=0.130]
[36m(RayTrainWorker pid=3296882)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=3296882)[0m 
Validation:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|     

[36m(RayTrainWorker pid=3296882)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_15-48-41/TorchTrainer_c5f9c_00000_0_2025-09-29_15-48-43/checkpoint_000000)
[36m(RayTrainWorker pid=3296882)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4160/4160 [06:51<00:00, 10.12it/s, v_num=0, train_loss=0.130]
Epoch 0: 100%|██████████| 4160/4160 [06:51<00:00, 10.12it/s, v_num=0, train_loss=0.130]


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `LinearClassifier_trainer.train()` completed in approximately **9 minutes**.

In [4]:
%%time
LinearClassifier_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

Setting thread_per_worker to half of the available CPUs capped at 4
Using 1 workers with {'CPU': 4} each
Data splitting time: 26.24 seconds


2025-09-29 15:48:41,802	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-29 15:48:43 (running for 00:00:00.16)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_15-48-06_722491_3287714/artifacts/2025-09-29_15-48-41/TorchTrainer_2025-09-29_15-48-41/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-29 15:49:09 (running for 00:00:25.33)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_15-48-06_722491_3287714/artifacts/2025-09-29_15-48-41/TorchTrainer_2025-09-29_15-48-41/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-29 15:56:36 (running for 00:07:53.21)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2

2025-09-29 15:56:42,595	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-29_15-48-41' in 0.0100s.
2025-09-29 15:56:42,600	INFO tune.py:1041 -- Total run time: 480.80 seconds (478.89 seconds for the tuning loop).


== Status ==
Current time: 2025-09-29 15:56:42 (running for 00:07:58.90)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_15-48-06_722491_3287714/artifacts/2025-09-29_15-48-41/TorchTrainer_2025-09-29_15-48-41/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 27.2 s, sys: 6.66 s, total: 33.8 s
Wall time: 8min 29s


## 3. SIMS: Scalable, Interpretable Models for Cell Annotation of large scale single-cell RNA-seq data
**SIMS** is a pipeline designed to build interpretable and accurate classifiers for identifying any target in single-cell RNA sequencing (scRNA-seq) data.  
- The core SIMS model is based on a **sequential transformer**, a specialized transformer architecture built for large-scale tabular datasets. 
- SIMS provides a framework for **cell type annotation**: it trains on labeled single-cell data and predicts cell type labels for new, unlabeled cells. 
- It leverages the **TabNet** deep learning model, which automatically selects the most informative genes for each prediction, ensuring results that are both **accurate** and **interpretable**.  
For implementation details and source code, see the [SIMS GitHub repository](https://github.com/braingeneers/SIMS/tree/main).

### SIMS Metadata Callback
This callback (`sims_metadata_cb`) extracts key information from the AnnData object to configure the SIMS model.
- `input_dim`: the number of genes (features) in the dataset.
- `cell_lines`: list of unique cell line categories.
- `output_dim`: the number of distinct classes (cell lines) to be predicted.

In [5]:
def sims_metadata_cb(ad: anndata.AnnData, metadata: dict):
    metadata["num_genes"] = ad.var.shape[0]
    metadata["input_dim"] = metadata["num_genes"]
    metadata["cell_lines"] = ad.obs["cell_line"].cat.categories.to_list()
    metadata["num_classes"] = len(metadata["cell_lines"])
    metadata["output_dim"] = metadata["num_classes"]

### Training the SIMS Classifier

- The **SIMSClassifier** model is initialized with the dataset (`Dcl`), while essential arguments (`input_dim`, `output_dim`) are supplied through the `sims_metadata_cb` callback 
- Training is distributed using **RayDDPStrategy**, with `find_unused_parameters=True` enabled to ensure proper handling of layers that may not be active in every forward pass


In [6]:
%%time
sims_trainer = RayTrainRunner(
    SIMSClassifier,
    Dcl,
    ["input_dim", "output_dim"],  # maps to SIMSClassifier(input_dim, output_dim)
    sims_metadata_cb,
    ray_trainer_strategy=RayDDPStrategy(find_unused_parameters=True),
)

2025-09-29 15:57:36,769	INFO worker.py:1951 -- Started a local Ray instance.
2025-09-29 15:57:38,275	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example'.
2025-09-29 15:57:38,726	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_0bc7f6254ceca816.zip' (69.14MiB) to Ray cluster...
2025-09-29 15:57:39,243	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_0bc7f6254ceca816.zip'.


CPU times: user 746 ms, sys: 760 ms, total: 1.51 s
Wall time: 10.7 s


[33m(raylet)[0m Using CPython [36m3.11.13[39m
[33m(raylet)[0m Creating virtual environment at: [36m.venv[39m
[33m(raylet)[0m [2mInstalled [1m296 packages[0m [2min 322ms[0m[0m


[36m(TrainTrainable pid=3309981)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=3309981)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3311122)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=3309981)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3309981)[0m - (node_id=b7aa76930fb0b23923b4c09f1ac8ae73d51c8f7724cf1b4cacb7631c, ip=192.168.1.226, pid=3311122) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=3311122)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=3311122)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3311122)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=3311122)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=3311122)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3311122)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3311122)[0m /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/runtime_resources/working_dir_files/_ray_pkg_0bc7f6254ceca816/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3 /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pyth ...
[36m(RayTrainWorker pid=3311122)[0m You are using 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=3311122)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3311122)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3311122)[0m   return torch.sparse_compressed_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


[36m(RayTrainWorker pid=3311122)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3311122)[0m   return torch.sparse_csr_tensor(


                                                                           


[36m(RayTrainWorker pid=3311122)[0m /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/runtime_resources/working_dir_files/_ray_pkg_0bc7f6254ceca816/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=3311122)[0m /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/runtime_resources/working_dir_files/_ray_pkg_0bc7f6254ceca816/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val/f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=3311122)[0m /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/runtime_resources/working_dir_files/_ray_pkg_0bc7f6254ceca816/.venv/lib/python

Epoch 0:   0%|          | 0/4160 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/4160 [00:24<27:53:00,  0.04it/s, v_num=0, train/loss_step=4.750]
.
.
.
Epoch 0: 100%|█████████▉| 4157/4160 [09:05<00:00,  7.62it/s, v_num=0, train/loss_step=0.382]
Epoch 0: 100%|█████████▉| 4158/4160 [09:05<00:00,  7.62it/s, v_num=0, train/loss_step=0.433]
Epoch 0: 100%|█████████▉| 4159/4160 [09:05<00:00,  7.62it/s, v_num=0, train/loss_step=0.520]
Epoch 0: 100%|██████████| 4160/4160 [09:06<00:00,  7.62it/s, v_num=0, train/loss_step=0.443]
[36m(RayTrainWorker pid=3311122)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=3311122)[0m 
Validation:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1024 [00:00<?, ?it/s][A
.
.
.
Validation DataLoader 0: 100%|█████████▉| 1020/1024 [01:14<00:00, 13.65it/s][A
Validation DataLoader 0: 100%|█████████▉| 1021/1024 [01:14<00:00, 13.66it/s][A
Validation DataLoader 0: 100%|█████████▉| 1022/1024 [01:14<0

[36m(RayTrainWorker pid=3311122)[0m /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/runtime_resources/working_dir_files/_ray_pkg_0bc7f6254ceca816/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0: 100%|██████████| 4160/4160 [10:44<00:00,  6.45it/s, v_num=0, train/loss_step=0.443, val/loss=0.543, val/f1=0.843, val/macro_acc=0.841, val/micro_acc=0.945, val/precision=0.847, val/recall=0.841, val/specificity=0.999, val/weighted_acc=0.945, train/loss_epoch=0.503]
Epoch 0: 100%|██████████| 4160/4160 [10:45<00:00,  6.45it/s, v_num=0, train/loss_step=0.443, val/loss=0.543, val/f1=0.843, val/macro_acc=0.841, val/micro_acc=0.945, val/precision=0.847, val/recall=0.841, val/specificity=0.999, val/weighted_acc=0.945, train/loss_epoch=0.503]


[36m(RayTrainWorker pid=3311122)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_15-58-04/TorchTrainer_15473_00000_0_2025-09-29_15-58-04/checkpoint_000000)
[36m(RayTrainWorker pid=3311122)[0m /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/runtime_resources/working_dir_files/_ray_pkg_0bc7f6254ceca816/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train/f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
[36m(RayTrainWorker pid=3311122)[0m /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/runtime_resources/working_dir_files/_ray_pkg_0bc7f6254ceca816/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train/macro_acc', ..., sync_dist=True)` when logging 

On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `sims_trainer.train()` completed in about **14 minutes**.

In [7]:
%%time
sims_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

Setting thread_per_worker to half of the available CPUs capped at 4
Using 1 workers with {'CPU': 4} each


2025-09-29 15:58:04,345	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 20.92 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-29 15:58:04 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/artifacts/2025-09-29_15-58-04/TorchTrainer_2025-09-29_15-58-04/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-29 15:58:29 (running for 00:00:25.37)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/artifacts/2025-09-29_15-58-04/TorchTrainer_2025-09-29_15-58-04/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-29 16:10:54 (running for 00:12:50.01)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-29 16:10:56,469	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-29_15-58-04' in 0.0130s.
2025-09-29 16:10:56,473	INFO tune.py:1041 -- Total run time: 772.13 seconds (772.10 seconds for the tuning loop).


== Status ==
Current time: 2025-09-29 16:10:56 (running for 00:12:52.11)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_15-57-32_695037_3287714/artifacts/2025-09-29_15-58-04/TorchTrainer_2025-09-29_15-58-04/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 31.8 s, sys: 4.9 s, total: 36.7 s
Wall time: 13min 15s


## 4. Autoencoder
- An **autoencoder** is an unsupervised neural network consisting of three main components:  
  - **Encoder**: compresses the input into a lower-dimensional representation.  
  - **Bottleneck**: stores the compressed features.  
  - **Decoder**: reconstructs the input from the bottleneck representation.  
- In this setup, separate encoders process **gene** and **protein** data. Their outputs are concatenated, passed through an additional encoder to form the bottleneck, and then decoded back to the original input.  
- Since **Tahoe-100M** does not include protein data, the protein input is set to `0`, and the source code was adapted to ensure compatibility with datasets lacking protein features.
- For testing purposes, we temporarily set mid = 128, which reduces the hidden layer size and simplifies the model architecture. For implementation details, see the [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [8]:
# group linear, batchnorm, and dropout layers. This module was from citeseq_autoencoder notebook
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from torch import nn, optim


class LinBnDrop(nn.Sequential):
    """Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers, adapted from fastai."""

    def __init__(self, n_in, n_out, bn=True, p=0.0, act=None, lin_first=True):
        layers = [nn.BatchNorm1d(n_out if lin_first else n_in)] if bn else []
        if p != 0:
            layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None:
            lin.append(act)
        layers = lin + layers if lin_first else layers + lin
        super().__init__(*layers)

We implement an encoder that processes RNA features through a two-layer MLP (`nfeatures_rna` → `mid=128` → `hidden_rna`, with `mid=2` set for testing). The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [9]:
class Encoder(nn.Module):
    """Encoder for CITE-seq data"""

    def __init__(
        self, nfeatures_rna: int, nfeatures_pro: int, hidden_rna: int, hidden_pro: int, latent_dim: int, p: float = 0
    ):
        super().__init__()
        self.nfeatures_rna = nfeatures_rna
        self.nfeatures_pro = nfeatures_pro

        if nfeatures_rna > 0:
            mid = 128  # 128 is for testing the code
            self.encoder_rna = nn.Sequential(
                LinBnDrop(nfeatures_rna, mid, p=p, act=nn.LeakyReLU()),
                LinBnDrop(mid, hidden_rna, act=nn.LeakyReLU()),
            )

        if nfeatures_pro > 0:
            self.encoder_protein = LinBnDrop(nfeatures_pro, hidden_pro, p=p, act=nn.LeakyReLU())

        # make sure hidden_rna and hidden_pro are set correctly
        hidden_rna = 0 if nfeatures_rna == 0 else hidden_rna
        hidden_pro = 0 if nfeatures_pro == 0 else hidden_pro

        hidden_dim = hidden_rna + hidden_pro

        self.encoder = LinBnDrop(hidden_dim, latent_dim, act=nn.LeakyReLU())

    def forward(self, x):
        if self.nfeatures_rna > 0 and self.nfeatures_pro > 0:
            x_rna = self.encoder_rna(x[:, : self.nfeatures_rna])
            x_pro = self.encoder_protein(x[:, self.nfeatures_rna :])
            x = torch.cat([x_rna, x_pro], 1)
        elif self.nfeatures_rna > 0 and self.nfeatures_pro == 0:
            x = self.encoder_rna(x)
        elif self.nfeatures_rna == 0 and self.nfeatures_pro > 0:
            x = self.encoder_protein(x)
        return self.encoder(x)

We implement a decoder that maps the latent vector to the RNA feature space by first expanding it to `hidden_rna`, passing it through a small intermediate layer (`mid_out` = `128`, used for testing), and finally projecting it to the RNA output dimension. The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb).

In [10]:
class Decoder(nn.Module):
    """Decoder for CITE-seq data"""

    def __init__(self, nfeatures_rna: int, nfeatures_pro: int, hidden_rna: int, hidden_pro: int, latent_dim: int):
        super().__init__()
        # make sure hidden_rna and hidden_pro are set correctly
        hidden_rna = 0 if nfeatures_rna == 0 else hidden_rna
        hidden_pro = 0 if nfeatures_pro == 0 else hidden_pro

        hidden_dim = hidden_rna + hidden_pro
        out_dim = nfeatures_rna + nfeatures_pro
        mid_out = 128  # 128 is for testing the code

        self.decoder = nn.Sequential(
            LinBnDrop(latent_dim, hidden_dim, act=nn.LeakyReLU()),
            LinBnDrop(hidden_dim, mid_out, act=nn.LeakyReLU()),
            LinBnDrop(mid_out, out_dim, bn=False),
        )

    def forward(self, x):
        return self.decoder(x)

The encoder and decoder are assembled into an autoencoder, which is defined as a PyTorch Lightning Module to simplify the training process. The source code is from [CITE-seq autoencoder source code](https://github.com/naity/citeseq_autoencoder/blob/main/autoencoder_citeseq_saturn.ipynb)

In [11]:
class CiteAutoencoder(pl.LightningModule):
    def __init__(
        self,
        nfeatures_rna: int,
        nfeatures_pro: int,
        hidden_rna: int,
        hidden_pro: int,
        latent_dim: int,
        p: float = 0,
        lr: float = 0.1,
    ):
        """Autoencoder for citeseq data"""
        super().__init__()

        # save hyperparameters
        self.save_hyperparameters()

        self.encoder = Encoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim, p)
        self.decoder = Decoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim)

        # example input array for visualizing network graph
        self.example_input_array = torch.zeros(256, nfeatures_rna + nfeatures_pro)

    def forward(self, x):
        # extract latent embeddings
        z = self.encoder(x)
        return z

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def _get_reconstruction_loss(self, batch):
        """Calculate MSE loss for a given batch."""
        x, _ = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        # MSE loss
        loss = F.mse_loss(x_hat, x)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

### Autoencoder Metadata Callback
- The `ae_metadata_cb` function extends `cell_line_metadata_cb` and configures the metadata required for training the autoencoder. It sets up cell line information, defines feature counts, and specifies key model hyperparameters such as hidden dimensions, latent space size, dropout, and learning rate

**Note (for testing):**  
In `ae_metadata_cb`, both the hidden RNA dimension (`hidden_rna=128`) and the latent dimension (`latent_dim=16`) are intentionally set to very small values. This configuration is used for quick testing and validation, not for full-scale training.

In [12]:
def ae_metadata_cb(ad, metadata):
    cell_line_metadata_cb(ad, metadata)
    metadata["cell_lines"] = np.sort(np.unique(ad.obs["cell_line"].to_numpy()))
    metadata["nfeatures_rna"] = metadata["num_genes"]
    metadata["nfeatures_pro"] = 0
    metadata["hidden_rna"] = 128
    metadata["hidden_pro"] = 0
    metadata["latent_dim"] = 16
    metadata["p"] = 0.1
    metadata["lr"] = 1e-3

### Training the CiteAutoencoder model
- The dataset (`Dcl`) is provided along with key model parameters such as RNA/protein feature counts, hidden layer sizes, latent dimension, dropout p, and learning rate lr, all supplied through the `ae_metadata_cb` callback.

In [13]:
%%time
autoencoder_trainer = RayTrainRunner(
    CiteAutoencoder,
    Dcl,
    ["nfeatures_rna", "nfeatures_pro", "hidden_rna", "hidden_pro", "latent_dim", "p", "lr"],
    metadata_cb=ae_metadata_cb,
)

2025-09-29 16:12:13,554	INFO worker.py:1951 -- Started a local Ray instance.
2025-09-29 16:12:14,575	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example'.
2025-09-29 16:12:14,931	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_44eaeac1790d7085.zip' (70.05MiB) to Ray cluster...
2025-09-29 16:12:15,371	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_44eaeac1790d7085.zip'.


CPU times: user 586 ms, sys: 694 ms, total: 1.28 s
Wall time: 10.6 s


[33m(raylet)[0m Using CPython [36m3.11.13[39m
[33m(raylet)[0m Creating virtual environment at: [36m.venv[39m
[33m(raylet)[0m [2mInstalled [1m296 packages[0m [2min 323ms[0m[0m


[36m(TrainTrainable pid=3330047)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=3330047)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3330754)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=3330047)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3330047)[0m - (node_id=661dfa72632cd7bdac4b2f4696086471aadc4dae065ed3a2b92450de, ip=192.168.1.226, pid=3330754) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=3330754)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=3330754)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3330754)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=3330754)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=3330754)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3330754)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3330754)[0m /tmp/ray/session_2025-09-29_16-12-08_641985_3287714/runtime_resources/working_dir_files/_ray_pkg_44eaeac1790d7085/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3 /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pyth ...
[36m(RayTrainWorker pid=3330754)[0m You are using 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=3330754)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3330754)[0m   return torch.sparse_compressed_tensor(
[36m(RayTrainWorker pid=3330754)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3330754)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3330754)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3330754)[0m /tmp/ray/session_2025-09-29_16-12-08_641985_3287714/runtime_resources/working_dir_files/_ray_pkg_44eaeac1790d7085/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


                                                                           
Epoch 0:   0%|          | 0/4160 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 2/4160 [00:22<13:16:53,  0.09it/s, v_num=0]
Epoch 0:   0%|          | 3/4160 [00:23<8:51:51,  0.13it/s, v_num=0] 
.
.
.
Epoch 0: 100%|█████████▉| 4153/4160 [05:31<00:00, 12.52it/s, v_num=0]
Epoch 0: 100%|█████████▉| 4156/4160 [05:31<00:00, 12.52it/s, v_num=0]
Epoch 0: 100%|█████████▉| 4159/4160 [05:31<00:00, 12.53it/s, v_num=0]
Epoch 0: 100%|██████████| 4160/4160 [05:31<00:00, 12.53it/s, v_num=0]
[36m(RayTrainWorker pid=3330754)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=3330754)[0m 
Validation:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/1024 [00:00<00:12, 79.73it/s][A
.
.
.
Validation DataLoader 0: 100%|█████████▉| 1019/1024 [01:03<00:00, 16.02it/s][A
Validation DataLoader 0: 100%|██

[36m(RayTrainWorker pid=3330754)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-12-46/TorchTrainer_23398_00000_0_2025-09-29_16-12-46/checkpoint_000000)


Epoch 0: 100%|██████████| 4160/4160 [07:07<00:00,  9.72it/s, v_num=0]


[36m(RayTrainWorker pid=3330754)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4160/4160 [07:08<00:00,  9.71it/s, v_num=0]


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12GiB) + 96 CPUs + 125GiB RAM**, `autoencoder_trainer()` finished in **9 minutes**

In [14]:
%%time
autoencoder_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

Setting thread_per_worker to half of the available CPUs capped at 4
Using 1 workers with {'CPU': 4} each
Data splitting time: 27.48 seconds
Spawning Ray worker and initiating distributed training


2025-09-29 16:12:46,739	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-09-29 16:12:46 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-12-08_641985_3287714/artifacts/2025-09-29_16-12-46/TorchTrainer_2025-09-29_16-12-46/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-29 16:13:12 (running for 00:00:25.26)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-12-08_641985_3287714/artifacts/2025-09-29_16-12-46/TorchTrainer_2025-09-29_16-12-46/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-29 16:21:00 (running for 00:08:13.75)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-12-08_641985_3287714/artifacts/2025-09-29_1

2025-09-29 16:21:01,711	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-12-46' in 0.0064s.
2025-09-29 16:21:01,808	INFO tune.py:1041 -- Total run time: 495.07 seconds (494.95 seconds for the tuning loop).


== Status ==
Current time: 2025-09-29 16:21:01 (running for 00:08:14.97)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-12-08_641985_3287714/artifacts/2025-09-29_16-12-46/TorchTrainer_2025-09-29_16-12-46/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 33.5 s, sys: 4.13 s, total: 37.6 s
Wall time: 8min 45s


## 5. DistributedClassifierTrainingPlan
- **ClassifierTrainingPlan** (from `scvi-tools`) is not a model itself, but a training plan.  
  Its purpose is to coordinate the entire training workflow of an scvi-tools classifier, including optimization, scheduling, and evaluation.  
- For details, see the [source code](https://github.com/scverse/scvi-tools/blob/main/src/scvi/train/_trainingplans.py#L1479).

In [None]:
# install scvi:
# uv add scvi-tools in terminal

### Classifier Training metadata callback
Calls `cell_line_metadata_cb` to extract `num_genes` and `num_classes` from the input AnnData object.

In [15]:
def clf_metadata_cb(ad, metadata):
    # Populate num_genes / num_classes from the AnnData file
    cell_line_metadata_cb(ad, metadata)

    # Create the classifier instance and attach it to metadata
    metadata["classifier"] = Classifier(
        n_input=metadata["num_genes"],
        n_labels=metadata["num_classes"],
        logits=True,  # ClassifierTrainingPlan requirement that the module returns logits
    )
    metadata["lr"] = 1e-3
    metadata["weight_decay"] = 1e-6
    metadata["eps"] = 0.01
    metadata["optimizer"] = "Adam"

The `DistributedClassifierTrainingPlan` subclass extends `ClassifierTrainingPlan` by explicitly defining its own `training_step` and `validation_step`:

In [16]:
class DistributedClassifierTrainingPlan(ClassifierTrainingPlan):
    def training_step(self, batch, batch_idx):
        """Training step for classifier training."""
        x, y = batch
        soft_prediction = self.forward(x)
        loss = self.loss_fn(soft_prediction, y.view(-1).long())
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step for classifier training."""
        x, y = batch
        soft_prediction = self.forward(x)
        loss = self.loss_fn(soft_prediction, y.view(-1).long())
        self.log("validation_loss", loss)

### Executing ClassifierTrainingPlan

In [17]:
%%time
ClassifierTrainingPlan_trainer = RayTrainRunner(
    Model=DistributedClassifierTrainingPlan,
    Ds=Dcl,
    model_keys=["classifier", "lr", "weight_decay", "eps", "optimizer"],
    metadata_cb=clf_metadata_cb,
)

2025-09-29 16:22:02,568	INFO worker.py:1951 -- Started a local Ray instance.
2025-09-29 16:22:03,828	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example'.
2025-09-29 16:22:04,333	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_1da4f615a458d958.zip' (70.38MiB) to Ray cluster...
2025-09-29 16:22:05,419	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_1da4f615a458d958.zip'.


CPU times: user 800 ms, sys: 831 ms, total: 1.63 s
Wall time: 12.9 s


[33m(raylet)[0m Using CPython [36m3.11.13[39m
[33m(raylet)[0m Creating virtual environment at: [36m.venv[39m
[33m(raylet)[0m [2mInstalled [1m296 packages[0m [2min 347ms[0m[0m


[36m(TrainTrainable pid=3351177)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=3351177)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3352013)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=3351177)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3351177)[0m - (node_id=5f290d146a3444f6b76f01c2fd8cbb4ef9675d111d416f43402f2952, ip=192.168.1.226, pid=3352013) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=3352013)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=3352013)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3352013)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=3352013)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=3352013)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3352013)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3352013)[0m /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/runtime_resources/working_dir_files/_ray_pkg_1da4f615a458d958/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3 /mnt/hdd1/dung/protoplast-ml-example/.venv/lib/pyth ...
[36m(RayTrainWorker pid=3352013)[0m You are using 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=3352013)[0m /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/runtime_resources/working_dir_files/_ray_pkg_1da4f615a458d958/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
[36m(RayTrainWorker pid=3352013)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3352013)[0m   return torch.sparse_compressed_tensor(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  5.04it/s]


[36m(RayTrainWorker pid=3352013)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3352013)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=3352013)[0m   return torch.sparse_csr_tensor(


                                                                           
Epoch 0:   0%|          | 0/4160 [00:00<?, ?it/s] 


[36m(RayTrainWorker pid=3352013)[0m /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/runtime_resources/working_dir_files/_ray_pkg_1da4f615a458d958/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('validation_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0:   0%|          | 1/4160 [00:31<36:00:19,  0.03it/s, v_num=0, train_loss_step=3.970]
Epoch 0:   0%|          | 2/4160 [00:31<18:15:47,  0.06it/s, v_num=0, train_loss_step=3.490]
Epoch 0:   0%|          | 5/4160 [00:31<7:19:19,  0.16it/s, v_num=0, train_loss_step=2.810] 
.
.
.
Epoch 0: 100%|█████████▉| 4142/4160 [05:12<00:01, 13.24it/s, v_num=0, train_loss_step=0.129] 
Epoch 0: 100%|█████████▉| 4143/4160 [05:12<00:01, 13.24it/s, v_num=0, train_loss_step=0.0985]
Epoch 0: 100%|█████████▉| 4148/4160 [05:12<00:00, 13.25it/s, v_num=0, train_loss_step=0.153] 
Epoch 0: 100%|█████████▉| 4149/4160 [05:13<00:00, 13.26it/s, v_num=0, train_loss_step=0.153]
Epoch 0: 100%|█████████▉| 4149/4160 [05:13<00:00, 13.25it/s, v_num=0, train_loss_step=0.115]
Epoch 0: 100%|█████████▉| 4154/4160 [05:13<00:00, 13.27it/s, v_num=0, train_loss_step=0.090] 
Epoch 0: 100%|██████████| 4160/4160 [05:13<00:00, 13.28it/s, v_num=0, train_loss_step=0.121] 
[36m(RayTrainWorker pid=3352013)[0m 
Validation: |       

[36m(RayTrainWorker pid=3352013)[0m /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/runtime_resources/working_dir_files/_ray_pkg_1da4f615a458d958/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Epoch 0: 100%|██████████| 4160/4160 [06:38<00:00, 10.43it/s, v_num=0, train_loss_step=0.121, train_loss_epoch=0.180]


[36m(RayTrainWorker pid=3352013)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-22-48/TorchTrainer_89c02_00000_0_2025-09-29_16-22-48/checkpoint_000000)


Epoch 0: 100%|██████████| 4160/4160 [06:39<00:00, 10.42it/s, v_num=0, train_loss_step=0.121, train_loss_epoch=0.180]


[36m(RayTrainWorker pid=3352013)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12GiB) + 96 CPUs + 125GiB RAM**, `ClassifierTrainingPlan_trainer()` finished in **8 minutes**

In [19]:
%%time
ClassifierTrainingPlan_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)
ray.shutdown()

Setting thread_per_worker to half of the available CPUs capped at 4
Using 1 workers with {'CPU': 4} each
Data splitting time: 23.91 seconds
Spawning Ray worker and initiating distributed training


2025-09-29 16:22:48,239	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-09-29 16:22:49 (running for 00:00:00.86)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/artifacts/2025-09-29_16-22-48/TorchTrainer_2025-09-29_16-22-48/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-29 16:23:19 (running for 00:00:31.34)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/artifacts/2025-09-29_16-22-48/TorchTrainer_2025-09-29_16-22-48/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-29 16:30:36 (running for 00:07:48.74)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/artifacts/2025-09-29_1

2025-09-29 16:30:38,727	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-22-48' in 1.7167s.
2025-09-29 16:30:38,792	INFO tune.py:1041 -- Total run time: 470.55 seconds (468.75 seconds for the tuning loop).


== Status ==
Current time: 2025-09-29 16:30:38 (running for 00:07:50.48)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-21-57_025895_3287714/artifacts/2025-09-29_16-22-48/TorchTrainer_2025-09-29_16-22-48/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 46.1 s, sys: 16.3 s, total: 1min 2s
Wall time: 8min 17s
