# Perturbation Models for Single-Cell Data with PROTOplast

This notebook showcases **perturbation models** for the **Tahoe-100M** dataset, focusing on predicting gene expression changes under drug perturbations. We demonstrate two approaches: a statistical baseline and a neural embedding model.

**Download the Tahoe-100M `h5ad` files**
- The Tahoe-100M dataset can be downloaded in `h5ad` format from the **Arc Institute Google Cloud Storage**. For step-by-step instructions, see the [official tutorial](https://github.com/ArcInstitute/arc-virtual-cell-atlas/blob/main/tahoe-100M/README.md).

**Set up**
- Set up the training environment for single-cell RNA sequencing (scRNA-seq) data using PROTOplast together with PyTorch Lightning and Ray

In [1]:
import protoplast

âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


root - INFO - Logging initialized. Current level is: INFO


In [2]:
from importlib.metadata import version

print(version("protoplast"))

0.1.2


In [3]:
import anndata
import numpy as np
import torch
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset, cell_line_metadata_cb
from protoplast.scrna.anndata.trainer import RayTrainRunner

# models from state
from state.tx.models.embed_sum import EmbedSumPerturbationModel
from state.tx.models.perturb_mean import PerturbMeanPerturbationModel

## 1. Load the Tahoe 100-M Dataset (`h5ad`)
- `file_paths`: here, only Plate 12 from Tahoe-100M (The largest file: 35 GB) is used as a demo. To add more plates, append their `.h5ad` file paths to the list, separated by commas
- `batch_size`: number of samples per training batch
- `test_size`: fraction of data reserved for testing
- `val_size`: fraction of data reserved for validation (use `0.0` if no validation set is needed)

In [4]:
file_paths = ["/mnt/hdd2/tan/tahoe100m/plate12_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad"]
batch_size = 2000
test_size = 0.0
val_size = 0.2

## 2. Perturbation Mean
**PerturbMeanPerturbationModel** (from STATE) is a *statistical baseline* that predicts perturbed expression by combining a control baseline (global or per-sample) with a perturbation-specific offset averaged across cell types.
- **Inputs**
    - Perturbation identifier
    - Cell type (cell line in Tahoe-100M)
    - Perturbed counts or embeddings
    - (Optional) control embedding
- **Output**
    - Predicted gene expression profile (or latent embedding, depending on configuration)
Note: This model is not trained so no learnable weights, no validation data). Its predictions come purely from statistics of the training dataset. 
**Source code:** [perturb_mean.py](https://github.com/ArcInstitute/state/blob/b6d26731e41d78c8c789d6973fe3d7db7853e9ad/src/state/tx/models/perturb_mean.py)

### Metadata Callback
The `perturbmean_metadata_cb` function prepares metadata for the **Perturbation Mean** model.  
- It converts drug and cell line columns to categorical values, sets input/output dimensions, hidden size, perturbation dimension, and training hyperparameters.  
- It also stores gene names, perturbation names, and cell types, while designating `DMSO_TF` as the control, `X` as the embedding key, and `gene` as the output space.
- `perturbmean_metadata_cb` prepares metadata for the Perturbation Mean model. It casts drug and cell line columns to categorical, sets input/output dimensions, hidden size, and perturbation dimension, and defines training hyperparameters. It also records gene names, perturbation names, and cell types, while specifying `DMSO_TF` as the control, `X` as the embedding key, and `gene` as the output space.

In [5]:
def perturbmean_metadata_cb(ad: anndata.AnnData, metadata: dict):
    ad.obs["drug"] = ad.obs["drug"].astype("category")
    ad.obs["cell_line"] = ad.obs["cell_line"].astype("category")

    metadata["input_dim"] = ad.var.shape[0]
    metadata["output_dim"] = ad.var.shape[0]
    metadata["hidden_dim"] = 0  # hidden_dim: Not used here, but required by base-class signature.
    metadata["pert_dim"] = ad.obs["drug"].astype(str).nunique()
    metadata["lr"] = 1e-3

    metadata["gene_names"] = ad.var_names.tolist()
    metadata["pert_names"] = ad.obs["drug"].cat.categories.tolist()
    metadata["cell_types"] = ad.obs["cell_line"].cat.categories.tolist()
    metadata["control_pert"] = "DMSO_TF"
    metadata["embed_key"] = "X"
    metadata["output_space"] = "gene"

### Perturbation Dataset for Training (PerturbAnnDataset)
`PerturbAnnDataset` prepares batches for the **Perturbation Mean** model. It loads expression data, collects `drug` and `cell_line` metadata, and returns a dictionary containing perturbation names, cell types, and the corresponding expression features (used both as counts and embeddings) for training.

In [6]:
class PerturbAnnDataset(DistributedAnnDataset):
    def transform(self, start: int, end: int):
        X = super().transform(start, end)

        # Metadata froms self.ad
        pert_names = self.ad.obs["drug"].iloc[start:end].astype(str).to_list()
        cell_lines = self.ad.obs["cell_line"].iloc[start:end].astype(str).to_list()

        return {
            "pert_name": pert_names,
            "cell_type": cell_lines,
            "pert_cell_counts": X,
            "pert_cell_emb": X,
        }

### Extending STATE Models
The **STATE** framework provides baseline model classes such as `PerturbMeanPerturbationModel`, which can be imported and used directly.
To customize behavior, you can **extend an existing class** and override only the methods that need modification.  
In the example below, we subclass `PerturbMeanPerturbationModel` and redefine the `forward()` method. Rather than relying on the per-cell `ctrl_cell_emb`, the model predicts using only the **global basal vector** combined with the corresponding perturbation offset.

In [7]:
class PerturbMeanGlobalModel(PerturbMeanPerturbationModel):
    """
    Extended class of PerturbMeanPerturbationModel where prediction ignores
    per-cell control embedding and uses only global basal + offset.
    """

    def forward(self, batch: dict) -> torch.Tensor:
        B = len(batch["pert_name"])
        device = self.dummy_param.device
        pred_out = torch.zeros((B, self.output_dim), device=device)

        for i in range(B):
            p_name = str(batch["pert_name"][i])
            offset_vec = self.pert_mean_offsets.get(p_name)
            if offset_vec is None:
                offset_vec = torch.zeros(self.output_dim, device=device)

            # Use global basal instead of batch["ctrl_cell_emb"]
            pred_out[i] = self.global_basal.to(device) + offset_vec.to(device)

        return pred_out

### Training the model
- Collect statistics (`on_fit_start`)
    - Compute control means per cell type
    - Compute perturbation deltas
    - Average deltas across cell types â†’ perturbation offsets
    - Compute global basal = mean of all control means.
- Forward: for each sample, `prediction = global_basal + offset[perturbation]`
- Training: no parameters are learned; only logs MSE loss vs. ground truth

In [8]:
%%time
PerturbMeanPerturbationModel_trainer = RayTrainRunner(
    PerturbMeanGlobalModel,
    PerturbAnnDataset,
    [
        "input_dim",
        "output_dim",
        "hidden_dim",
        "pert_dim",
        "lr",
        "control_pert",  # "DMSO_TF"
        "embed_key",
        "output_space",  # "gene"
    ],
    perturbmean_metadata_cb,
)

2025-11-04 17:21:11,603	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-11-04 17:21:11,692	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-04 17:21:11,720	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_30a635a9689a3416.zip' (3.58MiB) to Ray cluster...
2025-11-04 17:21:11,741	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_30a635a9689a3416.zip'.


CPU times: user 339 ms, sys: 315 ms, total: 654 ms
Wall time: 10.5 s


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `PerturbMeanPerturbationModel_trainer.train()` completed in approximately **47 minutes**.

In [9]:
%%time
PerturbMeanPerturbationModel_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)

