In [None]:
# superres_colab.py
# Colab-ready full script for ZoomLDM Super-Resolution pipeline

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from datasets import load_dataset
from omegaconf import OmegaConf
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from tqdm.auto import tqdm

from ldm.util import instantiate_from_config
from utils import collate_fn
from large_image_gen.resizer import Resizer
from large_image_gen.utils import model_pred, decode_large_image
from large_image_gen.postprocess import postprocess_image

# -------------------------
# Setup device
# -------------------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)

# -------------------------
# Parameters
# -------------------------
MAG_DICT = {
    "20x": 0,
    "10x": 1,
    "5x": 2,
    "2_5x": 3,
    "1_25x": 4,
}
MAG = "5x"  # Choose magnification here

# -------------------------
# Load dataset
# -------------------------
ds = load_dataset("StonyBrook-CVLab/ZoomLDM-demo-dataset", name=MAG, trust_remote_code=True, split="train")
dl = DataLoader(ds, batch_size=1, shuffle=True, collate_fn=collate_fn)
batch = next(iter(dl))
lr_image = batch["image"]

print("Images:", lr_image.shape)
print("Magnifications:", batch["mag"])
print("SSL Features:", batch["ssl_feat"].shape)

# -------------------------
# Load model & config
# -------------------------
ckpt_path = hf_hub_download(repo_id="StonyBrook-CVLab/ZoomLDM", filename="brca/weights.ckpt")
state_dict = torch.load(ckpt_path, map_location=device)

config_path = hf_hub_download(repo_id="StonyBrook-CVLab/ZoomLDM", filename="brca/config.yaml")
config = OmegaConf.load(config_path)

model = instantiate_from_config(config.model)
model = model.to(device).eval()
model.load_state_dict(state_dict)
model.cond_stage_model.p_uncond = 0  # disable null token

# -------------------------
# Display low-res image
# -------------------------
plt.figure(figsize=(5,5))
plt.imshow(lr_image[0])
plt.title("Low-res image")
plt.axis("off")
plt.show()

# -------------------------
# Encode low-res image
# -------------------------
img_latent = model.get_first_stage_encoding(
    model.encode_first_stage((2 * (lr_image.permute([0,3,1,2])/255. - 0.5)).float().to(device))
)
print("Latent:", img_latent.shape)

# -------------------------
# Set conditioning parameters
# -------------------------
cond_dim = 1024
mag_cond = torch.tensor([MAG_DICT[MAG]]).long().view(-1).to(device)

if MAG == "10x":
    emb_h, emb_w = 2, 2
elif MAG == "5x":
    emb_h, emb_w = 4, 4
elif MAG == "2_5x":
    emb_h, emb_w = 8, 8
elif MAG == "1_25x":
    emb_h, emb_w = 16, 16

cond_feat_learned = torch.randn((cond_dim, emb_h, emb_w), device=device, requires_grad=True)
opt = torch.optim.Adam([cond_feat_learned], lr=1e-1)

# -------------------------
# Learn embeddings via inversion
# -------------------------
num_steps = 200
timesteps = np.linspace(950, 50, num_steps)
for t in tqdm(timesteps):
    t_int = int(t)
    at = model.alphas_cumprod[t_int-1].view(1,1,1,1)
    eps = torch.randn_like(img_latent)
    xt = torch.sqrt(at) * img_latent + torch.sqrt(1-at) * eps

    t_cond = torch.tensor([t_int]).float().to(device).view(1)
    with torch.cuda.amp.autocast():
        ssl_feat_norm = (cond_feat_learned - cond_feat_learned.mean(0, keepdim=True)) / cond_feat_learned.std(0, keepdim=True)
        cond_inp = dict(ssl_feat=[ssl_feat_norm], mag=mag_cond)
        cond = model.cond_stage_model(cond_inp)
        pred_eps = model.model.diffusion_model(xt, t_cond, cond)

    w = at ** 0.5 * (1 - at)
    loss = (w*(pred_eps - eps)**2).mean()

    embs = cond_feat_learned
    embs = embs / torch.norm(embs, dim=0, keepdim=True)
    embs = embs.view(cond_dim,-1)
    sim = (embs.view(cond_dim,-1,1) * embs.view(cond_dim,1,-1)).sum(0)
    loss += -0.05 * sim.mean()

    opt.zero_grad()
    loss.backward()
    opt.step()

# -------------------------
# Patch generation (first stage)
# -------------------------
t0 = 1000
stride = 50
guidance = 2
batch_size = 2

cond_feat_norm = (cond_feat_learned - cond_feat_learned.mean(0, keepdim=True)) / cond_feat_learned.std(0, keepdim=True)
ssl_feat_20x = rearrange(cond_feat_norm.detach(), 'd (p1 h) (p2 w) -> (p1 p2) d h w', p1=emb_h, p2=emb_w)
cond_dict_20x = dict(
    ssl_feat=[ssl_feat_20x[i] for i in range(ssl_feat_20x.shape[0])],
    mag=torch.tensor([MAG_DICT["20x"]]).long().tile(ssl_feat_20x.shape[0]).to(device)
)
cond_20x_all = model.get_learned_conditioning(cond_dict_20x)

