# VAR HF 推理与后验采样 Demo

本 Notebook 展示了如何：

1. 从 Hugging Face (`FoundationVision/var`) 下载官方发布的 VAE 与 VAR 权重；
2. 构建 `var_inv` 中自带的 VAR/VQVAE 模型，并使用 `VAR.autoregressive_infer_cfg` 进行标准 CFG 采样；
3. 借助 `GradientGuidedVARSampler` + 测量算子，将 VAR 视作“离散扩散”，对部分观测数据执行后验重建，并可视化每个尺度的重建过程。

> ⚠️ 运行前请确保你拥有 Hugging Face 的下载权限（可能需要 `huggingface-cli login`）。

## 1. 准备依赖与工具函数

以下代码导入 `var_inv` 中封装好的模块，并提供若干辅助函数用于读写图片、可视化采样结果。

In [None]:
import math
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image

from var_inv.measurements import MaskingMeasurement
from var_inv.posterior_sampling import (
    GradientGuidedVARSampler,
    MeasurementModel,
    PosteriorGuidanceConfig,
)
from var_inv.var_models.models import build_vae_var


def load_rgb(path: Path, target_hw: int) -> torch.Tensor:
    img = Image.open(path).convert("RGB")
    if target_hw is not None:
        img = img.resize((target_hw, target_hw), Image.BICUBIC)
    arr = np.asarray(img, dtype=np.float32) / 255.0
    return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)


def tensor_to_pil(t: torch.Tensor) -> Image.Image:
    t = t.detach().cpu().clamp(0, 1)
    if t.ndim == 4:
        t = t[0]
    arr = (t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    return Image.fromarray(arr)


def show_tensor(t: torch.Tensor, title: str = "") -> None:
    pil_img = tensor_to_pil(t)
    plt.figure(figsize=(4, 4))
    plt.imshow(pil_img)
    plt.axis("off")
    if title:
        plt.title(title)
    plt.show()


def visualize_stages(stage_imgs: List[torch.Tensor], ncols: int = 5) -> None:
    if not stage_imgs:
        return
    nrows = math.ceil(len(stage_imgs) / ncols)
    plt.figure(figsize=(3 * ncols, 3 * nrows))
    for idx, stage in enumerate(stage_imgs):
        ax = plt.subplot(nrows, ncols, idx + 1)
        ax.imshow(tensor_to_pil(stage))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f"Stage {idx+1}")
    plt.tight_layout()
    plt.show()


## 2. 下载 Hugging Face 权重

我们以 `VAR-d16` 为例，需要同时下载 `vae_ch160v4096z32.pth` 与 `var_d16.pth`。下载后的文件会缓存在 `var_inv/hf_assets/` 目录，方便重复运行。

In [None]:
MODEL_DEPTH = 16  # 可切换成 20/24/30/36，对应不同体量的模型
REPO_ID = "FoundationVision/var"
HF_CACHE = Path("var_inv/hf_assets")
HF_CACHE.mkdir(parents=True, exist_ok=True)

vae_path = hf_hub_download(repo_id=REPO_ID, filename="vae_ch160v4096z32.pth", cache_dir=HF_CACHE)
var_path = hf_hub_download(repo_id=REPO_ID, filename=f"var_d{MODEL_DEPTH}.pth", cache_dir=HF_CACHE)
print(f"VAE ckpt: {vae_path}\nVAR ckpt: {var_path}")

## 3. 构建 VAE / VAR 并加载权重

`build_vae_var` 会创建和官方实现一致的结构。加载权重时，我们针对不同 checkpoint 格式（`model`/`state_dict`/纯 `state_dict`）做了兼容处理。

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

vae, var = build_vae_var(
    device=device,
    patch_nums=patch_nums,
    num_classes=1000,
    depth=MODEL_DEPTH,
    shared_aln=False,
    attn_l2_norm=True,
)

def load_module_state(module: torch.nn.Module, ckpt_path: str) -> None:
    payload = torch.load(ckpt_path, map_location=device)
    if isinstance(payload, dict) and "state_dict" in payload:
        payload = payload["state_dict"]
    if isinstance(payload, dict) and "model" in payload:
        payload = payload["model"]
    module.load_state_dict(payload, strict=True)

load_module_state(vae, vae_path)
load_module_state(var, var_path)

vae.eval().requires_grad_(False)
var.eval().requires_grad_(False)
print("Models are ready on", device)

## 4. 直接使用 CFG 采样

首先展示最基础的用法：根据 ImageNet 分类标签直接生成图像。

In [None]:
seed = 1234
torch.manual_seed(seed)
label = torch.tensor([207], device=device)  # 例如 "golden retriever"

with torch.inference_mode():
    autocast = torch.autocast(device_type=device.type, enabled=device.type == "cuda", dtype=torch.float16)
    with autocast:
        sample = var.autoregressive_infer_cfg(
            B=label.shape[0],
            label_B=label,
            g_seed=seed,
            cfg=4.0,
            top_k=900,
            top_p=0.96,
            more_smooth=False,
        )

