#Set up an environment

In [None]:
!pip install -U pip setuptools wheel packaging
#The session will automatically restart to use newly installed versions of packages

In [None]:
!pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121

In [None]:
# bring requirement file from google drive to content folder in colab local drive
!pip install -q gdown
!gdown --id 1OCq1jiHBm5kPrMO4AVVvgfuoXThDfJJJ -O /content/requirements-colab.txt
!pip install -r requirements-colab.txt

In [None]:
#clone IMPA github
!git clone https://github.com/theislab/IMPA.git

In [None]:
#bring dataset tar file from google drive to content folder in colab local drive
!pip install -q gdown
!gdown --id 11Cu4nm64ZOaJyKMEORoehM2sJulGV5FP -O /content/bbbc021_all.tar.gz

In [None]:
#unzip the tar file with datasets in a newly created folder on colab drive
!mkdir -p /content/IMPA/project_folder/datasets
!tar -xvzf /content/bbbc021_all.tar.gz -C /content/IMPA/project_folder/datasets

In [None]:
#bring checkpoint files from google drive to a created folder in colab local drive
!mkdir -p /content/IMPA/checkpoints
!gdown --folder 1W9zMjJYyRfdqYfhZLqi-V-wGER27nL5Z -O /content/IMPA/checkpoints/bbbc021_all

In [None]:
#install IMPA package (setup.py)
%cd /content/IMPA
%pip install -e .

#Tutorial by Palma et al. (2024) was adjusted and the model was retrained through fine tuning

original tutorial: "Use IMPA for unseen drug prediction BBBC021"

https://github.com/theislab/IMPA/blob/main/tutorial/transform_cells_bbbc021_all_unseen_prediction.ipynb

In [None]:
##Import libraries
# Standard library imports
import os
from pathlib import Path

# Local application/library imports
from IMPA.dataset.data_loader import CellDataLoader
from IMPA.solver import IMPAmodule
import IMPA.eval.gan_metrics.fid as fid_mod
from IMPA.eval.gan_metrics.density_and_coverage import compute_d_c

# Third-party library imports
import re
import sys
sys.path.append("/content/IMPA/tutorial")
from tutorial_utils import t2np
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import linalg
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.models import inception_v3
from omegaconf import OmegaConf
from tqdm import tqdm
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
##Utils functions
def transform_by_emb(solver, dataloader, y, n_average, args):
    """
    Transform images in a dataloader using a solver for a specific drug ID.

    Parameters:
        solver: The solver object used for transformation.
        dataloader: The dataloader containing images to be transformed.
        n_average (int): Number of times to average random noise vectors.
        drug_id (str): The ID of the drug for transformation.
        args: Arguments object containing additional parameters.

    Returns:
        tuple: A tuple containing two NumPy arrays representing controls and transformed images.
    """
    controls = []
    transformed = []

    y = y.unsqueeze(0)
    with torch.no_grad():
        for batch in tqdm(dataloader.train_dataloader()):
            X_ctr = batch["X"][0]
            z = torch.ones(X_ctr.shape[0], n_average, args.z_dimension).cuda().mean(1)

            # Perturbation ID
            y_emb = y.repeat((z.shape[0], 1)).cuda()
            y_emb = torch.cat([y_emb, z], dim=1)
            y_emb = solver.nets.mapping_network(y_emb)

            _, X_generated = solver.nets.generator(X_ctr, y_emb)
            transformed.append(t2np(X_generated.detach().cpu(), batch_dim=True))
            controls.append(t2np(X_ctr.detach().cpu(), batch_dim=True))
            break
    return np.concatenate(controls, axis=0), np.concatenate(transformed, axis=0)

In [None]:
##Read the configuration of interest
path_to_config = "/content/IMPA/config_hydra/config/bbbc021_all.yaml"

In [None]:
# Reading the YAML file
with open(path_to_config, 'r') as file:
    config = yaml.safe_load(file)  # Use safe_load to avoid executing arbitrary code

