In [None]:
import torch
import numpy as np
from datasets import load_dataset
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from torch.utils.data import DataLoader
from einops import rearrange
import matplotlib.pyplot as plt

from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
from utils import collate_fn

# -----------------------------
# 1. CONFIG
# -----------------------------
MAG = "3x"  # Options: 1x, 2x, 3x, 4x
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# -----------------------------
# 2. LOAD DATASET
# -----------------------------
ds = load_dataset(
    "StonyBrook-CVLab/ZoomLDM-demo-dataset-NAIP",
    name=MAG,
    trust_remote_code=True,
    split="train",
)

dl = DataLoader(ds, batch_size=4, collate_fn=collate_fn)
batch = next(iter(dl))

print("Images:", batch["image"].shape)

MAG_DICT = {
    "1x": 0,
    "2x": 1,
    "3x": 2,
    "4x": 3,
}
print("Magnifications:", batch["mag"])
print("SSL Features:", batch["ssl_feat"].shape)

# -----------------------------
# 3. DOWNLOAD & LOAD MODEL
# -----------------------------
ckpt_path = hf_hub_download(
    repo_id="StonyBrook-CVLab/ZoomLDM",
    filename="naip/weights.ckpt"
)
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)

config_path = hf_hub_download(
    repo_id="StonyBrook-CVLab/ZoomLDM",
    filename="naip/config.yaml"
)
config = OmegaConf.load(config_path)

model = instantiate_from_config(config.model)
model = model.to(device).eval()
model.load_state_dict(state_dict)

# disable null token
model.cond_stage_model.p_uncond = 0
sampler = PLMSSampler(model)

# -----------------------------
# 4. SAMPLING PARAMS
# -----------------------------
ddim_steps = 50
cfg_scale = 2
shape = [3, 64, 64]

# -----------------------------
# 5. INFERENCE
# -----------------------------
with torch.no_grad(), model.ema_scope(), torch.autocast(device_type="cuda", dtype=torch.float16):
    for k in ["ssl_feat", "mag"]:
        batch[k] = batch[k].to(device)

    cc = model.get_learned_conditioning(batch)
    uc = torch.zeros_like(cc)

    samples_ddim, _ = sampler.sample(
        S=ddim_steps,
        conditioning=cc,
        batch_size=len(batch["ssl_feat"]),
        shape=shape,
        verbose=False,
        unconditional_guidance_scale=cfg_scale,
        unconditional_conditioning=uc,
    )

    x_samples_ddim = model.decode_first_stage(samples_ddim)
    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
    x_samples_ddim = (255 * x_samples_ddim.cpu().numpy()).astype(np.uint8)

# -----------------------------
# 6. VISUALIZE REAL vs SYNTHETIC
# -----------------------------
real = rearrange(batch["image"], "n h w c -> h (n w) c")
syn = rearrange(x_samples_ddim, "n c h w -> h (n w) c")

fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].imshow(real)
axs[0].set_title("Real")
axs[0].axis("off")

axs[1].imshow(syn)
axs[1].set_title("Synthetic")
axs[1].axis("off")

plt.tight_layout()
plt.show()
