# Perturbation Models for Single-Cell Data with PROTOplast

This notebook showcases **perturbation models** for the **Tahoe-100M** dataset, focusing on predicting gene expression changes under drug perturbations. We demonstrate two approaches: a statistical baseline and a neural embedding model.

**Download the Tahoe-100M `h5ad` files**
- The Tahoe-100M dataset can be downloaded in `h5ad` format from the **Arc Institute Google Cloud Storage**. For step-by-step instructions, see the [official tutorial](https://github.com/ArcInstitute/arc-virtual-cell-atlas/blob/main/tahoe-100M/README.md).

**Set up**
- Set up the training environment for single-cell RNA sequencing (scRNA-seq) data using PROTOplast together with PyTorch Lightning and Ray

In [1]:
import anndata
import numpy as np
import torch
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset, cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner

# models from state
from state.tx.models.embed_sum import EmbedSumPerturbationModel
from state.tx.models.perturb_mean import PerturbMeanPerturbationModel

✓ Applied AnnDataFileManager patch


  from .autonotebook import tqdm as notebook_tqdm
2025-09-24 11:33:18,450	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-09-24 11:33:18,548	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


✓ Applied AnnDataFileManager patch


2025-09-24 11:33:18,588	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


## 1. Load the Tahoe 100-M Dataset (`h5ad`)
- `file_paths`: here, only Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas
- `thread_per_worker`: number of threads allocated per worker
- `batch_size`: number of samples per training batch
- `test_size`: fraction of data reserved for testing
- `val_size`: fraction of data reserved for validation (use `0.0` if no validation set is needed)

In [2]:
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
thread_per_worker = 2
batch_size = 2000
test_size = 0.0
val_size = 0.2

## 2. Perturbation Mean
**PerturbMeanPerturbationModel** (from STATE) is a *statistical baseline* that predicts perturbed expression by combining a control baseline (global or per-sample) with a perturbation-specific offset averaged across cell types.
- **Inputs**
    - Perturbation identifier
    - Cell type (cell line in Tahoe-100M)
    - Perturbed counts or embeddings
    - (Optional) control embedding
- **Output**
    - Predicted gene expression profile (or latent embedding, depending on configuration)
Note: This model is not trained so no learnable weights, no validation data). Its predictions come purely from statistics of the training dataset. 
**Source code:** [perturb_mean.py](https://github.com/ArcInstitute/state/blob/b6d26731e41d78c8c789d6973fe3d7db7853e9ad/src/state/tx/models/perturb_mean.py)

### Metadata Callback
The `perturbmean_metadata_cb` function prepares metadata for the **Perturbation Mean** model.  
- It converts drug and cell line columns to categorical values, sets input/output dimensions, hidden size, perturbation dimension, and training hyperparameters.  
- It also stores gene names, perturbation names, and cell types, while designating `DMSO_TF` as the control, `X` as the embedding key, and `gene` as the output space.
- `perturbmean_metadata_cb` prepares metadata for the Perturbation Mean model. It casts drug and cell line columns to categorical, sets input/output dimensions, hidden size, and perturbation dimension, and defines training hyperparameters. It also records gene names, perturbation names, and cell types, while specifying "DMSO_TF" as the control, "X" as the embedding key, and "gene" as the output space.

In [3]:
%%time


def perturbmean_metadata_cb(ad: anndata.AnnData, metadata: dict):
    ad.obs["drug"] = ad.obs["drug"].astype("category")
    ad.obs["cell_line"] = ad.obs["cell_line"].astype("category")

    metadata["input_dim"] = ad.var.shape[0]
    metadata["output_dim"] = ad.var.shape[0]
    metadata["hidden_dim"] = 0  # hidden_dim: Not used here, but required by base-class signature.
    metadata["pert_dim"] = ad.obs["drug"].astype(str).nunique()
    metadata["lr"] = 1e-3

    metadata["gene_names"] = ad.var_names.tolist()
    metadata["pert_names"] = ad.obs["drug"].cat.categories.tolist()
    metadata["cell_types"] = ad.obs["cell_line"].cat.categories.tolist()
    metadata["control_pert"] = "DMSO_TF"
    metadata["embed_key"] = "X"
    metadata["output_space"] = "gene"

CPU times: user 23 μs, sys: 0 ns, total: 23 μs
Wall time: 43.4 μs


### Perturbation Dataset for Training (PerturbAnnDataset)
`PerturbAnnDataset` prepares batches for the **Perturbation Mean** model. It loads expression data, collects `drug` and `cell_line` metadata, and returns a dictionary containing perturbation names, cell types, and the corresponding expression features (used both as counts and embeddings) for training.

In [4]:
%%time


class PerturbAnnDataset(DistributedAnnDataset):
    def transform(self, start: int, end: int):
        X = super().transform(start, end)

        # Metadata froms self.ad
        pert_names = self.ad.obs["drug"].iloc[start:end].astype(str).to_list()
        cell_lines = self.ad.obs["cell_line"].iloc[start:end].astype(str).to_list()

        return {
            "pert_name": pert_names,
            "cell_type": cell_lines,
            "pert_cell_counts": X,
            "pert_cell_emb": X,
        }

CPU times: user 149 μs, sys: 0 ns, total: 149 μs
Wall time: 170 μs


### Extending STATE Models
The **STATE** framework provides baseline model classes such as `PerturbMeanPerturbationModel`, which can be imported and used directly.
To customize behavior, you can **extend an existing class** and override only the methods that need modification.  
In the example below, we subclass `PerturbMeanPerturbationModel` and redefine the `forward()` method. Rather than relying on the per-cell `ctrl_cell_emb`, the model predicts using only the **global basal vector** combined with the corresponding perturbation offset.

In [5]:
%%time


class PerturbMeanGlobalModel(PerturbMeanPerturbationModel):
    """
    Extended class of PerturbMeanPerturbationModel where prediction ignores
    per-cell control embedding and uses only global basal + offset.
    """

    def forward(self, batch: dict) -> torch.Tensor:
        B = len(batch["pert_name"])
        device = self.dummy_param.device
        pred_out = torch.zeros((B, self.output_dim), device=device)

        for i in range(B):
            p_name = str(batch["pert_name"][i])
            offset_vec = self.pert_mean_offsets.get(p_name)
            if offset_vec is None:
                offset_vec = torch.zeros(self.output_dim, device=device)

            # Use global basal instead of batch["ctrl_cell_emb"]
            pred_out[i] = self.global_basal.to(device) + offset_vec.to(device)

        return pred_out

CPU times: user 135 μs, sys: 0 ns, total: 135 μs
Wall time: 156 μs


### Training the model
- Collect statistics (`on_fit_start`)
    - Compute control means per cell type
    - Compute perturbation deltas
    - Average deltas across cell types → perturbation offsets
    - Compute global basal = mean of all control means.
- Forward: for each sample, `prediction = global_basal + offset[perturbation]`
- Training: no parameters are learned; only logs MSE loss vs. ground truth

In [6]:
%%time
PerturbMeanPerturbationModel_trainer = RayTrainRunner(
    PerturbMeanGlobalModel,
    PerturbAnnDataset,
    [
        "input_dim",
        "output_dim",
        "hidden_dim",
        "pert_dim",
        "lr",
        "control_pert",  # "DMSO_TF"
        "embed_key",
        "output_space",  # "gene"
    ],
    perturbmean_metadata_cb,
)

2025-09-24 11:33:26,414	INFO worker.py:1951 -- Started a local Ray instance.


CPU times: user 232 ms, sys: 331 ms, total: 564 ms
Wall time: 2.81 s
[36m(TrainTrainable pid=1209221)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1209221)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1209389)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1209221)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1209221)[0m - (node_id=3f1e363fcefb582b194b252798310b8d931ee304e5f7ccca2792aa00, ip=192.168.1.226, pid=1209389) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1209389)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1209389)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1209389)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1209389)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1209389)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1209389)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1209389)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision




