### ðŸš€ For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! ðŸŒŸ

In [1]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}

torch.cuda.empty_cache()

# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'memorization/checkpoints/vae_ch160v4096z32.pth', f'memorization/checkpoints/var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): 
    print("test")
    os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

ckpt = torch.load(var_ckpt, map_location='cpu')
ckpt.pop('attn_bias_for_masking', None)
var.load_state_dict(ckpt, strict=False)

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
#var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

  from .autonotebook import tqdm as notebook_tqdm



[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422
prepare finished.


In [2]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from memorization.data_prep.subset_imagenet import get_balanced_subset

class IndexedDataset(torch.utils.data.Dataset):
        """
        Wraps an existing dataset so that __getitem__ returns (img, label, index).
        The index refers to the position in the wrapped dataset (after subsetting).
        """
        def __init__(self, dataset):
            self.dataset = dataset

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            img, label = self.dataset[idx]
            return img, label, idx


base_dataset = ImageFolder(
    root=os.path.join("/scratch/inf0/user/mparcham/ILSVRC2012", "val_categorized"),
    transform=None,
)
dataset = get_balanced_subset(
    dataset=base_dataset,
    total_samples=12800,
    shuffle=True,
    seed=0,
)

# Wrap to expose indices
dataset = IndexedDataset(dataset)

# ---------------------------
# Collate function (keeps PIL + index)
# ---------------------------
def collate_pil(batch):
    """
    batch: List[(PIL.Image, int label, int index)]
    returns:
        images: List[PIL.Image]
        labels: LongTensor (B,)
        indices: LongTensor (B,)  # stable dataset indices
    """
    imgs, labels, indices = zip(*batch)
    return (
        list(imgs),
        torch.tensor(labels, dtype=torch.long),
        torch.tensor(indices, dtype=torch.long),
    )

# Use distributed sampler to split dataset per-rank. If not initialized, it
# gracefully becomes a single-process full-range sampler.
loader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=collate_pil)


In [3]:

import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from utils.data import normalize_01_into_pm1

mid_px = round(1.125 * 256)
aug_transform = T.Compose([
    T.Resize(mid_px, interpolation=InterpolationMode.LANCZOS),
    T.RandomCrop((256, 256)),
    T.ToTensor(),
    normalize_01_into_pm1,
])

In [4]:
images, labels, indices = next(iter(loader))
print(f'images len: {len(images)}, labels shape: {labels.shape}, indices shape: {indices.shape}')

images len: 4, labels shape: torch.Size([4]), indices shape: torch.Size([4])


In [6]:
# Prepare VAE embeddings for teacher-forcing
images_aug = torch.stack([aug_transform(img) for img in images], dim=0).to(device)
indices = var.vae_proxy[0].img_to_idxBl(images_aug)
emb = torch.cat(
    [var.vae_quant_proxy[0].embedding(idx) for idx in indices[1:]],
    dim=1
)

# Forward pass (teacher forcing): hooks record activations
_, test_atte = var(label_B=labels.to(device), x_BLCv_wo_first_l=emb)

  with torch.cuda.amp.autocast(enabled=False):


In [7]:
print(test_atte.shape)
print(test_atte[0][0][0][:10]) # first scale
print(test_atte[0][0][1][:10]) # second scale
print(test_atte[0][0][2][:10]) # second scale
print(test_atte[0][0][3][:10]) # second scale
print(test_atte[0][0][4][:10]) # second scale
print(test_atte[0][0][5][:10]) # third scale
print(test_atte[0][0][30][:20]) # third scale

torch.Size([1, 1, 680, 680])
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:0')
tensor([0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], device='cuda:0')
tensor([0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], device='cuda:0')
tensor([0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], device='cuda:0')
tensor([0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')


In [None]:
# a list of len 10, each of shape (B, num_patches), and each entry is an index to the codebook
print(len(indices))
print(indices[0].shape)
print(indices[1].shape)
print(indices[2].shape)
print(indices[9].shape)
print(indices[9])

10
torch.Size([4, 1])
torch.Size([4, 4])
torch.Size([4, 9])
torch.Size([4, 256])
tensor([[1997, 1362, 3823,  ..., 3324, 4040, 3737],
        [2741, 1917,  312,  ..., 1439, 2563, 1071],
        [2211, 2241,  342,  ..., 1080, 1081,  297],
        [3448, 1625,  860,  ..., 2183, 2804, 1052]], device='cuda:0')


In [None]:
print(emb.shape) # torch.Size([4, 679, 32]), batch size 4, 679 tokens, embedding dim 32

torch.Size([4, 679, 32])


In [2]:
print(var.attn_bias_for_masking.shape)
print(var.attn_bias_for_masking[0][0][0][:10]) # first scale
print(var.attn_bias_for_masking[0][0][1][:10]) # second scale
print(var.attn_bias_for_masking[0][0][2][:10]) # second scale
print(var.attn_bias_for_masking[0][0][3][:10]) # second scale
print(var.attn_bias_for_masking[0][0][4][:10]) # second scale
print(var.attn_bias_for_masking[0][0][5][:10]) # third scale
print(var.attn_bias_for_masking[0][0][30][:20]) # third scale

torch.Size([1, 1, 680, 680])
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:0')
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:0')
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:0')
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:0')
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       device='cuda:0')
tensor([-inf, 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], device='cuda:0')
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0.],
       device='cuda:0')


