In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [2]:
from cell_load.data_modules import PerturbationDataModule
dm = PerturbationDataModule(
    toml_config_path="starter.toml",
    embed_key=None, 
    num_workers=8,
    batch_col="batch_var",
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    use_scplode = True,
    perturbation_features_file="/home/tphan/state/state/competition_support_set/ESM2_pert_features.pt",
    output_space="gene",
    basal_mapping_strategy="random",
    n_basal_samples=1,
    should_yield_control_cells=True,
    batch_size=16,
)
dm.setup()

Dataset path does not exist: /home/tphan/state/state/competition_support_set/{competition_train,k562_gwps,rpe1,jurkat,k562,hepg2}.h5


/home/tphan/state/state/competition_support_set/{competition_train,k562_gwps,rpe1,jurkat,k562,hepg2}.h5


Processing replogle_h1:   0%|                                                                                        | 0/6 [00:00<?, ?it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:   0%|                                                                                        | 0/6 [00:00<?, ?it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:  33%|██████████████████████████▋                                                     | 2/6 [00:00<00:00, 13.70it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:  33%|██████████████████████████▋                                                     | 2/6 [00:00<00:00, 13.70it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing replogle_h1:  33%|██████████████████████████▋                        

Processed competition_train: 221273 train, 0 val, 0 test
Processed k562_gwps: 111605 train, 0 val, 0 test
Processed rpe1: 22317 train, 0 val, 0 test
Processed jurkat: 21412 train, 0 val, 0 test
Processed k562: 18465 train, 0 val, 0 test
Processed hepg2: 0 train, 0 val, 9386 test





In [3]:
# # Get training data
# train_loader = dm.train_dataloader()
# for batch in train_loader:
#     print(batch["cell_type_onehot"].shape)
#     print(batch["pert_cell_emb"].shape)
#     print(batch["ctrl_cell_emb"].shape)
#     print(batch["pert_emb"].shape)
#     break

In [4]:
def _to_BG(t):
    if t.ndim == 3 and t.size(1) == 1: t = t.squeeze(1)
    if t.ndim == 3 and t.size(-1) == 1: t = t.squeeze(-1)
    if t.ndim == 1: t = t.unsqueeze(0)
    if t.ndim != 2: t = t.view(t.size(0), -1)
    return t.contiguous()


In [5]:
import wandb

# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="tanphan-dxt-dataxight",
    # Set the wandb project where this run will be logged.
    project="vcc-simple",
    # Track hyperparameters and run metadata.
    name="baseline-delta-pert-emb",
    config={
        "learning_rate": 0.001,
        "architecture": "baseline-delta",
        "dataset": "competition_support",
        "epochs": 30,
        "gpu":"rtx-3080",
        "loss": "ctrl+pert+delta"
    },
)

[34m[1mwandb[0m: Currently logged in as: [33mtanphan-dxt[0m ([33mtanphan-dxt-dataxight[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
tfrom tqdm import tqdm
import torch.nn.functional as F

def train_baseline_epoch(
    model,
    dataloader,
    optimizer,
    epoch: int,
    device: str = "cuda",
    use_amp: bool = True
):
    model.train()
    total_loss = 0.0

    # AMP context + persistent scaler
    amp_ctx = torch.cuda.amp.autocast if (use_amp and torch.cuda.is_available()) else nullcontext
    scaler = getattr(train_baseline_epoch, "_scaler", None)
    if scaler is None and use_amp and torch.cuda.is_available():
        scaler = torch.cuda.amp.GradScaler()
        train_baseline_epoch._scaler = scaler

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)
    for step, batch in enumerate(pbar, 1):
        x, y, xp, x_ctrl_match = batch["pert_cell_emb"], batch["cell_type_onehot"], batch["pert_emb"], batch["ctrl_cell_emb"]
        x = x.squeeze(1)
        x_ctrl_match = x_ctrl_match.squeeze(1)
        x, y, xp, x_ctrl_match = x.to(device), y.to(device), xp.to(device), x_ctrl_match.to(device)
        
        with amp_ctx():
            x_ctrl_pred, delta_pred, x_pred = model(y, xp)

            x           = _to_BG(x)
            x_ctrl_match= _to_BG(x_ctrl_match)
            x_ctrl_pred = _to_BG(x_ctrl_pred)
            x_pred      = _to_BG(x_pred)
            delta_pred  = _to_BG(delta_pred)
            
            true_delta  = _to_BG(x - x_ctrl_match)

            # Make sure dtype/device match (esp. with AMP)
            if delta_pred.dtype != true_delta.dtype:
                true_delta = true_delta.to(delta_pred.dtype)
            
            # Assert exact shape equality; fail fast if not
            assert x_ctrl_pred.shape == x_ctrl_match.shape, f"{x_ctrl_pred.shape=} vs {x_ctrl_match.shape=}"
            assert x_pred.shape      == x.shape,            f"{x_pred.shape=} vs {x.shape=}"
            assert delta_pred.shape  == true_delta.shape,   f"{delta_pred.shape=} vs {true_delta.shape=}"

            # Losses
            loss_ctrl = F.mse_loss(x_ctrl_pred, x_ctrl_match)
            loss_pert = F.mse_loss(x_pred, x)
            true_delta = x - x_ctrl_match
            loss_delta = F.mse_loss(delta_pred, true_delta)
            

            loss = loss_ctrl + loss_pert + loss_delta

            wandb.log({
                "train/loss_step": loss.item(),
                "epoch": epoch,
                "step": step + (epoch-1)*len(train_loader)
            })

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        total_loss += float(loss.detach().item())
        pbar.set_postfix({"loss": f"{total_loss / (pbar.n or 1):.4f}"})

    avg_loss = total_loss / max(1, len(dataloader))
    wandb.log({
        "train/loss_epoch": avg_loss,
        "epoch": epoch
    })
    return avg_loss

In [7]:
def save_checkpoint(model, optimizer, epoch, model_dir):
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
    }, f"{model_dir}/epoch={epoch}.pt")

In [8]:
import os
import torch
from protoplast.scrna.models.baseline import BaselinePerturbModel

G = 18080           # genes
n_cell_lines = 5
pert_d = 5120   # genes + control

device = "cuda" if torch.cuda.is_available() else "cpu"
start_epoch = 10
max_epoch = 30
last_ck = f"epoch={start_epoch}.pt"

model = BaselinePerturbModel(G, n_cell_lines, pert_d).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
train_loader = dm.train_dataloader()

if os.path.exists(last_ck):
    ckpt = torch.load(last_ck, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
else:
    start_epoch = 1
    
# wandb.watch(model, log="all")

for epoch in range(start_epoch, max_epoch + 1):
    loss = train_baseline_epoch(model, train_loader, optimizer, epoch)
    if not (epoch) % 10:
        save_checkpoint(model, optimizer, epoch, "baseline-delta-pert-emb")

                                                                                                                                           

[1;34mwandb[0m: 
[1;34mwandb[0m: 🚀 View run [33mbaseline-delta-pert-emb[0m at: [34mhttps://wandb.ai/tanphan-dxt-dataxight/vcc-simple/runs/7opznwob[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20250902_165052-7opznwob/logs[0m