# Correct the path to images and data index
config["image_path"] = "IMPA/project_folder/datasets/bbbc021_all"
config["data_index_path"] = "IMPA/project_folder/datasets/bbbc021_all/metadata/bbbc021_df_all.csv"
config["embedding_path"] = "IMPA/embeddings/csv/emb_fp_all.csv"

with open(path_to_config, "w") as file:
    yaml.safe_dump(config, file)

# Access the loaded data
print(config)

In [None]:
config["image_path"] = "/content/" + config["image_path"]
config["data_index_path"] = "/content/" + config["data_index_path"]
config["embedding_path"] = "/content/" + config["embedding_path"]

In [None]:
##Create an omega config dict
args = OmegaConf.create(config)

In [None]:
##Initialize data loader
dataloader = CellDataLoader(args)
train_dataloader = dataloader.train_dataloader()
val_dataloader = dataloader.val_dataloader()

In [None]:
##Wrapper for resizing images (96x96 --> 128x128: to use original checkpoint file)
# imported images have 96x96 size --> 128x128 resize wrapper

class ResizeWrapDataset(torch.utils.data.Dataset):
    def __init__(self, base_ds, size=128):
        self.base = base_ds
        self.size = size
    def __len__(self):
        return len(self.base)
    def _resize_img(self, t):
        if isinstance(t, torch.Tensor):
            # t: [C,H,W]
            if t.ndim == 3:
                return F.interpolate(t.unsqueeze(0), size=(self.size, self.size),
                                     mode="bilinear", align_corners=False).squeeze(0)
            # t: [H,W] (if grayscale)
            if t.ndim == 2:
                t = t.unsqueeze(0)  # [1,H,W]
                return F.interpolate(t.unsqueeze(0), size=(self.size, self.size),
                                     mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
        return t
    def __getitem__(self, idx):
        item = self.base[idx]
        X = item["X"]
        # IMPA has ("ctrl","trt") format - resize both
        if isinstance(X, (list, tuple)) and len(X) == 2:
            item["X"] = (self._resize_img(X[0]), self._resize_img(X[1]))
        else:
            item["X"] = self._resize_img(X)
        return item

# bring dataset, batch_size from old loader and create new loader
train_base = train_dataloader.dataset
val_base   = val_dataloader.dataset

wrapped_train_loader = DataLoader(
    ResizeWrapDataset(train_base, size=128),
    batch_size=train_dataloader.batch_size,
    shuffle=True,
    num_workers=getattr(args, "num_workers", 2),
    pin_memory=True,
    drop_last=True,
)

wrapped_val_loader = DataLoader(
    ResizeWrapDataset(val_base, size=128),
    batch_size=val_dataloader.batch_size,
    shuffle=False,
    num_workers=getattr(args, "num_workers", 2),
    pin_memory=True,
    drop_last=False,
)

# check batch size
b = next(iter(wrapped_train_loader))
xc, xt = b["X"]
print("wrapped ctrl/trt shapes:", tuple(xc.shape), tuple(xt.shape))  # -> (*, 3, 128, 128) 기대

In [None]:
##Initialize model
checkpoint_dir = "/content/IMPA/checkpoints/bbbc021_all/"

In [None]:
##Create model
solver = IMPAmodule(args, checkpoint_dir, dataloader)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
##avoid shape mismatch (domain 88→24): load checkpoint except discriminator head

ckpt_path = "/content/IMPA/checkpoints/bbbc021_all/checkpoint/000200_nets.ckpt"
ckpt = torch.load(ckpt_path, map_location="cpu")

def get_substate(ckpt_dict, keys):
    for k in keys:
        if k in ckpt_dict and isinstance(ckpt_dict[k], dict):
            return ckpt_dict[k]
    nets = ckpt_dict.get("nets")
    if isinstance(nets, dict):
        for k in keys:
            if k in nets and isinstance(nets[k], dict):
                return nets[k]
    return None

def strip_module_prefix(sd):
    return { (k[7:] if k.startswith("module.") else k): v for k,v in sd.items() }

def safe_load_module(module, state_dict, drop_head=False):
    sd = strip_module_prefix(state_dict)
    if drop_head:
        sd = {k:v for k,v in sd.items() if not (k.startswith("head.") or ".head." in k)}
    target = module.module if hasattr(module, "module") else module
    missing, unexpected = target.load_state_dict(sd, strict=False)
    print(f"[{target.__class__.__name__}] missing={len(missing)}, unexpected={len(unexpected)}")
    return missing, unexpected

G_sd = get_substate(ckpt, ["generator"])
D_sd = get_substate(ckpt, ["discriminator"])
S_sd = get_substate(ckpt, ["style_encoder","styleencoder"])
M_sd = get_substate(ckpt, ["mapping_network","mappingNetwork"])

if G_sd is not None: safe_load_module(solver.generator, G_sd)
if D_sd is not None: safe_load_module(solver.discriminator, D_sd, drop_head=True)  # ★ 핵심: head 제외
if S_sd is not None: safe_load_module(solver.style_encoder, S_sd)
if M_sd is not None: safe_load_module(solver.mapping_network, M_sd)

In [None]:
# 1) Linear probing: train only head
def freeze_except_D_head(solver):
    D = solver.discriminator.module if hasattr(solver.discriminator,"module") else solver.discriminator
    # Discriminator: head only
    for name, p in D.named_parameters():
        p.requires_grad = name.startswith("module.head") or name.startswith("head")
    # freeze rest of the modules (G, S, M)
    for m in [solver.generator, solver.style_encoder, solver.mapping_network]:
        mm = m.module if hasattr(m,"module") else m
        for p in mm.parameters():
            p.requires_grad = False

# 2) Fine tuning
def unfreeze_all(solver):
    for m in [solver.discriminator, solver.generator, solver.style_encoder, solver.mapping_network]:
        mm = m.module if hasattr(m,"module") else m
        for p in mm.parameters():
            p.requires_grad = True

In [None]:
## Training details
# 1) Select mode — LP or FT
mode = "ft"

if mode == "lp":
    freeze_except_D_head(solver)
else:
    unfreeze_all(solver)

# fid tracking
checkpoint_cb = ModelCheckpoint(
    monitor="fid_transformations",  # based on yaml
    mode="min",
    save_top_k=1,                   # only one best model
    filename="{epoch:04d}-fid{fid_transformations:.2f}",
)

# 2) Run training with PL Trainer (use original training_step/optimizers as-is)
trainer = Trainer(
    max_epochs=args.total_epochs,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1 if torch.cuda.is_available() else None,
    log_every_n_steps=getattr(args, "log_every_n_steps", 10),
    callbacks=[checkpoint_cb],
)

### setting directories according to changed condition (wrapper)
# 1) Prepare sample & checkpoint folders
os.makedirs(os.path.join(checkpoint_dir, args.sample_dir), exist_ok=True)      # /content/IMPA/checkpoints/bbbc021_all/sample
os.makedirs(os.path.join(checkpoint_dir, args.checkpoint_dir), exist_ok=True)  # /content/IMPA/checkpoints/bbbc021_all/checkpoint

# 2) Replace validation loader used by solver for internal evaluation/visualization with 128x128 wrapping loader
solver.loader_test = wrapped_val_loader

In [None]:
# Patch FID sqrtm handling (to solve FID instability issue)
def _sqrtm_robust(A):
    out = linalg.sqrtm(A, disp=False)
    if isinstance(out, tuple):
        S, _ = out
    else:
        S = out
    S = np.asarray(S)
    # On numerical instability, add small jitter and recompute
    if not np.isfinite(S).all():
        jitter = np.eye(A.shape[0]) * 1e-6
        out2 = linalg.sqrtm((A + jitter).astype(np.float64), disp=False)
        S = out2[0] if isinstance(out2, tuple) else out2
    return S.real

def cal_frechet_distance_patched(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1);   mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1); sigma2 = np.atleast_2d(sigma2)
    # Add small diagonal stabilization to covariance
    epsI = np.eye(sigma1.shape[0]) * eps
    covmean = _sqrtm_robust((sigma1 + epsI).dot(sigma2 + epsI))
    diff = mu1 - mu2
    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2.0 * np.trace(covmean)
    return float(np.real(fid))