xt_20x_all = torch.randn((emb_h*emb_w,3,64,64), device=device)
down_operator = Resizer((1,3,256,256), scale_factor=1/emb_h).to(device)
up_operator = Resizer((1,3,256//emb_h,256//emb_w), scale_factor=emb_h).to(device)
img_guide = 2*(lr_image.permute([0,3,1,2])/255. - 0.5).float().to(device)

n_steps = 5
lr_patch = 0.5
for t in tqdm(range(t0, 0, -stride)):
    atbar = model.alphas_cumprod[t-1].view(1,1,1,1).to(device)
    atbar_prev = model.alphas_cumprod[max(t-1-stride,0)].view(1,1,1,1).to(device)
    beta_tilde = (model.betas[t-1] * (1 - atbar_prev) / (1 - atbar)).view(1,1,1,1).to(device)

    x0_pred_20x_all_list = []
    eps_20x_all_list = []

    for idx_20x in range(0, xt_20x_all.shape[0], batch_size):
        xt_20x = xt_20x_all[idx_20x:idx_20x+batch_size]
        cond_20x = cond_20x_all[idx_20x:idx_20x+batch_size]
        img_patches_guide = rearrange(img_guide, 'b c (p1 h) (p2 w) -> (b p1 p2) c h w', p1=emb_h, p2=emb_w)[idx_20x:idx_20x+batch_size]

        epsilon_20x = model_pred(model, xt_20x, t, cond_20x, w=guidance)
        x0_pred_20x = xt_20x / torch.sqrt(atbar) - epsilon_20x * torch.sqrt((1-atbar)/atbar)

        for k in range(n_steps):
            with torch.no_grad():
                img_pred = model.differentiable_decode_first_stage(x0_pred_20x)
            error_dir = up_operator(img_pred) - img_patches_guide
            error_dir = error_dir / (error_dir.abs().reshape(batch_size,-1).max(1)[0].view(-1,1,1,1) + 1e-6)
            xt_20x = xt_20x + lr_patch*error_dir
            epsilon_20x = model_pred(model, xt_20x, t, cond_20x, w=guidance)
            x0_pred_20x = xt_20x / torch.sqrt(atbar) - epsilon_20x * torch.sqrt((1-atbar)/atbar)

        x0_pred_20x_all_list.append(x0_pred_20x)
        eps_20x_all_list.append(epsilon_20x)

    x0_pred_20x_all = torch.cat(x0_pred_20x_all_list, dim=0)
    eps_20x_all = torch.cat(eps_20x_all_list, dim=0)
    xt_20x_all = torch.sqrt(atbar_prev)*x0_pred_20x_all + torch.sqrt(1-atbar_prev)*eps_20x_all + torch.sqrt(beta_tilde)*torch.randn_like(xt_20x_all)

# -------------------------
# Decode patches
# -------------------------
images_20x = []
for idx_20x in range(0, xt_20x_all.shape[0], batch_size):
    images_20x.append(model.decode_first_stage(xt_20x_all[idx_20x:idx_20x+batch_size]))
images_20x = torch.cat(images_20x, dim=0)

# Display low-res and high-res patches
fig = plt.figure(figsize=(15,10))
gs = GridSpec(nrows=4, ncols=5, width_ratios=[2,1,1,1,1], wspace=0.1, hspace=0.1)
ax_lr = fig.add_subplot(gs[:,0])
ax_lr.imshow(lr_image[0])
ax_lr.axis("off")
ax_lr.set_title("Low-res image")

for idx in range(emb_h*emb_w):
    row = idx // 4
    col = idx % 4
    ax = fig.add_subplot(gs[row,col+1])
    ax.imshow((0.5*(images_20x[idx]+1)).clamp(0,1).cpu().numpy().transpose([1,2,0]))
    ax.axis("off")
plt.tight_layout()
plt.show()

# -------------------------
# Reconstruct large image
# -------------------------
xt_20x = rearrange(xt_20x_all, '(b p1 p2) c h w -> b c (p1 h) (p2 w)', p1=emb_h, p2=emb_w)
image_20x = decode_large_image(xt_20x, model)

plt.figure(figsize=(10,10))
plt.imshow(image_20x[0])
plt.title("20x image")
plt.axis("off")
plt.show()

# Side-by-side comparison with upscaled low-res image
fig, ax = plt.subplots(1, 2, figsize=(20,10))
ax[0].imshow(Image.fromarray(lr_image[0].numpy()).resize((1024,1024)))
ax[0].set_title(f'Lower magnification image at {MAG}')
ax[0].axis("off")
ax[1].imshow(Image.fromarray(image_20x[0]))
ax[1].set_title('20x image')
ax[1].axis("off")
plt.show()

# -------------------------
# Optional postprocessing
# -------------------------
t0_post = 300
stride_post = 50
guidance_post = 2

ssl_feat = cond_feat_learned  # inferred embeddings

xt_20x_all_postprocessed = postprocess_image(
    model, xt_20x_all, ssl_feat,
    t0_post, stride_post, guidance_post,
    sliding_window_size=16, emb_h=emb_h, emb_w=emb_w, batch_size=16
)

xt_20x_postprocessed = rearrange(xt_20x_all_postprocessed, '(b p1 p2) c h w -> b c (p1 h) (p2 w)', p1=emb_h, p2=emb_w)
image_20x_postprocessed = decode_large_image(xt_20x_postprocessed, model)

# Side-by-side comparison after postprocessing
fig, ax = plt.subplots(1, 2, figsize=(20,10))
ax[0].imshow(Image.fromarray(lr_image[0].numpy()).resize((1024,1024)))
ax[0].set_title(f'Lower magnification image at {MAG}')
ax[0].axis("off")
ax[1].imshow(Image.fromarray(image_20x_postprocessed[0]))
ax[1].set_title('20x image - postprocessed')
ax[1].axis("off")
plt.show()

