In [5]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')


[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))



  from .autonotebook import tqdm as notebook_tqdm


[init_weights] VAR with init_std=0.0180422
prepare finished.


In [8]:
var_checkpoint_path = 'var_d16.pth'

# Load the state dictionary directly, as identified previously
state_dict = torch.load(var_checkpoint_path, map_location='cpu')
var.load_state_dict(state_dict, strict=True)
print("Successfully loaded weights into the VAR model.")

# --- 4. Prepare for Inference ---
# Set the model to evaluation mode to disable dropout, etc.
var.eval()

# The model's parameters should not require gradients for inference
for p in var.parameters():
    p.requires_grad_(False)

print("Model is ready for inference.")

Successfully loaded weights into the VAR model.
Model is ready for inference.


In [29]:
import torch
from collections import defaultdict

def extract_var_activations(var_model, input_tokens, target_layers=None):
    """
    Extract activations from VAR model layers during forward pass
    
    Args:
        var_model: The loaded VAR model
        input_tokens: Input tokens for the model
        target_layers: List of layer names to extract (if None, extracts all FFN layers)
    
    Returns:
        Dictionary of layer_name -> activation tensor
    """
    activations = {}
    hooks = []
    
    if target_layers is None:
        # Extract all FFN layers by default
        target_layers = []
        for name, module in var_model.named_modules():
            if 'ffn.fc1' in name or 'ffn.fc2' in name or name == 'head':
                target_layers.append(name)
    
    def make_hook(layer_name):
        def hook_fn(module, input, output):
            # Store activation, handling different output shapes
            if isinstance(output, torch.Tensor):
                if len(output.shape) == 3:  # [batch, sequence, features]
                    # Average across sequence dimension for transformers
                    activations[layer_name] = output.mean(dim=1).detach()
                else:
                    activations[layer_name] = output.detach()
            else:
                # Handle tuple outputs (some layers return multiple values)
                activations[layer_name] = output[0].detach()
        return hook_fn
    
    # Register hooks
    for layer_name in target_layers:
        try:
            layer = dict(var_model.named_modules())[layer_name]
            hook = layer.register_forward_hook(make_hook(layer_name))
            hooks.append(hook)
        except KeyError:
            print(f"Warning: Layer {layer_name} not found")
    
    # Perform forward pass
    try:
        with torch.no_grad():
            _ = var_model(torch.tensor([1,]), input_tokens)
    finally:
        # Clean up hooks
        for hook in hooks:
            hook.remove()
    
    return activations


In [30]:
from PIL import Image
import torchvision.transforms as transforms
image_path = "../data/imagenette/n01440764/ILSVRC2012_val_00009111.JPEG"
# Load and preprocess image
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)

# Encode image to tokens using VAE
with torch.no_grad():
    # This depends on your VAE implementation
    # You may need to adjust based on your specific VAE interface
    #encoded_tokens = vae.encode(image_tensor)  # Adjust this line
    vae_out = vae.img_to_idxBl(image_tensor)
    var_in = var.vae_quant_proxy[0].idxBl_to_var_input(vae_out)
print("shapes")
for el in vae_out:
    print(el.shape)
print(vae_out[0])
print(var_in.shape)

# Extract activations from VAR
activations = extract_var_activations(var, var_in)



shapes
torch.Size([1, 1])
torch.Size([1, 4])
torch.Size([1, 9])
torch.Size([1, 16])
torch.Size([1, 25])
torch.Size([1, 36])
torch.Size([1, 64])
torch.Size([1, 100])
torch.Size([1, 169])
torch.Size([1, 256])
tensor([[2248]])
torch.Size([1, 679, 32])


  with torch.cuda.amp.autocast(enabled=False):


In [35]:
for key in activations.keys():
    print(f"Layer: {key}, Activation shape: {activations[key].shape}")
    # Optionally, you can save or visualize the activations here
    # For example, convert to numpy and save as image if needed
    activation_image = activations[key].cpu().numpy()
    print(activation_image)
    # Save or visualize activation_image as needed

Layer: blocks.0.ffn.fc1, Activation shape: torch.Size([1, 4096])
[[-0.74638313 -2.9305754  -3.013914   ... -0.52451587 -2.7052476
  -0.49910337]]
Layer: blocks.0.ffn.fc2, Activation shape: torch.Size([1, 1024])
[[ 1.4226226e-04  3.6634475e-02 -6.1239654e-01 ... -8.0569319e-02
   7.6975203e-03  1.5197721e-01]]
Layer: blocks.1.ffn.fc1, Activation shape: torch.Size([1, 4096])
[[-0.33342698 -0.1417509  -1.0790021  ... -2.398618   -1.7207677
  -0.43498006]]
Layer: blocks.1.ffn.fc2, Activation shape: torch.Size([1, 1024])
[[ 0.12657642  0.08855741 -0.22504513 ... -0.0940779   0.3302756
  -0.2608851 ]]
Layer: blocks.2.ffn.fc1, Activation shape: torch.Size([1, 4096])
[[-0.6254736 -1.173401  -0.6255016 ... -1.00569   -2.9676414 -1.9144794]]
Layer: blocks.2.ffn.fc2, Activation shape: torch.Size([1, 1024])
[[ 0.08841949 -0.15493914  0.06704925 ... -0.09750675  0.12207684
   0.01863836]]
Layer: blocks.3.ffn.fc1, Activation shape: torch.Size([1, 4096])
[[-1.5254282  -1.3557339  -0.42768756 ... -0.4