In [6]:
print(var.__init__.__code__.co_firstlineno)
print(var.__init__.__code__.co_filename)

mask = var.attn_bias_for_masking[0, 0]
q = 30

allowed = torch.isfinite(mask[q])
print(torch.unique(var.lvl_1L.squeeze()[allowed]))



22
/BS/scene_repre/work/VAR/models/var.py
tensor([0, 1, 2, 3, 4], device='cuda:0')


In [3]:
q = 30  # pick a token index
allowed = torch.isfinite(var.attn_bias_for_masking[0, 0, q])
print(allowed.shape)
print(torch.unique(var.lvl_1L.squeeze()[allowed]))

torch.Size([680])
tensor([0, 1, 2, 3, 4], device='cuda:0')


In [10]:
# d: torch.Size([1, 680, 1]) assigns level indices to each token
d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(patch_nums)]).view(1, 680, 1)
dT = d.transpose(1, 2) # torch.Size([1, 1, 680])    # dT: 11L
lvl_1L = dT[:, 0].contiguous() # torch.Size([1, 680])
print(lvl_1L.shape)
# A query at level s can attend to any key from the same or earlier level.
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, 680, 680) # query_level â‰¥ key_level


torch.Size([1, 680])


In [50]:
prev_lvl = (d == (dT + 1))                # keys at level s-1
first_lvl = (d == 0) & (dT == 0)          # level-0 queries attend to level-0 keys
mask_bool = prev_lvl | first_lvl
attn_bias_for_masking = torch.where(mask_bool, 0., -torch.inf).reshape(1, 1, 680, 680)

In [51]:
print(attn_bias_for_masking.shape)
print(attn_bias_for_masking[0][0][0][:10]) # first scale
print(attn_bias_for_masking[0][0][1][:10]) # second scale
print(attn_bias_for_masking[0][0][2][:10]) # second scale
print(attn_bias_for_masking[0][0][3][:10]) # second scale
print(attn_bias_for_masking[0][0][4][:10]) # second scale
print(attn_bias_for_masking[0][0][5][:10]) # third scale
print(attn_bias_for_masking[0][0][30][:20]) # third scale

torch.Size([1, 1, 680, 680])
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
tensor([0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
tensor([-inf, 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf])
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0.])


In [43]:
print(d.shape)
print(d.transpose(1, 2).shape)
print(var.lvl_1L.shape)

torch.Size([1, 680, 1])
torch.Size([1, 1, 680])
torch.Size([1, 680])


In [37]:
print(var.lvl_1L)

tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
         4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6,
         6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
         6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
         6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8,
         8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
         8, 8, 8, 8, 8, 8, 8

In [2]:
print(var.begin_ends)
print(var.blocks[0])

[(0, 1), (1, 5), (5, 14), (14, 30), (30, 55), (55, 91), (91, 155), (155, 255), (255, 424), (424, 680)]
AdaLNSelfAttn(
  shared_aln=False
  (drop_path): Identity()
  (attn): SelfAttention(
    using_flash=False, using_xform=False, attn_l2_norm=True
    (mat_qkv): Linear(in_features=1024, out_features=3072, bias=False)
    (proj): Linear(in_features=1024, out_features=1024, bias=True)
    (proj_drop): Identity()
  )
  (ffn): FFN(
    fused_mlp_func=False
    (fc1): Linear(in_features=1024, out_features=4096, bias=True)
    (act): GELU(approximate='tanh')
    (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    (drop): Identity()
  )
  (ln_wo_grad): LayerNorm((1024,), eps=1e-06, elementwise_affine=False)
  (ada_lin): Sequential(
    (0): SiLU()
    (1): Linear(in_features=1024, out_features=6144, bias=True)
  )
)


In [None]:
############################# 2. Sample with classifier-free guidance

# set args
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)

chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()
