In [2]:
%load_ext autoreload
%autoreload 2
from protoplast.scrna.anndata.dataloader import PerturbDataset
from protoplast.scrna.models.baseline import BaselinePerturbModel

✓ Applied AnnDataFileManager patch


In [3]:
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F

def train_baseline_epoch(model, dataloader, optimizer, device="cuda"):
    model.train()
    total_loss = 0.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)
    for x, y, b, xp, x_ctrl_match in pbar:
        x, y, xp, x_ctrl_match = x.to(device), y.to(device), xp.to(device), x_ctrl_match.to(device)
        
        x_ctrl_pred, delta_pred, x_pred = model(y, xp)
        
        is_ctrl = (xp.sum() == 0)
        is_pert = ~is_ctrl
        
        # Losses
        # .squeeze to remove the extra dimension getting from adata
        loss_ctrl = F.mse_loss(x_ctrl_pred[is_ctrl], x[is_ctrl].squeeze(2)) if is_ctrl.any() else 0.0
        loss_pert = F.mse_loss(x_pred[is_pert], x[is_pert].squeeze(2)) if is_pert.any() else 0.0
        if is_pert.any():
            true_delta = x[is_pert] - x_ctrl_match[is_pert]
            loss_delta = F.mse_loss(delta_pred[is_pert], true_delta.squeeze(2))
        else:
            loss_delta = 0.0
        
        # Total loss
        loss = loss_ctrl + loss_pert + loss_delta
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += float(loss.item())
    avg_loss = total_loss / len(dataloader)
    pbar.set_postfix({"loss": f"{avg_loss:.4f}"})
    return avg_loss

In [4]:
import glob
data_dir = "/home/tphan/state/state/competition_support_set/*.h5"
ds = PerturbDataset(glob.glob(data_dir))

