### This notebook displays the basic inference of VAR

In [None]:
################## 1. Download checkpoints ##################
import os
import os.path as osp
import sys
import time
import torch, torchvision
import random
import argparse
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
import importlib
project_root = 'VAR-Q'
os.chdir(project_root)
# Add the VAR-Q directory to the Python path
sys.path.append('VAR-Q')

setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed

# Import configuration loader and model builder
from VAR_Q.config_loader import load_varq_config
from VAR.models import build_vae_var_from_config

# Load configuration
config = load_varq_config(f'VAR-Q/VAR_Q/VAR-raw.json')

# Get model depth from config
model_depth = config.get_model_config()['depth']

In [None]:
################## 2. Build models ##################
assert model_depth in {16, 20, 24, 30, 36}

# Get checkpoint paths from config
vae_ckpt, var_ckpt = config.get_checkpoint_paths(model_depth)
hf_home = config.get_checkpoint_config()['hf_home']

# Download checkpoints if they don't exist
if not osp.exists(vae_ckpt): 
    print(f"Downloading VAE checkpoint from {hf_home}/{vae_ckpt}")
    os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): 
    print(f"Downloading VAR checkpoint from {hf_home}/{var_ckpt}")
    os.system(f'wget {hf_home}/{var_ckpt}')

# Get device from config
device = config.get_device()
print(f"Using device: {device}")

# Build models using configuration
if 'vae' not in globals() or 'var' not in globals():
    print("Building models from configuration...")
    vae, var = build_vae_var_from_config(config.config, device=device)

# Load checkpoints
print("Loading checkpoints...")
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

In [None]:
############################# 3. Sample with classifier-free guidance #############################
# Get inference parameters from config
inference_config = config.get_inference_config()
seed = inference_config['seed']
cfg = inference_config['cfg']
top_k = inference_config['top_k']
top_p = inference_config['top_p']
more_smooth = inference_config['more_smooth']
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}

print(f"Inference parameters:")
print(f"  Seed: {seed}")
print(f"  CFG: {cfg}")
print(f"  Top-k: {top_k}")
print(f"  Top-p: {top_p}")
print(f"  More smooth: {more_smooth}")

# Set seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

In [None]:
############################# 4. Generate images #############################
# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        start_time = time.time()
        recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=top_k, top_p=top_p, g_seed=seed, more_smooth=more_smooth)
        end_time = time.time()
        print(f"Time taken: {end_time - start_time} seconds")

chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.save('recon_B3HW.png')