[36m(RayTrainWorker pid=1209389)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1209389)[0m 
[36m(RayTrainWorker pid=1209389)[0m   | Name         | Type    | Params | Mode 
[36m(RayTrainWorker pid=1209389)[0m -------------------------------------------------
[36m(RayTrainWorker pid=1209389)[0m 0 | loss_fn      | MSELoss | 0      | train
[36m(RayTrainWorker pid=1209389)[0m   | other params | n/a     | 1      | n/a  
[36m(RayTrainWorker pid=1209389)[0m -------------------------------------------------
[36m(RayTrainWorker pid=1209389)[0m 1         Trainable params
[36m(RayTrainWorker pid=1209389)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=1209389)[0m 1         Total params
[36m(RayTrainWorker pid=1209389)[0m 0.000     Total estimated model params size (MB)
[36m(RayTrainWorker pid=1209389)[0m 1         Modules in train mode
[36m(RayTrainWorker pid=1209389)[0m 0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  2.88it/s]
                                                                           
Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/4192 [00:22<26:32:38,  0.04it/s, v_num=0, train_loss=1.080]
Epoch 0:   0%|          | 2/4192 [00:23<13:48:53,  0.08it/s, v_num=0, train_loss=0.983]
Epoch 0:   0%|          | 3/4192 [00:24<9:19:45,  0.12it/s, v_num=0, train_loss=1.170] 
Epoch 0:   0%|          | 4/4192 [00:24<7:05:22,  0.16it/s, v_num=0, train_loss=0.956]
Epoch 0:   0%|          | 5/4192 [00:24<5:44:56,  0.20it/s, v_num=0, train_loss=0.448]
Epoch 0:   0%|          | 6/4192 [00:25<4:51:18,  0.24it/s, v_num=0, train_loss=0.619]
Epoch 0:   0%|          | 7/4192 [00:25<4:12:39,  0.28it/s, v_num=0, train_loss=0.500]
Epoch 0:   0%|          | 8/4192 [00:25<3:43:45,  0.31it/s, v_num=

[36m(RayTrainWorker pid=1209389)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-33-51/TorchTrainer_58516_00000_0_2025-09-24_11-33-51/checkpoint_000000)


Epoch 0: 100%|██████████| 4192/4192 [29:09<00:00,  2.40it/s, v_num=0, train_loss=0.953]


[36m(RayTrainWorker pid=1209389)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4192/4192 [29:09<00:00,  2.40it/s, v_num=0, train_loss=0.953]




On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `PerturbMeanPerturbationModel_trainer.train()` completed in approximately **40 minutes**.

In [7]:
%%time
PerturbMeanPerturbationModel_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
    thread_per_worker=thread_per_worker,  # 2
)

Using 1 workers with {'CPU': 2} each


2025-09-24 11:33:51,717	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 24.04 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 11:33:51 (running for 00:00:00.18)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-33-24_769318_1202711/artifacts/2025-09-24_11-33-51/TorchTrainer_2025-09-24_11-33-51/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 11:33:57 (running for 00:00:05.21)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-33-24_769318_1202711/artifacts/2025-09-24_11-33-51/TorchTrainer_2025-09-24_11-33-51/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 11:34:02 (running for 00:00:10.29)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-24 12:13:07,400	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-33-51' in 0.1191s.
2025-09-24 12:13:07,405	INFO tune.py:1041 -- Total run time: 2355.69 seconds (2355.48 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 12:13:07 (running for 00:39:15.60)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_11-33-24_769318_1202711/artifacts/2025-09-24_11-33-51/TorchTrainer_2025-09-24_11-33-51/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 56.8 s, sys: 14.4 s, total: 1min 11s
Wall time: 39min 39s


Result(
  metrics={'train_loss': 0.9527178406715393, 'val_loss': 0.6070448756217957, 'epoch': 0, 'step': 4192},
  path='/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-33-51/TorchTrainer_58516_00000_0_2025-09-24_11-33-51',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_11-33-51/TorchTrainer_58516_00000_0_2025-09-24_11-33-51/checkpoint_000000)
)

In [9]:
import ray

ray.shutdown()

## 3. EmbedSum
The **EmbedSumPerturbationModel** (part of the STATE framework) is a neural embedding model that predicts gene expression under perturbations.  
It works by combining a **control (basal) cell state** with a **learned perturbation embedding**.  
**Inputs**
  - Control (basal) expression counts or embedding  
  - Perturbation one-hot vector  

**Output**
  - Predicted gene expression profile  
**Source code:** [embed_sum.py](https://github.com/ArcInstitute/state/blob/b6d26731e41d78c8c789d6973fe3d7db7853e9ad/src/state/tx/models/embed_sum.py#L7)

### Metadata Callback
The `embedsum_metadata_cb` function prepares metadata for the `EmbedSumPerturbationModel`. It sets the input and output dimensions (equal to the **number of genes**), defines the perturbation dimension based on the unique drugs in the dataset, and specifies training parameters such as hidden layer size and the control perturbation (`DMSO_TF`).

In [10]:
%%time


def embedsum_metadata_cb(ad: anndata.AnnData, metadata: dict):
    cell_line_metadata_cb(ad, metadata)
    metadata["input_dim"] = ad.var.shape[0]
    metadata["output_dim"] = ad.var.shape[0]

    uniq_drugs = sorted(ad.obs["drug"].astype(str).unique().tolist())
    metadata["pert_dim"] = len(uniq_drugs)

    metadata["hidden_dim"] = 10  # here kept small for testing
    metadata["control_pert"] = "DMSO_TF"

CPU times: user 18 μs, sys: 4 μs, total: 22 μs
Wall time: 34.3 μs


### EmbedSumAnnDataset
The `EmbedSumAnnDataset` class extends `DistributedAnnDataset` and prepares batches for the `EmbedSumPerturbationModel`. It enriches each batch with drug embeddings, metadata, and control information needed for training.

In [12]:
%%time


class EmbedSumAnnDataset(DistributedAnnDataset):
    control_drug = "DMSO_TF"

    def transform(self, start: int, end: int):
        # Loads gene expression (X) and converts it into a tensor (target_gene_expr).
        X = super().transform(start, end)
        target_gene_expr = torch.as_tensor(X, dtype=torch.float32)
        device = target_gene_expr.device

        # Collects metadata: perturbation names (drug) and cell line labels
        pert_names = self.ad.obs["drug"].iloc[start:end].astype(str).to_list()
        cell_lines = self.ad.obs["cell_line"].iloc[start:end].astype(str).to_list()

        # Create drug index mapping
        if not hasattr(self, "_drug_to_idx"):
            drug_names = sorted(self.ad.obs["drug"].astype(str).unique())
            self._drug_to_idx = {d: i for i, d in enumerate(drug_names)}
            self._num_drugs = len(drug_names)

        # encodes drugs as one-hot embeddings
        idxs = [self._drug_to_idx.get(p, 0) for p in pert_names]
        pert_emb = torch.nn.functional.one_hot(torch.tensor(idxs, device=device), num_classes=self._num_drugs).float()

        # Computes a global control mean expression vector from cells treated with DMSO_TF
        if not hasattr(self, "_ctrl_global"):
            mask = self.ad.obs["drug"] == self.control_drug
            if mask.sum() == 0:
                ctrl_vec = np.zeros(self.ad.shape[1], dtype=np.float32)
            else:
                ctrl_vec = np.asarray(self.ad[mask].X.mean(axis=0)).ravel().astype(np.float32)
            self._ctrl_global = torch.from_numpy(ctrl_vec)

        ctrl_cell_emb = self._ctrl_global.to(device).unsqueeze(0).expand(len(pert_names), -1)

        # Returns a dictionary containing embeddings, control features, target expression, and metadata for perturbation training
        return {
            "pert_emb": pert_emb,
            "ctrl_cell_emb": ctrl_cell_emb,
            "target_gene_expr": target_gene_expr,
            "pert_cell_emb": target_gene_expr,
            "pert_name": pert_names,
            "cell_type": cell_lines,
        }

CPU times: user 144 μs, sys: 0 ns, total: 144 μs
Wall time: 158 μs


### Training the EmbedSumPerturbationModel
- It pairs the model with the custom `EmbedSumAnnDataset` and passes in required arguments (dimensions, learning rate, control perturbation, embedding key, and output space) via `embedsum_metadata_cb`.

In [13]:
%%time
EmbedSumPerturbationModel_trainer = RayTrainRunner(
    EmbedSumPerturbationModel,
    EmbedSumAnnDataset,
    [
        "input_dim",
        "output_dim",
        "hidden_dim",
        "pert_dim",
        "lr",
        "control_pert",  # "DMSO_TF"
        "embed_key",
        "output_space",  # "gene"
    ],
    embedsum_metadata_cb,
)

2025-09-24 12:14:53,176	INFO worker.py:1951 -- Started a local Ray instance.


CPU times: user 112 ms, sys: 301 ms, total: 412 ms
Wall time: 3.66 s
[36m(TrainTrainable pid=1251740)[0m ✓ Applied AnnDataFileManager patch
[36m(TrainTrainable pid=1251740)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1252096)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=1251740)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=1251740)[0m - (node_id=76088a489aa8c92b9eeb053e2cee79abba9656a0d134181ab362cea8, ip=192.168.1.226, pid=1252096) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=1252096)[0m ✓ Applied AnnDataFileManager patch
[36m(RayTrainWorker pid=1252096)[0m ✓ Applied AnnDataFileManager patch


[36m(RayTrainWorker pid=1252096)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1252096)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1252096)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1252096)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1252096)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[36m(RayTrainWorker pid=1252096)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1252096)[0m 
[36m(

Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  1.70it/s]
                                                                           
Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/4192 [00:43<50:46:34,  0.02it/s, v_num=0]
Epoch 0:   0%|          | 2/4192 [00:43<25:35:08,  0.05it/s, v_num=0]
Epoch 0:   0%|          | 3/4192 [00:44<17:11:18,  0.07it/s, v_num=0]
Epoch 0:   0%|          | 4/4192 [00:44<12:59:24,  0.09it/s, v_num=0]
Epoch 0:   0%|          | 5/4192 [00:45<10:28:17,  0.11it/s, v_num=0]
Epoch 0:   0%|          | 6/4192 [00:45<8:47:31,  0.13it/s, v_num=0] 
Epoch 0:   0%|          | 7/4192 [00:45<7:35:33,  0.15it/s, v_num=0]
Epoch 0:   0%|          | 8/4192 [00:46<6:41:33,  0.17it/s, v_num=0]
Epoch 0:   0%|          | 9/4192 [00:46<5:59:38,  0.19it/s, v_num=0]
Epoch 0:   0%|          | 10/4192 [00:46<5:26:01,  

[36m(RayTrainWorker pid=1252096)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_12-15-18/TorchTrainer_22951_00000_0_2025-09-24_12-15-18/checkpoint_000000)
[36m(RayTrainWorker pid=1252096)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4192/4192 [33:32<00:00,  2.08it/s, v_num=0]




On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `EmbedSumPerturbationModel_trainer.train()` completed in approximately **35 minutes**.

In [14]:
%%time
EmbedSumPerturbationModel_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
    thread_per_worker=thread_per_worker,  # 2
)

Using 1 workers with {'CPU': 2} each


2025-09-24 12:15:18,559	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 24.34 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-24 12:15:18 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_12-14-50_582135_1202711/artifacts/2025-09-24_12-15-18/TorchTrainer_2025-09-24_12-15-18/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 12:15:23 (running for 00:00:05.17)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_12-14-50_582135_1202711/artifacts/2025-09-24_12-15-18/TorchTrainer_2025-09-24_12-15-18/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 12:15:28 (running for 00:00:10.24)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G

2025-09-24 12:49:52,903	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/dtran/protoplast_results/TorchTrainer_2025-09-24_12-15-18' in 0.0066s.
2025-09-24 12:49:52,908	INFO tune.py:1041 -- Total run time: 2074.35 seconds (2074.32 seconds for the tuning loop).


== Status ==
Current time: 2025-09-24 12:49:52 (running for 00:34:34.33)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-24_12-14-50_582135_1202711/artifacts/2025-09-24_12-15-18/TorchTrainer_2025-09-24_12-15-18/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


CPU times: user 50.9 s, sys: 10.9 s, total: 1min 1s
Wall time: 34min 58s


Result(
  metrics={'train_loss': 1.076431155204773, 'val_loss': 0.628422200679779, 'epoch': 0, 'step': 4192},
  path='/home/dtran/protoplast_results/TorchTrainer_2025-09-24_12-15-18/TorchTrainer_22951_00000_0_2025-09-24_12-15-18',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/TorchTrainer_2025-09-24_12-15-18/TorchTrainer_22951_00000_0_2025-09-24_12-15-18/checkpoint_000000)
)

*** SIGTERM received at time=1758718221 on cpu 9 ***
PC: @     0x73f02e32a072  (unknown)  epoll_wait
    @     0x73f02e245330  (unknown)  (unknown)
    @     0x5a1f48e9565a  (unknown)  select_epoll_poll_impl
[2025-09-24 12:50:21,507 E 1202711 1202711] logging.cc:474: *** SIGTERM received at time=1758718221 on cpu 9 ***
[2025-09-24 12:50:21,507 E 1202711 1202711] logging.cc:474: PC: @     0x73f02e32a072  (unknown)  epoll_wait
[2025-09-24 12:50:21,508 E 1202711 1202711] logging.cc:474:     @     0x73f02e245330  (unknown)  (unknown)
[2025-09-24 12:50:21,508 E 1202711 1202711] logging.cc:474:     @     0x5a1f48e9565a  (unknown)  select_epoll_poll_impl
