In [1]:
import anndata
import glob
import numpy as np
import pandas as pd
import os
import pathlib
import protoplast as pt
import ray
import torch

from anndata.experimental import AnnCollection
from protoplast.scrna.anndata.lightning_models import LinearClassifier
from protoplast.scrna.anndata.trainer import RayTrainRunner
from protoplast.scrna.anndata.torch_dataloader import DistributedAnnDataset
from protoplast.scrna.anndata.torch_dataloader import cell_line_metadata_cb, DistributedCellLineAnnDataset

from ray.train import Checkpoint
from ray.train.lightning import RayDDPStrategy

✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


  import pynvml  # type: ignore[import]


✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


## Training Set Overview
To start, we load the dataset used to train the example gene perturbation model featured in this notebook. This dataset will also serve as the basis for filtering non-targeting controls in the final prediction set, which we'll submit to the Virtual Cell Challenge.

In [2]:
file_paths = ["/mnt/hdd2/tan/competition_support_set/competition_train.h5"]
adata = anndata.read_h5ad(file_paths[0], backed = "r")
adata.obs.head(n = 5)

Unnamed: 0,target_gene,guide_id,batch,batch_var,cell_type
AAACAAGCAACCTTGTACTTTAGG-Flex_1_01,CHMP3,CHMP3_P1P2_A|CHMP3_P1P2_B,Flex_1_01,Flex_1_01,ARC_H1
AAACAAGCATTGCCGCACTTTAGG-Flex_1_01,AKT2,AKT2_P1P2_A|AKT2_P1P2_B,Flex_1_01,Flex_1_01,ARC_H1
AAACCAATCAATGTTCACTTTAGG-Flex_1_01,SHPRH,SHPRH_P1P2_A|SHPRH_P1P2_B,Flex_1_01,Flex_1_01,ARC_H1
AAACCAATCCCTCGCTACTTTAGG-Flex_1_01,TMSB4X,TMSB4X_P1_A|TMSB4X_P1_B,Flex_1_01,Flex_1_01,ARC_H1
AAACCAATCTAAATCCACTTTAGG-Flex_1_01,KLF10,KLF10_P2_A|KLF10_P2_B,Flex_1_01,Flex_1_01,ARC_H1


## Define example perturbation model

In [3]:
import lightning.pytorch as pl
import torch
from torch import nn

