# Perturbation Models for Single-Cell Data with PROTOplast

This notebook showcases **perturbation models** for the **Tahoe-100M** dataset, focusing on predicting gene expression changes under drug perturbations. We demonstrate two approaches: a statistical baseline and a neural embedding model.

**Download the Tahoe-100M `h5ad` files**
- The Tahoe-100M dataset can be downloaded in `h5ad` format from the **Arc Institute Google Cloud Storage**. For step-by-step instructions, see the [official tutorial](https://github.com/ArcInstitute/arc-virtual-cell-atlas/blob/main/tahoe-100M/README.md).

**Set up**
- Set up the training environment for single-cell RNA sequencing (scRNA-seq) data using PROTOplast together with PyTorch Lightning and Ray

In [1]:
import anndata
import numpy as np
import torch
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset, cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner

# models from state
from state.tx.models.embed_sum import EmbedSumPerturbationModel
from state.tx.models.perturb_mean import PerturbMeanPerturbationModel

✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


## 1. Load the Tahoe 100-M Dataset (`h5ad`)
- `file_paths`: here, only Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas
- `batch_size`: number of samples per training batch
- `test_size`: fraction of data reserved for testing
- `val_size`: fraction of data reserved for validation (use `0.0` if no validation set is needed)

In [2]:
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
batch_size = 2000
test_size = 0.0
val_size = 0.2

## 2. Perturbation Mean
**PerturbMeanPerturbationModel** (from STATE) is a *statistical baseline* that predicts perturbed expression by combining a control baseline (global or per-sample) with a perturbation-specific offset averaged across cell types.
- **Inputs**
    - Perturbation identifier
    - Cell type (cell line in Tahoe-100M)
    - Perturbed counts or embeddings
    - (Optional) control embedding
- **Output**
    - Predicted gene expression profile (or latent embedding, depending on configuration)
Note: This model is not trained so no learnable weights, no validation data). Its predictions come purely from statistics of the training dataset. 
**Source code:** [perturb_mean.py](https://github.com/ArcInstitute/state/blob/b6d26731e41d78c8c789d6973fe3d7db7853e9ad/src/state/tx/models/perturb_mean.py)

### Metadata Callback
The `perturbmean_metadata_cb` function prepares metadata for the **Perturbation Mean** model.  
- It converts drug and cell line columns to categorical values, sets input/output dimensions, hidden size, perturbation dimension, and training hyperparameters.  
- It also stores gene names, perturbation names, and cell types, while designating `DMSO_TF` as the control, `X` as the embedding key, and `gene` as the output space.
- `perturbmean_metadata_cb` prepares metadata for the Perturbation Mean model. It casts drug and cell line columns to categorical, sets input/output dimensions, hidden size, and perturbation dimension, and defines training hyperparameters. It also records gene names, perturbation names, and cell types, while specifying `DMSO_TF` as the control, `X` as the embedding key, and `gene` as the output space.

In [3]:
def perturbmean_metadata_cb(ad: anndata.AnnData, metadata: dict):
    ad.obs["drug"] = ad.obs["drug"].astype("category")
    ad.obs["cell_line"] = ad.obs["cell_line"].astype("category")

    metadata["input_dim"] = ad.var.shape[0]
    metadata["output_dim"] = ad.var.shape[0]
    metadata["hidden_dim"] = 0  # hidden_dim: Not used here, but required by base-class signature.
    metadata["pert_dim"] = ad.obs["drug"].astype(str).nunique()
    metadata["lr"] = 1e-3

    metadata["gene_names"] = ad.var_names.tolist()
    metadata["pert_names"] = ad.obs["drug"].cat.categories.tolist()
    metadata["cell_types"] = ad.obs["cell_line"].cat.categories.tolist()
    metadata["control_pert"] = "DMSO_TF"
    metadata["embed_key"] = "X"
    metadata["output_space"] = "gene"

### Perturbation Dataset for Training (PerturbAnnDataset)
`PerturbAnnDataset` prepares batches for the **Perturbation Mean** model. It loads expression data, collects `drug` and `cell_line` metadata, and returns a dictionary containing perturbation names, cell types, and the corresponding expression features (used both as counts and embeddings) for training.

In [4]:
class PerturbAnnDataset(DistributedAnnDataset):
    def transform(self, start: int, end: int):
        X = super().transform(start, end)

        # Metadata froms self.ad
        pert_names = self.ad.obs["drug"].iloc[start:end].astype(str).to_list()
        cell_lines = self.ad.obs["cell_line"].iloc[start:end].astype(str).to_list()

        return {
            "pert_name": pert_names,
            "cell_type": cell_lines,
            "pert_cell_counts": X,
            "pert_cell_emb": X,
        }

### Extending STATE Models
The **STATE** framework provides baseline model classes such as `PerturbMeanPerturbationModel`, which can be imported and used directly.
To customize behavior, you can **extend an existing class** and override only the methods that need modification.  
In the example below, we subclass `PerturbMeanPerturbationModel` and redefine the `forward()` method. Rather than relying on the per-cell `ctrl_cell_emb`, the model predicts using only the **global basal vector** combined with the corresponding perturbation offset.

In [5]:
class PerturbMeanGlobalModel(PerturbMeanPerturbationModel):
    """
    Extended class of PerturbMeanPerturbationModel where prediction ignores
    per-cell control embedding and uses only global basal + offset.
    """

    def forward(self, batch: dict) -> torch.Tensor:
        B = len(batch["pert_name"])
        device = self.dummy_param.device
        pred_out = torch.zeros((B, self.output_dim), device=device)

        for i in range(B):
            p_name = str(batch["pert_name"][i])
            offset_vec = self.pert_mean_offsets.get(p_name)
            if offset_vec is None:
                offset_vec = torch.zeros(self.output_dim, device=device)

            # Use global basal instead of batch["ctrl_cell_emb"]
            pred_out[i] = self.global_basal.to(device) + offset_vec.to(device)

        return pred_out

### Training the model
- Collect statistics (`on_fit_start`)
    - Compute control means per cell type
    - Compute perturbation deltas
    - Average deltas across cell types → perturbation offsets
    - Compute global basal = mean of all control means.
- Forward: for each sample, `prediction = global_basal + offset[perturbation]`
- Training: no parameters are learned; only logs MSE loss vs. ground truth

In [6]:
%%time
PerturbMeanPerturbationModel_trainer = RayTrainRunner(
    PerturbMeanGlobalModel,
    PerturbAnnDataset,
    [
        "input_dim",
        "output_dim",
        "hidden_dim",
        "pert_dim",
        "lr",
        "control_pert",  # "DMSO_TF"
        "embed_key",
        "output_space",  # "gene"
    ],
    perturbmean_metadata_cb,
)

2025-09-29 16:31:51,373	INFO worker.py:1951 -- Started a local Ray instance.
2025-09-29 16:31:52,487	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example'.
2025-09-29 16:31:52,853	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_1f721884ccb69d33.zip' (70.69MiB) to Ray cluster...
2025-09-29 16:31:53,343	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_1f721884ccb69d33.zip'.


CPU times: user 922 ms, sys: 799 ms, total: 1.72 s
Wall time: 10.8 s


[33m(raylet)[0m Using CPython [36m3.11.13[39m
[33m(raylet)[0m Creating virtual environment at: [36m.venv[39m
[33m(raylet)[0m [2mInstalled [1m296 packages[0m [2min 377ms[0m[0m


[36m(TrainTrainable pid=3371578)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=3371578)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3372413)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=3371578)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3371578)[0m - (node_id=b8b683fd763132fb5be6d98f0e4d56e917dd267ccfd076f40669efc5, ip=192.168.1.226, pid=3372413) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=3372413)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=3372413)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3372413)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=3372413)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=3372413)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3372413)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3372413)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[36m(RayTrainWorker pid=3372413)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=3372413)[0m 
[36m(

Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  2.93it/s]
                                                                           
Epoch 0:   0%|          | 0/4160 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/4160 [00:23<27:27:04,  0.04it/s, v_num=0, train_loss=0.245]
Epoch 0:   0%|          | 2/4160 [00:24<13:54:30,  0.08it/s, v_num=0, train_loss=0.406]
.
.
.
Epoch 0: 100%|█████████▉| 4158/4160 [24:51<00:00,  2.79it/s, v_num=0, train_loss=0.365]
Epoch 0: 100%|█████████▉| 4159/4160 [24:52<00:00,  2.79it/s, v_num=0, train_loss=0.996]
Epoch 0: 100%|██████████| 4160/4160 [24:52<00:00,  2.79it/s, v_num=0, train_loss=0.474]
[36m(RayTrainWorker pid=3372413)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=3372413)[0m 
Validation:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1024 

[36m(RayTrainWorker pid=3372413)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-32-21/TorchTrainer_dfadf_00000_0_2025-09-29_16-32-22/checkpoint_000000)


Epoch 0: 100%|██████████| 4160/4160 [30:53<00:00,  2.24it/s, v_num=0, train_loss=0.474]
Epoch 0: 100%|██████████| 4160/4160 [30:53<00:00,  2.24it/s, v_num=0, train_loss=0.474]


[36m(RayTrainWorker pid=3372413)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `PerturbMeanPerturbationModel_trainer.train()` completed in approximately **39 minutes**.

In [7]:
%%time
PerturbMeanPerturbationModel_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)

Setting thread_per_worker to half of the available CPUs capped at 4
Using 1 workers with {'CPU': 4} each
Data splitting time: 24.26 seconds
Spawning Ray worker and initiating distributed training


2025-09-29 16:32:21,891	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-09-29 16:32:22 (running for 00:00:00.21)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-31-46_659688_3349722/artifacts/2025-09-29_16-32-21/TorchTrainer_2025-09-29_16-32-21/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-29 16:32:52 (running for 00:00:30.46)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-31-46_659688_3349722/artifacts/2025-09-29_16-32-21/TorchTrainer_2025-09-29_16-32-21/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-29 17:10:45 (running for 00:38:23.50)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-31-46_659688_3349722/artifacts/2025-09-29_1

2025-09-29 17:10:50,421	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-32-21' in 0.1060s.
2025-09-29 17:10:50,471	INFO tune.py:1041 -- Total run time: 2308.58 seconds (2307.96 seconds for the tuning loop).


== Status ==
Current time: 2025-09-29 17:10:50 (running for 00:38:28.07)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_16-31-46_659688_3349722/artifacts/2025-09-29_16-32-21/TorchTrainer_2025-09-29_16-32-21/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 55.3 s, sys: 13.4 s, total: 1min 8s
Wall time: 38min 53s


Result(
  metrics={'train_loss': 0.47350403666496277, 'val_loss': 0.5732423067092896, 'epoch': 0, 'step': 4160},
  path='/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-32-21/TorchTrainer_dfadf_00000_0_2025-09-29_16-32-22',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_16-32-21/TorchTrainer_dfadf_00000_0_2025-09-29_16-32-22/checkpoint_000000)
)

In [8]:
import ray
ray.shutdown()

## 3. EmbedSum
The **EmbedSumPerturbationModel** (part of the STATE framework) is a neural embedding model that predicts gene expression under perturbations.  
It works by combining a **control (basal) cell state** with a **learned perturbation embedding**.  
**Inputs**
  - Control (basal) expression counts or embedding  
  - Perturbation one-hot vector  

**Output**
  - Predicted gene expression profile  
**Source code:** [embed_sum.py](https://github.com/ArcInstitute/state/blob/b6d26731e41d78c8c789d6973fe3d7db7853e9ad/src/state/tx/models/embed_sum.py#L7)

### Metadata Callback
The `embedsum_metadata_cb` function prepares metadata for the `EmbedSumPerturbationModel`. It sets the input and output dimensions (equal to the **number of genes**), defines the perturbation dimension based on the unique drugs in the dataset, and specifies training parameters such as hidden layer size and the control perturbation (`DMSO_TF`).

In [9]:
def embedsum_metadata_cb(ad: anndata.AnnData, metadata: dict):
    cell_line_metadata_cb(ad, metadata)
    metadata["input_dim"] = ad.var.shape[0]
    metadata["output_dim"] = ad.var.shape[0]

    uniq_drugs = sorted(ad.obs["drug"].astype(str).unique().tolist())
    metadata["pert_dim"] = len(uniq_drugs)

    metadata["hidden_dim"] = 10  # here kept small for testing
    metadata["control_pert"] = "DMSO_TF"

### EmbedSumAnnDataset
The `EmbedSumAnnDataset` class extends `DistributedAnnDataset` and prepares batches for the `EmbedSumPerturbationModel`. It enriches each batch with drug embeddings, metadata, and control information needed for training.

In [10]:
class EmbedSumAnnDataset(DistributedAnnDataset):
    control_drug = "DMSO_TF"

    def transform(self, start: int, end: int):
        # Loads gene expression (X) and converts it into a tensor (target_gene_expr).
        X = super().transform(start, end)
        target_gene_expr = torch.as_tensor(X, dtype=torch.float32)
        device = target_gene_expr.device

        # Collects metadata: perturbation names (drug) and cell line labels
        pert_names = self.ad.obs["drug"].iloc[start:end].astype(str).to_list()
        cell_lines = self.ad.obs["cell_line"].iloc[start:end].astype(str).to_list()

        # Create drug index mapping
        if not hasattr(self, "_drug_to_idx"):
            drug_names = sorted(self.ad.obs["drug"].astype(str).unique())
            self._drug_to_idx = {d: i for i, d in enumerate(drug_names)}
            self._num_drugs = len(drug_names)

        # encodes drugs as one-hot embeddings
        idxs = [self._drug_to_idx.get(p, 0) for p in pert_names]
        pert_emb = torch.nn.functional.one_hot(torch.tensor(idxs, device=device), num_classes=self._num_drugs).float()

        # Computes a global control mean expression vector from cells treated with DMSO_TF
        if not hasattr(self, "_ctrl_global"):
            mask = self.ad.obs["drug"] == self.control_drug
            if mask.sum() == 0:
                ctrl_vec = np.zeros(self.ad.shape[1], dtype=np.float32)
            else:
                ctrl_vec = np.asarray(self.ad[mask].X.mean(axis=0)).ravel().astype(np.float32)
            self._ctrl_global = torch.from_numpy(ctrl_vec)

        ctrl_cell_emb = self._ctrl_global.to(device).unsqueeze(0).expand(len(pert_names), -1)

        # Returns a dictionary containing embeddings, control features, target expression, and metadata for perturbation training
        return {
            "pert_emb": pert_emb,
            "ctrl_cell_emb": ctrl_cell_emb,
            "target_gene_expr": target_gene_expr,
            "pert_cell_emb": target_gene_expr,
            "pert_name": pert_names,
            "cell_type": cell_lines,
        }

### Training the EmbedSumPerturbationModel
- It pairs the model with the custom `EmbedSumAnnDataset` and passes in required arguments (dimensions, learning rate, control perturbation, embedding key, and output space) via `embedsum_metadata_cb`.

In [11]:
%%time
EmbedSumPerturbationModel_trainer = RayTrainRunner(
    EmbedSumPerturbationModel,
    EmbedSumAnnDataset,
    [
        "input_dim",
        "output_dim",
        "hidden_dim",
        "pert_dim",
        "lr",
        "control_pert",  # "DMSO_TF"
        "embed_key",
        "output_space",  # "gene"
    ],
    embedsum_metadata_cb,
)

2025-09-29 17:13:50,011	INFO worker.py:1951 -- Started a local Ray instance.
2025-09-29 17:13:50,728	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example'.
2025-09-29 17:13:51,105	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_3e11bd015efed140.zip' (70.55MiB) to Ray cluster...
2025-09-29 17:13:51,625	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_3e11bd015efed140.zip'.


CPU times: user 679 ms, sys: 748 ms, total: 1.43 s
Wall time: 11.9 s


[33m(raylet)[0m [1m`VIRTUAL_ENV=/mnt/hdd1/dung/protoplast-ml-example/.venv` does not match the project environment path `.venv` and will be ignored; use `--active` to target the active environment instead[0m
[33m(raylet)[0m Using CPython [36m3.11.13[39m
[33m(raylet)[0m Creating virtual environment at: [36m.venv[39m
[33m(raylet)[0m [2mInstalled [1m296 packages[0m [2min 347ms[0m[0m


[36m(TrainTrainable pid=3410205)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=3410205)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3410711)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=3410205)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=3410205)[0m - (node_id=285a3d5c7ed27d7c54d3177df3fb09f94e0b897acd99ca14b9cb0bb5, ip=192.168.1.226, pid=3410711) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=3410711)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=3410711)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=3410711)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=3410711)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=3410711)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=3410711)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=3410711)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[36m(RayTrainWorker pid=3410711)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=3410711)[0m 
[36m(

Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  6.13it/s]
                                                                           
Epoch 0:   0%|          | 0/4160 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/4160 [00:52<60:48:48,  0.02it/s, v_num=0]
Epoch 0:   0%|          | 2/4160 [00:55<31:51:02,  0.04it/s, v_num=0]
.
.
.
Epoch 0: 100%|█████████▉| 4158/4160 [19:34<00:00,  3.54it/s, v_num=0]
Epoch 0: 100%|█████████▉| 4159/4160 [19:34<00:00,  3.54it/s, v_num=0]
Epoch 0: 100%|██████████| 4160/4160 [19:34<00:00,  3.54it/s, v_num=0]
[36m(RayTrainWorker pid=3410711)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=3410711)[0m 
Validation:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1024 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/1024 [00:00<00:17, 59.34it

[36m(RayTrainWorker pid=3410711)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_17-14-19/TorchTrainer_bc5c2_00000_0_2025-09-29_17-14-19/checkpoint_000000)
[36m(RayTrainWorker pid=3410711)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `EmbedSumPerturbationModel_trainer.train()` completed in approximately **27 minutes**.

In [12]:
%%time
EmbedSumPerturbationModel_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)

Setting thread_per_worker to half of the available CPUs capped at 4
Using 1 workers with {'CPU': 4} each


2025-09-29 17:14:19,631	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 24.40 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-29 17:14:19 (running for 00:00:00.16)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_17-13-43_388460_3349722/artifacts/2025-09-29_17-14-19/TorchTrainer_2025-09-29_17-14-19/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-29 17:14:55 (running for 00:00:35.42)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_17-13-43_388460_3349722/artifacts/2025-09-29_17-14-19/TorchTrainer_2025-09-29_17-14-19/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-29 17:40:49 (running for 00:26:30.11)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-29 17:40:50,117	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-29_17-14-19' in 0.0069s.
2025-09-29 17:40:50,123	INFO tune.py:1041 -- Total run time: 1590.49 seconds (1590.47 seconds for the tuning loop).


== Status ==
Current time: 2025-09-29 17:40:50 (running for 00:26:30.48)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-29_17-13-43_388460_3349722/artifacts/2025-09-29_17-14-19/TorchTrainer_2025-09-29_17-14-19/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 45.4 s, sys: 9.44 s, total: 54.8 s
Wall time: 26min 54s


Result(
  metrics={'train_loss': 0.51564621925354, 'val_loss': 0.6005661487579346, 'epoch': 0, 'step': 4160},
  path='/home/dtran/protoplast_results/TorchTrainer_2025-09-29_17-14-19/TorchTrainer_bc5c2_00000_0_2025-09-29_17-14-19',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-29_17-14-19/TorchTrainer_bc5c2_00000_0_2025-09-29_17-14-19/checkpoint_000000)
)