# VAR HF 推理与后验采样 Demo

本 Notebook 展示了如何：

1. 从 Hugging Face (`FoundationVision/var`) 下载官方发布的 VAE 与 VAR 权重；
2. 构建 `var_inv` 中自带的 VAR/VQVAE 模型，并使用 `VAR.autoregressive_infer_cfg` 进行标准 CFG 采样；
3. 借助 `GradientGuidedVARSampler` + 测量算子，将 VAR 视作“离散扩散”，对部分观测数据执行后验重建，并可视化每个尺度的重建过程。

> ⚠️ 运行前请确保你拥有 Hugging Face 的下载权限（可能需要 `huggingface-cli login`）。

## 1. 准备依赖与工具函数

以下代码导入 `var_inv` 中封装好的模块，并提供若干辅助函数用于读写图片、可视化采样结果。

In [None]:
import math
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image

from var_inv.measurements import MaskingMeasurement
from var_inv.posterior_sampling import (
    GradientGuidedVARSampler,
    MeasurementModel,
    PosteriorGuidanceConfig,
)
from var_inv.var_models.models import build_vae_var


def load_rgb(path: Path, target_hw: int) -> torch.Tensor:
    img = Image.open(path).convert("RGB")
    if target_hw is not None:
        img = img.resize((target_hw, target_hw), Image.BICUBIC)
    arr = np.asarray(img, dtype=np.float32) / 255.0
    return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)


def tensor_to_pil(t: torch.Tensor) -> Image.Image:
    t = t.detach().cpu().clamp(0, 1)
    if t.ndim == 4:
        t = t[0]
    arr = (t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    return Image.fromarray(arr)


def show_tensor(t: torch.Tensor, title: str = "") -> None:
    pil_img = tensor_to_pil(t)
    plt.figure(figsize=(4, 4))
    plt.imshow(pil_img)
    plt.axis("off")
    if title:
        plt.title(title)
    plt.show()


def visualize_stages(stage_imgs: List[torch.Tensor], ncols: int = 5) -> None:
    if not stage_imgs:
        return
    nrows = math.ceil(len(stage_imgs) / ncols)
    plt.figure(figsize=(3 * ncols, 3 * nrows))
    for idx, stage in enumerate(stage_imgs):
        ax = plt.subplot(nrows, ncols, idx + 1)
        ax.imshow(tensor_to_pil(stage))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f"Stage {idx+1}")
    plt.tight_layout()
    plt.show()


## 2. 下载 Hugging Face 权重

我们以 `VAR-d16` 为例，需要同时下载 `vae_ch160v4096z32.pth` 与 `var_d16.pth`。下载后的文件会缓存在 `var_inv/hf_assets/` 目录，方便重复运行。

In [None]:
MODEL_DEPTH = 16  # 可切换成 20/24/30/36，对应不同体量的模型
REPO_ID = "FoundationVision/var"
HF_CACHE = Path("var_inv/hf_assets")
HF_CACHE.mkdir(parents=True, exist_ok=True)

vae_path = hf_hub_download(repo_id=REPO_ID, filename="vae_ch160v4096z32.pth", cache_dir=HF_CACHE)
var_path = hf_hub_download(repo_id=REPO_ID, filename=f"var_d{MODEL_DEPTH}.pth", cache_dir=HF_CACHE)
print(f"VAE ckpt: {vae_path}\nVAR ckpt: {var_path}")

## 3. 构建 VAE / VAR 并加载权重

`build_vae_var` 会创建和官方实现一致的结构。加载权重时，我们针对不同 checkpoint 格式（`model`/`state_dict`/纯 `state_dict`）做了兼容处理。

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

vae, var = build_vae_var(
    device=device,
    patch_nums=patch_nums,
    num_classes=1000,
    depth=MODEL_DEPTH,
    shared_aln=False,
    attn_l2_norm=False,
)

def load_module_state(module: torch.nn.Module, ckpt_path: str) -> None:
    payload = torch.load(ckpt_path, map_location=device)
    if isinstance(payload, dict) and "state_dict" in payload:
        payload = payload["state_dict"]
    if isinstance(payload, dict) and "model" in payload:
        payload = payload["model"]
    module.load_state_dict(payload, strict=True)

load_module_state(vae, vae_path)
load_module_state(var, var_path)

vae.eval().requires_grad_(False)
var.eval().requires_grad_(False)
print("Models are ready on", device)

## 4. 直接使用 CFG 采样

首先展示最基础的用法：根据 ImageNet 分类标签直接生成图像。

In [None]:
seed = 1234
torch.manual_seed(seed)
label = torch.tensor([207], device=device)  # 例如 "golden retriever"

with torch.inference_mode():
    autocast = torch.autocast(device_type=device.type, enabled=device.type == "cuda", dtype=torch.float16)
    with autocast:
        sample = var.autoregressive_infer_cfg(
            B=label.shape[0],
            label_B=label,
            g_seed=seed,
            cfg=4.0,
            top_k=900,
            top_p=0.96,
            more_smooth=False,
        )

show_tensor(sample, title="CFG Sampling Result")

## 5. 观测约束 + 梯度引导后验采样

下面示例模拟一个“随机遮挡”的测量：我们取一张现有图片，当作观测 `y = M \odot x`，其中 `M` 是随机掩膜。`GradientGuidedVARSampler` 会在每个尺度上通过 Gumbel-Softmax 计算梯度，指导 logits 满足测量约束，并记录每个尺度的重建结果。

In [None]:
reference_path = Path("diffusion-posterior-sampling/data/samples/00014.png")  # 可替换成任意 RGB 图
target_hw = patch_nums[-1] * vae.downsample  # 16 * 16 = 256
reference = load_rgb(reference_path, target_hw)

mask = MaskingMeasurement.random(
    shape=(1, 1, target_hw, target_hw),
    keep_ratio=0.35,
    generator=torch.Generator().manual_seed(42),
)
measurement_model = MeasurementModel(operator=mask.operator)
measurement = measurement_model.measure(reference)

cfg_params = PosteriorGuidanceConfig(
    cfg_scale=2.0,
    grad_scale_min=0.05,
    grad_scale_max=0.35,
    grad_start_ratio=0.1,
    top_p=0.95,
    top_k=900,
)
sampler = GradientGuidedVARSampler(var_model=var, measurement_model=measurement_model, config=cfg_params)

result = sampler.sample(
    measurement=measurement.to(device),
    label_B=label,
    g_seed=seed,
    capture_intermediate=True,
)
recon, stage_imgs = result

show_tensor(reference, title="Ground Truth (x)")
show_tensor(measurement, title="Measurement (y = M ⊙ x)")
show_tensor(recon, title="Posterior Reconstruction")
visualize_stages(stage_imgs, ncols=5)

以上可视化了从 coarse-to-fine 的每个尺度输出，呈现 VAR 作为“离散扩散”时的逐步细化效果。你也可以尝试其他测量（模糊、超分辨率等）或替换标签、随机种子，探索梯度引导对后验采样的影响。