show_tensor(sample, title="CFG Sampling Result")

## 5. 观测约束 + 梯度引导后验采样

下面示例模拟一个“随机遮挡”的测量：我们取一张现有图片，当作观测 `y = M \odot x`，其中 `M` 是随机掩膜。`GradientGuidedVARSampler` 会在每个尺度上通过 Gumbel-Softmax 计算梯度，指导 logits 满足测量约束，并记录每个尺度的重建结果。

In [None]:
reference_path = Path("diffusion-posterior-sampling/data/samples/00014.png")  # 可替换成任意 RGB 图
target_hw = patch_nums[-1] * vae.downsample  # 16 * 16 = 256
reference = load_rgb(reference_path, target_hw).to(device)

mask = MaskingMeasurement.random(
    shape=(1, 1, target_hw, target_hw),
    keep_ratio=0.35,
    generator=torch.Generator().manual_seed(42),
).to(device)
measurement_model = MeasurementModel(operator=mask.operator, filter_fn=lambda x, _: x)
measurement = measurement_model.measure(reference)

# 预计算已知区域的 GT token（anchor），只在 mask==1 的位置强制替换
ref_norm = reference * 2 - 1
anchor_tokens = var.vae_proxy[0].img_to_idxBl(ref_norm, v_patch_nums=var.vae_quant_proxy[0].v_patch_nums)
token_constraints = []
mask_tensor = mask.mask  # (B,1,H,W)
for si, pn in enumerate(var.vae_quant_proxy[0].v_patch_nums):
    m_stage = torch.nn.functional.interpolate(mask_tensor, size=(pn, pn), mode="area")
    m_stage = (m_stage > 0.5).squeeze(1)  # B x pn x pn -> B x pn x pn
    token_constraints.append({"mask": m_stage.reshape(mask_tensor.shape[0], -1), "idx": anchor_tokens[si]})

cfg_params = PosteriorGuidanceConfig(
    cfg_scale=0.0,          # 关闭 CFG，完全依赖观测
    grad_scale_min=0.5,
    grad_scale_max=2.0,
    grad_start_ratio=0.0,
    grad_steps=5,
    grad_stop_before_stage=5,  # 前半段不做梯度，靠 token 注入固定结构
    top_p=0.0,
    top_k=0,
)
sampler = GradientGuidedVARSampler(var_model=var, measurement_model=measurement_model, config=cfg_params)

result = sampler.sample(
    measurement=measurement,
    label_B=-1,   # -1 表示完全无条件
    g_seed=seed,
    capture_intermediate=True,
    capture_token_trace=True,
    token_constraints=token_constraints,
)
if len(result) == 3:
    recon, stage_imgs, token_trace = result
else:
    recon, stage_imgs = result
    token_trace = []

masked_recon = mask.operator(recon.to(device))
mse_mask = torch.nn.functional.mse_loss(masked_recon, measurement)
print(f"Masked-region MSE: {mse_mask.item():.5f}")

for info in token_trace:
    diff = (info["cfg_top"] != info["guided_top"]).sum().item()
    total = info["cfg_top"].numel()
    forced = info.get("forced_tokens", 0)
    print(
        f"Stage {info['stage']}: changed {diff}/{total} tokens, "
        f"grad_norm={info['grad_norm']:.4f}, max_delta={info['max_delta']:.4f}, forced={forced}"
    )

show_tensor(reference, title="Ground Truth (x)")
show_tensor(measurement, title="Measurement (y = M ⊙ x)")
show_tensor(recon, title="Posterior Reconstruction")
visualize_stages(stage_imgs, ncols=5)

以上可视化了从 coarse-to-fine 的每个尺度输出，呈现 VAR 作为“离散扩散”时的逐步细化效果。你也可以尝试其他测量（模糊、超分辨率等）或替换标签、随机种子，探索梯度引导对后验采样的影响。

## 6. 其他测量算子快速测试

下面再用模糊、超分辨率两种测量跑一遍，以验证梯度引导是否对不同 operator 生效。这里不做 token 硬注入，只用频域引导+随机标签。

In [None]:
from var_inv.measurements import GaussianBlurMeasurement, SuperResolutionMeasurement