class ExampleModel(pl.LightningModule):
    def __init__(self, num_genes, num_classes, control_pert, pert_names):
        super().__init__()
        self.control_pert = control_pert
        self.pert_names = pert_names
        self.embedding = nn.Embedding(num_embeddings = len(pert_names), embedding_dim = num_genes)
        self.output = nn.Linear(num_genes, num_classes)
    
    def forward(self, batch):
        embed = self.embedding(batch)
        out = self.output(embed).squeeze(1)
        return out

    def training_step(self, batch, batch_idx):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")        
        
        batch_pert_cell_counts = batch["pert_cell_counts"]
        X = batch_pert_cell_counts

        batch_pert_names = np.array(batch["pert_name"])
        sorted_idx = np.argsort(self.pert_names)
        pos = np.searchsorted(self.pert_names[sorted_idx], batch_pert_names)
        indices = torch.tensor(sorted_idx[pos]).unsqueeze(1).to(device)

        out = self(indices)
        mse_loss_fn = nn.MSELoss()
        loss = mse_loss_fn(out, X)
        
        self.log("train_loss", loss, on_step=True, prog_bar=True, sync_dist=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [4]:
class PerturbAnnDataset(DistributedAnnDataset):
    def transform(self, start: int, end: int):
        X = super().transform(start, end)

        # Metadata froms self.ad
        pert_names = self.ad.obs["target_gene"].iloc[start:end].astype(str).to_list()

        return {
            "pert_name": pert_names,
            "pert_cell_counts": X
        }

## Configure the model parameters

In [5]:
num_genes = adata.n_vars
output_size = adata.n_vars
control_pert = "non-targeting"

validation_genes = pd.read_csv("/mnt/hdd2/tan/competition_support_set/pert_counts_Validation.csv").target_gene
pert_names = np.array(adata.obs.target_gene.unique().tolist() + validation_genes.tolist())

In [6]:
def metadata_cb(ad: anndata.AnnData, metadata: dict):
    metadata["num_genes"] = num_genes
    metadata["num_classes"] = output_size
    metadata["control_pert"] = control_pert
    metadata["pert_names"] = pert_names

## Create trainer & train perturbation model

In [7]:
trainer = RayTrainRunner(
    ExampleModel,
    PerturbAnnDataset,
    model_keys = ["num_genes",
                  "num_classes",
                  "control_pert",
                  "pert_names"],
    metadata_cb = metadata_cb,
    sparse_key = "X"
)

2025-09-28 13:23:20,084	INFO worker.py:1951 -- Started a local Ray instance.
[36m(pid=2601474)[0m   import pynvml  # type: ignore[import]


[36m(TrainTrainable pid=2601474)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(TrainTrainable pid=2601474)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=2601762)[0m   import pynvml  # type: ignore[import]
[36m(RayTrainWorker pid=2601762)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=2601474)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=2601474)[0m - (node_id=dead53e42ca5e556bc76dfd97e8f424f74a97e8028e458b770249e18, ip=192.168.1.226, pid=2601762) world_rank=0, local_rank=0, node_rank=0


[36m(RayTrainWorker pid=2601762)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!
[36m(RayTrainWorker pid=2601762)[0m ✓ Applied AnnDataFileManager patch, AnnData cannot be imported after the patch!


[36m(RayTrainWorker pid=2601762)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(RayTrainWorker pid=2601762)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=2601762)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=2601762)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=2601762)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.1 ...
[36m(RayTrainWorker pid=2601762)[0m /mnt/hdd2/nam/miniconda3/envs/test/lib/python3.11/site-packages/lightning/pytorch/trainer/configura

Epoch 0:   0%|          | 0/96 [00:00<?, ?it/s]


[36m(RayTrainWorker pid=2601762)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2601762)[0m   return torch.sparse_csr_tensor(
[36m(RayTrainWorker pid=2601762)[0m   return torch.sparse_compressed_tensor(


Epoch 0:   1%|          | 1/96 [00:03<06:11,  0.26it/s]
Epoch 0:   1%|          | 1/96 [00:03<06:12,  0.25it/s, v_num=0, train_loss=1.680]
Epoch 0:   2%|▏         | 2/96 [00:04<03:12,  0.49it/s, v_num=0, train_loss=1.680]
Epoch 0:   2%|▏         | 2/96 [00:04<03:20,  0.47it/s, v_num=0, train_loss=19.30]
Epoch 0:   3%|▎         | 3/96 [00:04<02:17,  0.68it/s, v_num=0, train_loss=19.30]
Epoch 0:   3%|▎         | 3/96 [00:04<02:22,  0.65it/s, v_num=0, train_loss=2.570]
Epoch 0:   4%|▍         | 4/96 [00:04<01:49,  0.84it/s, v_num=0, train_loss=2.570]
Epoch 0:   4%|▍         | 4/96 [00:04<01:53,  0.81it/s, v_num=0, train_loss=7.710]
Epoch 0:   5%|▌         | 5/96 [00:05<01:33,  0.98it/s, v_num=0, train_loss=7.710]
Epoch 0:   5%|▌         | 5/96 [00:05<01:36,  0.95it/s, v_num=0, train_loss=11.10]
Epoch 0:   6%|▋         | 6/96 [00:05<01:21,  1.10it/s, v_num=0, train_loss=11.10]
Epoch 0:   6%|▋         | 6/96 [00:05<01:24,  1.07it/s, v_num=0, train_loss=6.330]
Epoch 0:   7%|▋         | 7/96 

[36m(RayTrainWorker pid=2601762)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/nam/protoplast_results/TorchTrainer_2025-09-28_13-23-24/TorchTrainer_4f753_00000_0_2025-09-28_13-23-24/checkpoint_000000)


Epoch 0: 100%|██████████| 96/96 [00:53<00:00,  1.81it/s, v_num=0, train_loss=0.185]
Epoch 0: 100%|██████████| 96/96 [01:04<00:00,  1.48it/s, v_num=0, train_loss=0.185]


[36m(RayTrainWorker pid=2601762)[0m `Trainer.fit` stopped: `max_epochs=1` reached.


In [8]:
thread_per_worker = 2
batch_size = 2000
test_size = 0.0
val_size = 0.0

result = trainer.train(
    file_paths = file_paths,
    batch_size = batch_size,
    test_size = test_size,
    val_size = val_size,
    thread_per_worker = thread_per_worker,  # 2
)

Using 1 workers with {'CPU': 2} each


2025-09-28 13:23:24,183	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


Data splitting time: 0.24 seconds
Spawning Ray worker and initiating distributed training
== Status ==
Current time: 2025-09-28 13:23:24 (running for 00:00:00.14)
Using FIFO scheduling algorithm.
Logical resource usage: 0/96 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-28_13-23-17_060244_2592720/artifacts/2025-09-28_13-23-24/TorchTrainer_2025-09-28_13-23-24/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-28 13:23:29 (running for 00:00:05.16)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-28_13-23-17_060244_2592720/artifacts/2025-09-28_13-23-24/TorchTrainer_2025-09-28_13-23-24/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-28 13:23:34 (running for 00:00:10.20)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)

2025-09-28 13:24:58,921	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/nam/protoplast_results/TorchTrainer_2025-09-28_13-23-24' in 0.0099s.
2025-09-28 13:24:58,925	INFO tune.py:1041 -- Total run time: 94.74 seconds (94.70 seconds for the tuning loop).


== Status ==
Current time: 2025-09-28 13:24:58 (running for 00:01:34.71)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/96 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-09-28_13-23-17_060244_2592720/artifacts/2025-09-28_13-23-24/TorchTrainer_2025-09-28_13-23-24/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)




In [9]:
ray.shutdown()

## Perturbation Prediction Using the Trained Model

With the model trained, we can now generate perturbation predictions. We'll start by loading the validation set, which provides the target genes and the number of predictions required for each gene.

In [10]:
validation_set = pd.read_csv("/mnt/hdd2/tan/competition_support_set/pert_counts_Validation.csv")
validation_set.head(n = 5)

Unnamed: 0,target_gene,n_cells,median_umi_per_cell
0,SH3BP4,2925,54551.0
1,ZNF581,2502,53803.5
2,ANXA6,2496,55175.0
3,PACSIN3,2101,54088.0
4,MGST1,2096,54217.5


In [11]:
X = torch.tensor([]).long()
for row in validation_set.iloc:
    tmp = torch.tensor([np.where(pert_names == row.target_gene)][0] * row.n_cells)
    X = torch.cat((X, tmp), dim = 0)

  tmp = torch.tensor([np.where(pert_names == row.target_gene)][0] * row.n_cells)


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")        
X = X.to(device)

Load the trained model from the checkpoint path in previous `train()` step & make prediction

In [13]:
model = ExampleModel.load_from_checkpoint(result.checkpoint.path + "/checkpoint.ckpt", 
                                          num_genes = num_genes, 
                                          num_classes = output_size, 
                                          control_pert = control_pert, 
                                          pert_names = pert_names)
model.eval()  # set to eval mode

ExampleModel(
  (embedding): Embedding(201, 18080)
  (output): Linear(in_features=18080, out_features=18080, bias=True)
)

In [14]:
pred = model(X)
print(pred)

tensor([[ 0.3546,  0.6941,  0.3020,  ...,  0.7363,  0.4245, -0.1082],
        [ 0.3546,  0.6941,  0.3020,  ...,  0.7363,  0.4245, -0.1082],
        [ 0.3546,  0.6941,  0.3020,  ...,  0.7363,  0.4245, -0.1082],
        ...,
        [-0.3894, -0.5620, -0.1625,  ...,  0.2070,  0.5451,  0.2905],
        [-0.3894, -0.5620, -0.1625,  ...,  0.2070,  0.5451,  0.2905],
        [-0.3894, -0.5620, -0.1625,  ...,  0.2070,  0.5451,  0.2905]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)


## Creating the Predicted AnnData
Next, we generate an `AnnData` object from the model’s predicted gene perturbations. To complete the dataset, we also include the non-targeting control group. Since these controls aren't part of the prediction, we'll copy them directly from the training set.

In [15]:
sample_predicted_adata = anndata.AnnData(
    X = pred.cpu().detach().numpy(),
    obs = pd.DataFrame(
        {
            "target_gene": np.repeat(validation_set.target_gene, validation_set.n_cells).tolist(),
        },
        index = np.arange(validation_set.n_cells.sum()).astype(str),
    ),
    var = pd.DataFrame(index = list(adata.var_names)),
)

In [16]:
sample_submission = anndata.concat([adata[adata.obs["target_gene"] == "non-targeting"], 
                                    sample_predicted_adata])


In [19]:
sample_submission.write_h5ad("~/result/prediction.h5ad")

## Running `cell-eval`

We’ll now use `cell-eval` to process the `AnnData` object and prepare it for submission to the competition.

In [None]:
!cell-eval prep -i ~/result/prediction.h5ad -g /mnt/hdd2/tan/competition_support_set/gene_names.csv

INFO:cell_eval._cli._prep:Reading input anndata
INFO:cell_eval._cli._prep:Reading gene list
INFO:cell_eval._cli._prep:Preparing anndata
INFO:cell_eval._cli._prep:Using 32-bit float encoding
INFO:cell_eval._cli._prep:Setting data to sparse if not already