# Apply patch
fid_mod.cal_frechet_distance = cal_frechet_distance_patched

In [None]:
## Training
# Not a LightningDataModule → pass train/val dataloaders directly to fit
# Use wrapped_*_loader here
trainer.fit(
    solver,
    train_dataloaders=wrapped_train_loader,
    val_dataloaders=wrapped_val_loader
)

# Performance metrics

In [None]:
# print Best FID
best_fid = float(checkpoint_cb.best_model_score)
best_ckpt = checkpoint_cb.best_model_path
print("Best FID:", best_fid, "| path:", best_ckpt)

In [None]:
# Function: load checkpoint with the best FID
#           and calculate coverage/precision/recall

class InceptionPool3(nn.Module):
    def __init__(self):
        super().__init__()
        m = inception_v3(weights=None, aux_logits=False, transform_input=False)
        self.features = nn.Sequential(*(list(m.children())[:-1]))  # up to avgpool
        for p in self.features.parameters():
            p.requires_grad = False
        self.eval()
    @torch.no_grad()
    def forward(self, x):
        if x.shape[-2:] != (299, 299):
            x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
        x = torch.clamp(x, 0, 1)  # if [-1,1] -> (x+1)/2
        f = self.features(x).view(x.size(0), -1)  # [B,2048]
        return f.cpu().numpy()