def run_operator_test(measure_op, name: str):
    # 对于模糊/超分，默认整张图都是已知区域：前半段硬注入 GT token，后半段用梯度微调。
    measurement_model = MeasurementModel(operator=measure_op.operator, filter_fn=lambda x, _: x)
    meas = measurement_model.measure(reference)

    # 全 1 掩膜，表示全部位置都有 anchor。
    mask_tensor = torch.ones_like(reference[:, :1])
    ref_norm = reference * 2 - 1
    anchor_tokens = var.vae_proxy[0].img_to_idxBl(ref_norm, v_patch_nums=var.vae_quant_proxy[0].v_patch_nums)
    token_constraints = []
    for si, pn in enumerate(var.vae_quant_proxy[0].v_patch_nums):
        m_stage = torch.nn.functional.interpolate(mask_tensor, size=(pn, pn), mode="area")
        m_stage = (m_stage > 0.5).squeeze(1)
        token_constraints.append({"mask": m_stage.reshape(mask_tensor.shape[0], -1), "idx": anchor_tokens[si]})

    cfg_params = PosteriorGuidanceConfig(
        cfg_scale=0.0,
        grad_scale_min=0.5,
        grad_scale_max=2.0,
        grad_start_ratio=0.0,
        grad_steps=5,
        grad_stop_before_stage=5,
        top_p=0.0,
        top_k=0,
    )
    sampler = GradientGuidedVARSampler(var_model=var, measurement_model=measurement_model, config=cfg_params)
    result = sampler.sample(
        measurement=meas,
        label_B=-1,
        g_seed=seed,
        capture_intermediate=False,
        capture_token_trace=True,
        token_constraints=token_constraints,
    )
    if len(result) == 3:
        recon, _, trace = result
    else:
        recon, trace = result
    masked_recon = measure_op.operator(recon.to(device))
    mse_mask = torch.nn.functional.mse_loss(masked_recon, meas)
    print(f"[{name}] MSE: {mse_mask.item():.5f}")
    for info in trace:
        diff = (info['cfg_top'] != info['guided_top']).sum().item()
        total = info['cfg_top'].numel()
        forced = info.get('forced_tokens', 0)
        print(
            f"    Stage {info['stage']}: changed {diff}/{total} tokens, "
            f"grad_norm={info['grad_norm']:.4f}, max_delta={info['max_delta']:.4f}, forced={forced}"
        )
    show_tensor(meas, title=f"{name} Measurement")
    show_tensor(recon, title=f"{name} Recon")

# 高斯模糊
blur_op = GaussianBlurMeasurement(kernel_size=11, sigma=2.0, downsample=1)
run_operator_test(blur_op, 'GaussianBlur')

# 超分辨率（下采样再对齐）
sr_op = SuperResolutionMeasurement(scale=4)
run_operator_test(sr_op, 'SuperResolution')


## 7. 矩形遮挡 Inpainting 测试

再测试一个矩形缺失的 inpainting 任务：掩膜区域为 0 表示未知，其余为已知（硬注入），高频阶段用梯度微调。

In [None]:
# 构造中心矩形缺失的掩膜
rect_h = target_hw // 3
rect_w = target_hw // 3
top = (target_hw - rect_h) // 2
left = (target_hw - rect_w) // 2
rect_mask = MaskingMeasurement.rectangular(
    h=target_hw, w=target_hw, top=top, left=left, height=rect_h, width=rect_w, device=device
)
measurement_model = MeasurementModel(operator=rect_mask.operator, filter_fn=lambda x, _: x)
measurement = measurement_model.measure(reference)

# anchor tokens 只在已知区域（mask==1）硬注入
ref_norm = reference * 2 - 1
anchor_tokens = var.vae_proxy[0].img_to_idxBl(ref_norm, v_patch_nums=var.vae_quant_proxy[0].v_patch_nums)
token_constraints = []
for si, pn in enumerate(var.vae_quant_proxy[0].v_patch_nums):
    m_stage = torch.nn.functional.interpolate(rect_mask.mask, size=(pn, pn), mode="area")
    m_stage = (m_stage > 0.5).squeeze(1)
    token_constraints.append({"mask": m_stage.reshape(rect_mask.mask.shape[0], -1), "idx": anchor_tokens[si]})

cfg_params = PosteriorGuidanceConfig(
    cfg_scale=0.0,
    grad_scale_min=0.5,
    grad_scale_max=2.0,
    grad_start_ratio=0.0,
    grad_steps=5,
    grad_stop_before_stage=5,
    top_p=0.0,
    top_k=0,
)
sampler = GradientGuidedVARSampler(var_model=var, measurement_model=measurement_model, config=cfg_params)
result = sampler.sample(
    measurement=measurement,
    label_B=-1,
    g_seed=seed,
    capture_intermediate=False,
    capture_token_trace=True,
    token_constraints=token_constraints,
)
if len(result) == 3:
    recon, _, trace = result
else:
    recon, trace = result
masked_recon = rect_mask.operator(recon.to(device))
mse_mask = torch.nn.functional.mse_loss(masked_recon, measurement)
print(f'[Rect Inpaint] MSE: {mse_mask.item():.5f}')
for info in trace:
    diff = (info['cfg_top'] != info['guided_top']).sum().item()
    total = info['cfg_top'].numel()
    forced = info.get('forced_tokens', 0)
    print(
        f"    Stage {info['stage']}: changed {diff}/{total} tokens, "
        f"grad_norm={info['grad_norm']:.4f}, max_delta={info['max_delta']:.4f}, forced={forced}"
    )
show_tensor(measurement, title='Rect Measurement')
show_tensor(recon, title='Rect Recon')