2025-09-02 10:39:29,646 - protoplast.scrna.anndata.dataloader - INFO - write mmap file for /home/tphan/state/state/competition_support_set/rpe1.h5
[INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
2025-09-02 10:39:29,656 - protoplast.scrna.anndata.dataloader - INFO - n_obs for /home/tphan/state/state/competition_support_set/rpe1.h5: 22317
2025-09-02 10:39:29,657 - protoplast.scrna.anndata.dataloader - INFO - write mmap file for /home/tphan/state/state/competition_support_set/k562.h5
[INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
2025-09-02 10:39:29,667 - protoplast.scrna.anndata.dataloader - INFO - n_obs for /home/tphan/state/state/competition_support_set/k562.h5: 18465
2025-09-02 10:39:29,667 - protoplast.scrna.anndata.dataloader - INFO - write mmap file for /home/tphan/state/state/competition_support_set/k562_gwps.h5
[INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (impli

In [5]:
import wandb

# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="tanphan-dxt-dataxight",
    # Set the wandb project where this run will be logged.
    project="vcc-simple",
    # Track hyperparameters and run metadata.
    name="baseline-delta",
    config={
        "learning_rate": 0.001,
        "architecture": "baseline-delta",
        "dataset": "competition_support",
        "epochs": 10,
        "gpu":"rtx-3080"
    },
)

[34m[1mwandb[0m: Currently logged in as: [33mtanphan-dxt[0m ([33mtanphan-dxt-dataxight[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
def save_checkpoint(model, optimizer, epoch, model_dir):
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
    }, f"{model_dir}/epoch={epoch}.pt")

In [17]:
from torch.utils.data import DataLoader
dataloader = DataLoader(ds, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=8)

In [18]:
import os
import torch

G = 18080           # genes
n_cell_lines = 5
n_targets = 373   # genes + control

device = "cuda" if torch.cuda.is_available() else "cpu"
start_epoch = 10
max_epoch = 200
last_ck = f"epoch={start_epoch}.pt"

model = BaselinePerturbModel(G, n_cell_lines, len(ds.perturbs_vocab)).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

if os.path.exists(last_ck):
    ckpt = torch.load(last_ck, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
else:
    start_epoch = 0
    
    
wandb.watch(model, log="all")

for epoch in range(start_epoch, max_epoch + 1):
    loss = train_baseline_epoch(model, dataloader, optimizer, device)
    wandb.log({"train/loss": loss, "epoch": epoch})
    if not epoch % 5:
        save_checkpoint(model, optimizer, epoch, "baseline-delta")


Exception in thread Thread-64 (_pin_memory_loop):                                                                                                           
Traceback (most recent call last):
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 772, in run_closure
    _threading_Thread_run(self)
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 61, in _pin_memory_loop
    do_one_step()
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 37, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^

KeyboardInterrupt: 

    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/site-packages/torch/multiprocessing/reductions.py", line 541, in rebuild_storage_fd
    fd = df.detach()
         ^^^^^^^^^^^
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/multiprocessing/resource_sharer.py", line 86, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/multiprocessing/connection.py", line 519, in Client
    c = SocketClient(address)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tphan/miniconda3/envs/python312/lib/python3.12/multiprocessing/

In [None]:
ds._get_file_idx(228314)

# Inference

In [19]:
ckpt = torch.load("baseline-delta/epoch=75.pt", map_location=device)
model = BaselinePerturbModel(G, n_cell_lines, n_targets).to(device)
model.load_state_dict(ckpt["model_state"])

<All keys matched successfully>

In [20]:
y_test = torch.tensor([3], device=device)
xp_test = torch.tensor([123], device=device)  # gene 123
x_ctrl, delta, x_pred = model(y_test, xp_test)

print("Pred control:", x_ctrl.shape)  # [1, G]
print("Pred perturbed:", x_pred.shape)  # [1, G]

Pred control: torch.Size([1, 18080])
Pred perturbed: torch.Size([1, 18080])


In [21]:
x_pred

tensor([[-5.9167e-04,  5.3036e-01,  2.5000e-02,  ...,  3.4181e+00,
          1.9216e+00,  5.0787e+00]], device='cuda:0', grad_fn=<AddBackward0>)

In [27]:
x_pred.detach().cpu().numpy()

array([[-5.9167179e-04,  5.3035599e-01,  2.5000028e-02, ...,
         3.4181166e+00,  1.9216139e+00,  5.0787258e+00]],
      shape=(1, 18080), dtype=float32)

In [28]:
ds.cell_types

array(['ARC_H1', 'hepg2', 'jurkat', 'k562', 'rpe1'], dtype='<U6')

In [6]:
ds.sp_adatas[0].obs

Unnamed: 0_level_0,gem_group,gene,gene_id,transcript,gene_transcript,sgID_AB,mitopercent,UMI_count,z_gemgroup_UMI,core_scale_factor,core_adjusted_UMI_count,batch_var,target_gene,cell_type
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
AAACCCAAGAGAGAAC-35,35,non-targeting,non-targeting,non-targeting,11209_non-targeting_non-targeting_non-targeting,non-targeting_02989|non-targeting_02406,0.056552,8753.0,0.368225,0.600710,14571.081055,rpe135,non-targeting,rpe1
AAACCCAAGCGCACAA-2,2,TFAM,ENSG00000108064,P1P2,8832_TFAM_P1P2_ENSG00000108064,TFAM_+_60145205.23-P1P2|TFAM_-_60145223.23-P1P2,0.022038,11934.0,0.083506,0.936004,12749.941406,rpe12,TFAM,rpe1
AAACCCAAGCGGACAT-34,34,non-targeting,non-targeting,non-targeting,11238_non-targeting_non-targeting_non-targeting,non-targeting_03171|non-targeting_03032,0.058144,12469.0,0.970058,0.709628,17571.187500,rpe134,non-targeting,rpe1
AAACCCAAGCTCGACC-34,34,TAZ,ENSG00000102125,P1P2,8703_TAZ_P1P2_ENSG00000102125,TAZ_-_153640125.23-P1P2|TAZ_+_153639918.23-P1P2,0.053064,16075.0,1.957035,0.709628,22652.726562,rpe134,TAZ,rpe1
AAACCCAAGTGCACCC-42,42,non-targeting,non-targeting,non-targeting,10829_non-targeting_non-targeting_non-targeting,non-targeting_00483|non-targeting_03391,0.058406,12961.0,1.024948,0.725025,17876.626953,rpe142,non-targeting,rpe1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGTCCATCTCG-35,35,non-targeting,non-targeting,non-targeting,11025_non-targeting_non-targeting_non-targeting,non-targeting_01813|non-targeting_00122,0.008165,2327.0,-1.579949,0.600710,3873.746826,rpe135,non-targeting,rpe1
TTTGTTGTCGGCTCTT-36,36,non-targeting,non-targeting,non-targeting,11208_non-targeting_non-targeting_non-targeting,non-targeting_02977|non-targeting_01146,0.061335,11853.0,1.030208,0.644705,18385.166016,rpe136,non-targeting,rpe1
TTTGTTGTCTGCACCT-44,44,MAX,ENSG00000125952,P1P2,4871_MAX_P1P2_ENSG00000125952,MAX_+_65569008.23-P1P2|MAX_-_65568906.23-P1P2,0.079629,22228.0,0.767826,1.321505,16820.220703,rpe144,MAX,rpe1
TTTGTTGTCTGGGCAC-32,32,ATP6V0C,ENSG00000185883,P1P2,675_ATP6V0C_P1P2_ENSG00000185883,ATP6V0C_+_2564168.23-P1P2|ATP6V0C_-_2563995.23...,0.049527,12377.0,-0.004215,0.989524,12508.036133,rpe132,ATP6V0C,rpe1


In [7]:
ds.perturbs.sort()

In [8]:
ds.perturbs

array(['ACAT2', 'ACAT2_P1P2_A|ACAT2_P1P2_B', 'ACLY', 'ACVR1B',
       'ACVR1B_P1P2_A|ACVR1B_P1P2_B', 'AKT2', 'AKT2_P1P2_A|AKT2_P1P2_B',
       'ANTXR1_P1P2_A|ANTXR1_P1P2_B', 'ANXA6', 'ARID1A',
       'ARID1A_P1P2_A|ARID1A_P1P2_B', 'ARPC2', 'ATM',
       'ATM_P1P2_A|ATM_P1P2_B', 'ATP1B1_P1P2_A|ATP1B1_P1P2_B', 'ATP6V0B',
       'ATP6V0B_P1P2_A|ATP6V0B_P1P2_B', 'ATP6V0C',
       'ATP6V0C_P1P2_A|ATP6V0C_P1P2_B', 'BIRC2',
       'BIRC2_P1P2_A|BIRC2_P1P2_B', 'BRD9', 'BRD9_P1P2_A|BRD9_P1P2_B',
       'C1QBP', 'C1QBP_P1P2_A|C1QBP_P1P2_B', 'CALM3',
       'CALM3_ENST00000291295.9_A|CALM3_ENST00000291295.9_B',
       'CAMSAP2_P1P2_A|CAMSAP2_P1P2_B', 'CASP2',
       'CASP2_P1P2_A|CASP2_P1P2_B', 'CASP3', 'CASP3_P1P2_A|CASP3_P1P2_B',
       'CAST', 'CAST_P1_A|CAST_P1_B', 'CAST_P2_A|CAST_P2_B', 'CDCA2',
       'CDCA2_P1P2_A|CDCA2_P1P2_B', 'CENPB', 'CENPO',
       'CENPO_P1P2_A|CENPO_P1P2_B', 'CHMP3', 'CHMP3_P1P2_A|CHMP3_P1P2_B',
       'CLDN6', 'CLDN6_P1P2_A|CLDN6_P1P2_B', 'CLDN7',
       'CLDN7_P1P