protoplast.scrna.anndata.trainer - INFO - Setting thread_per_worker to half of the available CPUs capped at 4
protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 4, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 65 length of test_split: 0, length of train_split: 262
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 65, length of test_split: 0, length of train_split: 262


[36m(TrainController pid=1568983)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=1568983)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=1568983)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1568983)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 4, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=1569566)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=1569566)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=1569566)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=1569566)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1568983)[0m Started training worker group of size 1: 
[36m(TrainController pid=1568983)[0m - (ip=192.168.1.226, pid=1569566) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=1569566)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1569566)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1569566)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1569566)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1569566)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize

Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:00<00:00,  1.58it/s]
Sanity Checking DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00,  2.08it/s]
                                                                           
Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4192 [00:27<32:05:11,  0.04it/s, v_num=0, train_loss=0.406]
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰| 4191/4192 [30:43<00:00,  2.27it/s, v_num=0, train_loss=0.582]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [30:44<00:00,  2.27it/s, v_num=0, train_loss=0.268]
[36m(RayTrainWorker pid=1569566)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=1569566)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1040 [00:00<?, ?it/s][A
...
...
Validati

[36m(RayTrainWorker pid=1569566)[0m PerturbMean: Saved global_basal and pert_mean_offsets to checkpoint.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [38:35<00:00,  1.81it/s, v_num=0, train_loss=0.268]


[36m(RayTrainWorker pid=1569566)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_17-21-43/checkpoint_2025-11-04_18-08-23.026793)
[36m(RayTrainWorker pid=1569566)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_17-21-43/checkpoint_2025-11-04_18-08-23.026793), metrics={'train_loss': 0.26821431517601013, 'val_loss': 0.5752615928649902, 'epoch': 0, 'step': 4192}, validation_spec=None)
[36m(RayTrainWorker pid=1569566)[0m PerturbMean: Saved global_basal and pert_mean_offsets to checkpoint.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [38:36<00:00,  1.81it/s, v_num=0, train_loss=0.268]


[36m(RayTrainWorker pid=1569566)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 40.7 s, sys: 14.2 s, total: 54.9 s
Wall time: 47min 11s


Result(metrics={'train_loss': 0.26821431517601013, 'val_loss': 0.5752615928649902, 'epoch': 0, 'step': 4192}, checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_17-21-43/checkpoint_2025-11-04_18-08-23.026793), error=None, path='/home/dtran/protoplast_results/ray_train_run-2025-11-04_17-21-43', metrics_dataframe=   train_loss  val_loss  epoch  step
0    0.268214  0.575262      0  4192, best_checkpoints=[(Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_17-21-43/checkpoint_2025-11-04_18-08-23.026793), {'train_loss': 0.26821431517601013, 'val_loss': 0.5752615928649902, 'epoch': 0, 'step': 4192})], _storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7dee4c6561f0>)

In [10]:
import ray
ray.shutdown()

## 3. EmbedSum
The **EmbedSumPerturbationModel** (part of the STATE framework) is a neural embedding model that predicts gene expression under perturbations.  
It works by combining a **control (basal) cell state** with a **learned perturbation embedding**.  
**Inputs**
  - Control (basal) expression counts or embedding  
  - Perturbation one-hot vector  

**Output**
  - Predicted gene expression profile  
**Source code:** [embed_sum.py](https://github.com/ArcInstitute/state/blob/b6d26731e41d78c8c789d6973fe3d7db7853e9ad/src/state/tx/models/embed_sum.py#L7)

### Metadata Callback
The `embedsum_metadata_cb` function prepares metadata for the `EmbedSumPerturbationModel`. It sets the input and output dimensions (equal to the **number of genes**), defines the perturbation dimension based on the unique drugs in the dataset, and specifies training parameters such as hidden layer size and the control perturbation (`DMSO_TF`).

In [11]:
def embedsum_metadata_cb(ad: anndata.AnnData, metadata: dict):
    cell_line_metadata_cb(ad, metadata)
    metadata["input_dim"] = ad.var.shape[0]
    metadata["output_dim"] = ad.var.shape[0]

    uniq_drugs = sorted(ad.obs["drug"].astype(str).unique().tolist())
    metadata["pert_dim"] = len(uniq_drugs)

    metadata["hidden_dim"] = 10  # here kept small for testing
    metadata["control_pert"] = "DMSO_TF"

### EmbedSumAnnDataset
The `EmbedSumAnnDataset` class extends `DistributedAnnDataset` and prepares batches for the `EmbedSumPerturbationModel`. It enriches each batch with drug embeddings, metadata, and control information needed for training.

In [12]:
class EmbedSumAnnDataset(DistributedAnnDataset):
    control_drug = "DMSO_TF"

    def transform(self, start: int, end: int):
        # Loads gene expression (X) and converts it into a tensor (target_gene_expr).
        X = super().transform(start, end)
        target_gene_expr = torch.as_tensor(X, dtype=torch.float32)
        device = target_gene_expr.device

        # Collects metadata: perturbation names (drug) and cell line labels
        pert_names = self.ad.obs["drug"].iloc[start:end].astype(str).to_list()
        cell_lines = self.ad.obs["cell_line"].iloc[start:end].astype(str).to_list()

        # Create drug index mapping
        if not hasattr(self, "_drug_to_idx"):
            drug_names = sorted(self.ad.obs["drug"].astype(str).unique())
            self._drug_to_idx = {d: i for i, d in enumerate(drug_names)}
            self._num_drugs = len(drug_names)

        # encodes drugs as one-hot embeddings
        idxs = [self._drug_to_idx.get(p, 0) for p in pert_names]
        pert_emb = torch.nn.functional.one_hot(torch.tensor(idxs, device=device), num_classes=self._num_drugs).float()

        # Computes a global control mean expression vector from cells treated with DMSO_TF
        if not hasattr(self, "_ctrl_global"):
            mask = self.ad.obs["drug"] == self.control_drug
            if mask.sum() == 0:
                ctrl_vec = np.zeros(self.ad.shape[1], dtype=np.float32)
            else:
                ctrl_vec = np.asarray(self.ad[mask].X.mean(axis=0)).ravel().astype(np.float32)
            self._ctrl_global = torch.from_numpy(ctrl_vec)

        ctrl_cell_emb = self._ctrl_global.to(device).unsqueeze(0).expand(len(pert_names), -1)

        # Returns a dictionary containing embeddings, control features, target expression, and metadata for perturbation training
        return {
            "pert_emb": pert_emb,
            "ctrl_cell_emb": ctrl_cell_emb,
            "target_gene_expr": target_gene_expr,
            "pert_cell_emb": target_gene_expr,
            "pert_name": pert_names,
            "cell_type": cell_lines,
        }

### Training the EmbedSumPerturbationModel
- It pairs the model with the custom `EmbedSumAnnDataset` and passes in required arguments (dimensions, learning rate, control perturbation, embedding key, and output space) via `embedsum_metadata_cb`.

In [13]:
%%time
EmbedSumPerturbationModel_trainer = RayTrainRunner(
    EmbedSumPerturbationModel,
    EmbedSumAnnDataset,
    [
        "input_dim",
        "output_dim",
        "hidden_dim",
        "pert_dim",
        "lr",
        "control_pert",  # "DMSO_TF"
        "embed_key",
        "output_space",  # "gene"
    ],
    embedsum_metadata_cb,
)

2025-11-04 18:08:34,716	INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-11-04 18:08:34,843	INFO packaging.py:588 -- Creating a file package for local module '/mnt/hdd1/dung/protoplast-ml-example/notebooks'.
2025-11-04 18:08:34,871	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_5c046a9a69b2a95d.zip' (3.07MiB) to Ray cluster...
2025-11-04 18:08:34,891	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_5c046a9a69b2a95d.zip'.


CPU times: user 210 ms, sys: 326 ms, total: 536 ms
Wall time: 10 s


On a machine with **1 GPU (NVIDIA GeForce RTX 3080 - 12 GiB)**, **96 CPUs**, and **125 GiB RAM**, running `EmbedSumPerturbationModel_trainer.train()` completed in approximately **28 minutes**.

In [14]:
%%time
EmbedSumPerturbationModel_trainer.train(
    file_paths,
    batch_size,  # 2000
    test_size,  # 0.0
    val_size,  # 0.2
)

protoplast.scrna.anndata.trainer - INFO - Setting thread_per_worker to half of the available CPUs capped at 4
protoplast.scrna.anndata.trainer - INFO - Using 1 workers where each worker uses: {'CPU': 4, 'GPU': 1}
protoplast.scrna.anndata.strategy - INFO - Length of val_split: 65 length of test_split: 0, length of train_split: 262
protoplast.scrna.anndata.strategy - INFO - Length of after dropping remainder val_split: 65, length of test_split: 0, length of train_split: 262


[36m(TrainController pid=1595724)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainController pid=1595724)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(TrainController pid=1595724)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1595724)[0m Attempting to start training worker group of size 1 with the following resources: [{'CPU': 4, 'GPU': 1}] * 1


[36m(RayTrainWorker pid=1596111)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=1596111)[0m âœ“ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=1596111)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=1596111)[0m root - INFO - Logging initialized. Current level is: INFO
[36m(TrainController pid=1595724)[0m Started training worker group of size 1: 
[36m(TrainController pid=1595724)[0m - (ip=192.168.1.226, pid=1596111) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=1596111)[0m ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=1596111)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=1596111)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=1596111)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=1596111)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


[36m(RayTrainWorker pid=1596111)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(RayTrainWorker pid=1596111)[0m 
[36m(RayTrainWorker pid=1596111)[0m   | Name          | Type       | Params | Mode 
[36m(RayTrainWorker pid=1596111)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1596111)[0m 0 | loss_fn       | MSELoss    | 0      | train
[36m(RayTrainWorker pid=1596111)[0m 1 | pert_encoder  | Sequential | 1.1 K  | train
[36m(RayTrainWorker pid=1596111)[0m 2 | basal_encoder | Sequential | 627 K  | train
[36m(RayTrainWorker pid=1596111)[0m 3 | project_out   | Sequential | 689 K  | train
[36m(RayTrainWorker pid=1596111)[0m -----------------------------------------------------
[36m(RayTrainWorker pid=1596111)[0m 1.3 M     Trainable params
[36m(RayTrainWorker pid=1596111)[0m 0         Non-trainable params
[36m(RayTrainWorker pid=1596111)[0m 1.3 M     Total params
[36m(RayTrainWorker pid=1596111)[0m 5.273     Total estimated model pa

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1/2 [00:00<00:00,  1.19it/s]
                                                                           
Epoch 0:   0%|          | 0/4192 [00:00<?, ?it/s]
Epoch 0:   0%|          | 1/4192 [01:19<92:52:41,  0.01it/s, v_num=0]
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [20:29<00:00,  3.41it/s, v_num=0]
[36m(RayTrainWorker pid=1596111)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
[36m(RayTrainWorker pid=1596111)[0m 
Validation: |          | 0/? [00:00<?, ?it/s][A
...
...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [25:46<00:00,  2.71it/s, v_num=0]       [A
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [25:46<00:00,  2.71it/s, v_num=0]
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4192/4192 [25:46<00:00,  2.71it/s, v_num=0]


[36m(RayTrainWorker pid=1596111)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_18-09-02/checkpoint_2025-11-04_18-36-31.757231)
[36m(RayTrainWorker pid=1596111)[0m Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_18-09-02/checkpoint_2025-11-04_18-36-31.757231), metrics={'train_loss': 0.29701322317123413, 'val_loss': 0.5971354246139526, 'epoch': 0, 'step': 4192}, validation_spec=None)
[36m(RayTrainWorker pid=1596111)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


CPU times: user 36.9 s, sys: 9.78 s, total: 46.7 s
Wall time: 27min 56s


Result(metrics={'train_loss': 0.29701322317123413, 'val_loss': 0.5971354246139526, 'epoch': 0, 'step': 4192}, checkpoint=Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_18-09-02/checkpoint_2025-11-04_18-36-31.757231), error=None, path='/home/dtran/protoplast_results/ray_train_run-2025-11-04_18-09-02', metrics_dataframe=   train_loss  val_loss  epoch  step
0    0.297013  0.597135      0  4192, best_checkpoints=[(Checkpoint(filesystem=local, path=/home/dtran/protoplast_results/ray_train_run-2025-11-04_18-09-02/checkpoint_2025-11-04_18-36-31.757231), {'train_loss': 0.29701322317123413, 'val_loss': 0.5971354246139526, 'epoch': 0, 'step': 4192})], _storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7de4f8da1ff0>)