@torch.no_grad()
def compute_coverage_prec_recall(solver, val_loader, device, k=None, max_batches=None):
    feat_net = InceptionPool3().to(device)
    real_feats, fake_feats = [], []

    for i, batch in enumerate(val_loader):
        X_ctrl, X_trt = batch["X"]
        y_trg = batch["mols"].long().to(device)

        X_ctrl = X_ctrl.to(device)
        X_trt  = X_trt.to(device)

        # Create target style (assuming single modality)
        s_trg = solver.encode_label(X_ctrl, y_trg, None, None)
        _, X_fake = solver.nets.generator(X_ctrl, s_trg)

        # Uncomment below if conversion [-1,1] → [0,1] is needed
        # X_trt  = torch.clamp((X_trt  + 1)/2, 0, 1)
        # X_fake = torch.clamp((X_fake + 1)/2, 0, 1)

        real_feats.append(feat_net(X_trt))
        fake_feats.append(feat_net(X_fake))

        if (max_batches is not None) and (i+1 >= max_batches):
            break

    real_feats = np.concatenate(real_feats, axis=0)
    fake_feats = np.concatenate(fake_feats, axis=0)
    N = min(len(real_feats), len(fake_feats))
    if k is None:
        k = max(3, int(np.sqrt(N)))

    return compute_d_c(real_feats, fake_feats, nearest_k=k)


In [None]:
##Compute P/R/D/C

print(f"Best FID: {best_fid:.4f} | path: {best_ckpt}")

# 1) Parse epoch number from Lightning ckpt → IMPA step = epoch + 1
m = re.search(r"epoch=(\d+)", os.path.basename(best_ckpt))
assert m, f"Cannot parse epoch from path: {best_ckpt}"
epoch_num = int(m.group(1))
impa_step = epoch_num + 1  # IMPA saves at end of epoch with step = epoch + 1

# 2) Optionally check for existing IMPA ckpt
impa_nets = os.path.join(checkpoint_dir, args.checkpoint_dir, f"{impa_step:06d}_nets.ckpt")
assert os.path.exists(impa_nets), f"IMPA ckpt not found: {impa_nets}"

# 3) load best ckpt (IMPA format)
solver._load_checkpoint(impa_step)

# 4) Device alignment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
solver.to(device)
for net in solver.nets.values():
    (net.module if hasattr(net, "module") else net).to(device)

em = solver.embedding_matrix
if isinstance(em, (list, tuple)):
    solver.embedding_matrix = type(em)([e.to(device) for e in em])
else:
    solver.embedding_matrix = em.to(device)

# 5) Compute D/C/P/R
dc = compute_coverage_prec_recall(
    solver,
    wrapped_val_loader,   # Recommend using 128x128 loader
    device,
    k=None,
    max_batches=None
)

print(
    f"[Best-FID (Lightning epoch={epoch_num}) -> IMPA step={impa_step:06d}] "
    f"precision: {dc['precision']:.4f} | "
    f"recall: {dc['recall']:.4f} | "
    f"density: {dc['density']:.4f} | "
    f"coverage: {dc['coverage']:.4f}"
)


# Application

In [None]:
####Tranform controls to perturbed
# Initilize empty dictionaries
controls = []
transformed = {}

with torch.no_grad():
    for i, (drug, drug_id) in enumerate(dataloader.mol2id.items()):
        print(f"Transforming images for {drug}")
        transformed[drug] = []
        for j, batch in tqdm(enumerate(dataloader.train_dataloader())):
            X_ctr = batch["X"][0]
            # z original and z transported
            z = torch.randn(X_ctr.shape[0], 100, args.z_dimension).cuda().mean(1)

            # Perturbation ID
            id_pert = dataloader.mol2id[drug] * torch.ones(X_ctr.shape[0]).long().cuda()
            y = solver.embedding_matrix(id_pert)
            y = torch.cat([y, z], dim=1)
            y = solver.nets.mapping_network(y)

            _, X_generated = solver.nets.generator(X_ctr, y)

            if i==0:
                controls.append(t2np(X_ctr.detach().cpu(), batch_dim=True))
            transformed[drug].append(t2np(X_generated.detach().cpu(), batch_dim=True))
            if j==3:
                break

controls = np.concatenate(controls, axis=0)
transformed = {key: np.concatenate(val, axis=0) for key, val in transformed.items()}

In [None]:
for i in range(len(controls)):
    print(f"Control {i}")
    plt.figure(figsize=(1, 1))
    plt.imshow(controls[i])
    plt.axis("off")
    plt.show()
    if i==3:
        break

In [None]:
for pert in transformed:
    print(f"Perturbation {pert}")
    for i in range(len(transformed[pert])):
        plt.figure(figsize=(1, 1))
        plt.imshow(transformed[pert][i])
        plt.axis("off")
        plt.show()
        if i==3:
            break

In [None]:
####Predict on unseen perturbations
ood_drugs = ["taxol",
             "ALLN",
             "bryostatin",
             "simvastatin",
             "MG-132",
             "methotrexate",
             "colchicine",
             "cytochalasin B",
             "AZ258",
             "cisplatin"]

In [None]:
ood_drug_embeddings = pd.read_csv("/content/IMPA/embeddings/csv/emb_fp_all.csv", index_col=0).loc[ood_drugs]

In [None]:
drugs = {}
controls = {}

for drug in ood_drugs:
    print(f"Transform into {drug}")
    emb_drug = torch.Tensor(ood_drug_embeddings.loc[drug])
    control, transformed = transform_by_emb(solver, dataloader, emb_drug, 100, args)
    drugs[drug] = transformed
    controls[drug] = control

In [None]:
for pert in controls:
    print(f"Perturbation {pert}")
    for i in range(len(drugs[pert])):
        plt.figure(figsize=(1, 1))
        plt.imshow(drugs[pert][i])
        plt.axis("off")
        plt.show()
        if i==3:
            break

#References
OpenAI. 2025. GPT-5, ChatGPT model. OpenAI, San Francisco, CA. Available at: https://chat.openai.com/ (accessed in August-September 2025).

Palma A, Theis FJ, Lotfollahi M. 2025. Predicting cell morphological responses to perturbations using generative modeling. Nat Commun 16:505. DOI: 10.1038/s41467-024-55707-8.

Palma A, Theis FJ, Lotfollahi M. 2025. IMPA: Predicting cell morphological responses to perturbations (GitHub repository). GitHub. Available at: https://github.com/theislab/impa (accessed in August-September, 2025)
