In [17]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}

# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')


[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))



  from .autonotebook import tqdm as notebook_tqdm


[init_weights] VAR with init_std=0.0180422
prepare finished.


In [18]:
from collections import defaultdict

def extract_var_activations(var_model, input_tokens, target_layers=None):
    """
    Extract activations from VAR model layers during forward pass
    
    Args:
        var_model: The loaded VAR model
        input_tokens: Input tokens for the model
        target_layers: List of layer names to extract (if None, extracts all FFN layers)
    
    Returns:
        Dictionary of layer_name -> activation tensor
    """
    activations = {}
    hooks = []
    
    if target_layers is None:
        # Extract all FFN layers by default
        target_layers = []
        for name, module in var_model.named_modules():
            if 'ffn.fc1' in name or 'ffn.fc2' in name or name == 'head':
                target_layers.append(name)
    
    def make_hook(layer_name):
        def hook_fn(module, input, output):
            # Store activation, handling different output shapes
            if isinstance(output, torch.Tensor):
                if len(output.shape) == 3:  # [batch, sequence, features]
                    # Average across sequence dimension for transformers
                    activations[layer_name] = output.mean(dim=1).detach()
                else:
                    activations[layer_name] = output.detach()
            else:
                # Handle tuple outputs (some layers return multiple values)
                activations[layer_name] = output[0].detach()
        return hook_fn
    
    # Register hooks
    for layer_name in target_layers:
        try:
            layer = dict(var_model.named_modules())[layer_name]
            hook = layer.register_forward_hook(make_hook(layer_name))
            hooks.append(hook)
        except KeyError:
            print(f"Warning: Layer {layer_name} not found")
    
    # Perform forward pass
    try:
        with torch.no_grad():
            _ = var_model(torch.tensor([1,]), input_tokens)
    finally:
        # Clean up hooks
        for hook in hooks:
            hook.remove()
    
    return activations


In [19]:
def extract_var_ada_lin_activations(var_model, input_tokens, target_layers=None):
    """
    Extract activations from all ada_lin.1 (Linear) layers at the end of each block in the VAR model.

    Args:
        var_model: The loaded VAR model
        input_tokens: Input tokens for the model
        target_layers: List of layer names to extract (if None, extracts all ada_lin.1 layers in blocks and head_nm)

    Returns:
        Dictionary of layer_name -> activation tensor
    """
    activations = {}
    hooks = []

    if target_layers is None:
        # Collect all ada_lin.1 (Linear) layers in blocks and head_nm
        target_layers = []
        for name, module in var_model.named_modules():
            if name.endswith('ada_lin.1'):
                target_layers.append(name)

    def make_hook(layer_name):
        def hook_fn(module, input, output):
            if isinstance(output, torch.Tensor):
                if len(output.shape) == 3:
                    activations[layer_name] = output.mean(dim=1).detach()
                else:
                    activations[layer_name] = output.detach()
            else:
                activations[layer_name] = output[0].detach()
        return hook_fn

    for layer_name in target_layers:
        try:
            layer = dict(var_model.named_modules())[layer_name]
            hook = layer.register_forward_hook(make_hook(layer_name))
            hooks.append(hook)
        except KeyError:
            print(f"Warning: Layer {layer_name} not found")

    try:
        with torch.no_grad():
            _ = var_model(torch.tensor([1,]), input_tokens)
    finally:
        for hook in hooks:
            hook.remove()

    return activations


In [20]:
def extract_all_linear_activations(var_model, input_tokens):
    """
    Extract activations from all Linear (fully connected) layers in the VAR model.

    Args:
        var_model: The loaded VAR model
        input_tokens: Input tokens for the model

    Returns:
        Dictionary of layer_name -> activation tensor
    """
    activations = {}
    hooks = []

    # Find all Linear layers
    for name, module in var_model.named_modules():
        if isinstance(module, torch.nn.Linear):
            def make_hook(layer_name):
                def hook_fn(module, input, output):
                    if isinstance(output, torch.Tensor):
                        if len(output.shape) == 3:
                            activations[layer_name] = output.mean(dim=1).detach()
                        else:
                            activations[layer_name] = output.detach()
                    else:
                        activations[layer_name] = output[0].detach()
                return hook_fn
            hook = module.register_forward_hook(make_hook(name))
            hooks.append(hook)

    try:
        with torch.no_grad():
            _ = var_model(torch.tensor([1,]), input_tokens)
    finally:
        for hook in hooks:
            hook.remove()

    return activations

In [21]:
### UNIT MEM MEASUREMENT 
'''
datapath = './data'
import torchvision.transforms as transforms
import torchvision
import torch
from torchvision.transforms import ToTensor, Normalize
import scipy.io
from tqdm import tqdm
import numpy as np
s = 1
color_jitter = transforms.ColorJitter(
        0.9 * s, 0.9 * s, 0.9 * s, 0.1 * s)
flip = transforms.RandomHorizontalFlip()
Aug = transforms.Compose(
    [
    transforms.RandomResizedCrop(size=32),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomApply([color_jitter], p=0.9),
    transforms.RandomGrayscale(p=0.1)
    ])
data_transforms = transforms.Compose(
            [
                ToTensor(),
                Normalize(0.5, 0.5)
            ])
Imagenette = torchvision.datasets.Imagenette("./data/imagenette2/", download=True, transform=data_transforms)
#CIFAR_10_Dataset = torchvision.datasets.ImageNet("./data/imagenette/", train=True, download=False,)
sublist = list(range(0, 2, 1))
subset = torch.utils.data.Subset(Imagenette, sublist)
dataloader = torch.utils.data.DataLoader(subset, 1, shuffle=False, num_workers=2)

for image in dataloader:
    image = image[0]
    image = image.to(device)
    print(image.shape)
    print(image)
    # Extract activations from all FFN layers
    with torch.no_grad():
        # This depends on your VAE implementation
        # You may need to adjust based on your specific VAE interface
        #encoded_tokens = vae.encode(image_tensor)  # Adjust this line
        # Apply augmentations to ensure input is 16x16 as expected by VAE
        #image_aug = Aug(image)
        vae_out = vae.img_to_idxBl(image_aug)
        var_in = var.vae_quant_proxy[0].idxBl_to_var_input(vae_out)    
        activations = extract_var_activations(var, var_in)
    # Extract activations from VAR
    activations = extract_var_activations(var, var_in)

    print("Extracted activations from layers:")
    for layer_name, activation in activations.items():
        print(f"Layer: {layer_name}, Activation shape: {activation.shape}")
        # Optionally, you can save or visualize the activations here
        # For example, convert to numpy and save as image if needed
        activation_image = activation.cpu().numpy()
        print(activation_image)
        # Save or visualize activation_image as needed
'''

'\ndatapath = \'./data\'\nimport torchvision.transforms as transforms\nimport torchvision\nimport torch\nfrom torchvision.transforms import ToTensor, Normalize\nimport scipy.io\nfrom tqdm import tqdm\nimport numpy as np\ns = 1\ncolor_jitter = transforms.ColorJitter(\n        0.9 * s, 0.9 * s, 0.9 * s, 0.1 * s)\nflip = transforms.RandomHorizontalFlip()\nAug = transforms.Compose(\n    [\n    transforms.RandomResizedCrop(size=32),\n    transforms.RandomVerticalFlip(p=0.5),\n    transforms.RandomApply([color_jitter], p=0.9),\n    transforms.RandomGrayscale(p=0.1)\n    ])\ndata_transforms = transforms.Compose(\n            [\n                ToTensor(),\n                Normalize(0.5, 0.5)\n            ])\nImagenette = torchvision.datasets.Imagenette("./data/imagenette2/", download=True, transform=data_transforms)\n#CIFAR_10_Dataset = torchvision.datasets.ImageNet("./data/imagenette/", train=True, download=False,)\nsublist = list(range(0, 2, 1))\nsubset = torch.utils.data.Subset(Imagenette,

In [22]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
# Create dataset
dataset = datasets.ImageFolder(
    root='./data/imagenette2/',
    transform=transform
)
# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=1
)

### WITHOUT AUG

In [23]:
nsamples = 2
for i, image in enumerate(dataloader): 
    if i >= nsamples:
        break
    image = image[0]
    print(image.shape)
    image = image.to(device)
    with torch.no_grad():
        # This depends on your VAE implementation
        # You may need to adjust based on your specific VAE interface
        #encoded_tokens = vae.encode(image_tensor)  # Adjust this line
        vae_out = vae.img_to_idxBl(image)
        var_in = var.vae_quant_proxy[0].idxBl_to_var_input(vae_out)    
    activations = extract_var_activations(var, var_in)

    '''print("Extracted activations from layers:")
    for layer_name, activation in activations.items():
        print(f"Layer: {layer_name}, Activation shape: {activation.shape}")
        # Optionally, you can save or visualize the activations here
        # For example, convert to numpy and save as image if needed
        activation_image = activation.cpu().numpy()
        print(activation_image)
        # Save or visualize activation_image as needed'''

torch.Size([1, 3, 256, 256])


  with torch.cuda.amp.autocast(enabled=False):


torch.Size([1, 3, 256, 256])


In [24]:
# Take a single image from the dataloader
for image, _ in dataloader:
    image = image.to(device)
    break  # just take the first image

# Forward pass through VAE to get tokens for VAR
with torch.no_grad():
    vae_out = vae.img_to_idxBl(image)
    var_in = var.vae_quant_proxy[0].idxBl_to_var_input(vae_out)
    # Forward pass through VAR
    var_output = var(torch.tensor([1,]), var_in)
    print(var_output.shape)  # Check the shape of the output
    nimg = vae.embed_to_img(var_output, all_to_max_scale=True, last_one=False)
    print(nimg.shape)  # Check the shape of the reconstructed image
# To reconstruct the image from VAE tokens:
'''
with torch.no_grad():
    # vae.decode expects the quantized indices (vae_out)
    # If vae.decode expects a list, pass vae_out directly
    recon_image = vae.idxBl_to_img(var_output, same_shape=True)
    # recon_image shape: [batch, channels, height, width]
    recon_image = recon_image.clamp(0, 1).cpu()

import matplotlib.pyplot as plt
plt.imshow(recon_image[0].permute(1, 2, 0).numpy())
plt.axis('off')
plt.show()
'''

torch.Size([1, 680, 4096])


ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [] and output size of (16, 16). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

###  LETS NOW TRY IT WITH AUG 

In [25]:
import torchvision.transforms as transforms
import torchvision
s = 1
color_jitter = transforms.ColorJitter(
        0.9 * s, 0.9 * s, 0.9 * s, 0.1 * s)
flip = transforms.RandomHorizontalFlip()
Aug = transforms.Compose(
    [
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomApply([color_jitter], p=0.9),
    transforms.RandomGrayscale(p=0.1)
    ])

In [26]:

nsamples = 2
for i, image in enumerate(dataloader): 
    if i >= nsamples:
        break
    image = image[0]
    print(image.shape)
    image = image.to(device)
    with torch.no_grad():
        # This depends on your VAE implementation
        # You may need to adjust based on your specific VAE interface
        #encoded_tokens = vae.encode(image_tensor)  # Adjust this line
        vae_out = vae.img_to_idxBl(image)
        var_in = var.vae_quant_proxy[0].idxBl_to_var_input(vae_out)    
    activations = extract_var_activations(var, var_in)

    print("Extracted activations from layers:")
    for layer_name, activation in activations.items():
        print(f"Layer: {layer_name}, Activation shape: {activation.shape}")
        # Optionally, you can save or visualize the activations here
        # For example, convert to numpy and save as image if needed
        activation_image = activation.cpu().numpy()
        print(activation_image)
        # Save or visualize activation_image as needed

torch.Size([1, 3, 256, 256])
Extracted activations from layers:
Layer: blocks.0.ffn.fc1, Activation shape: torch.Size([1, 4096])
[[-0.77200174 -2.0322385  -2.6662674  ... -0.84804887 -2.5241601
  -0.4672901 ]]
Layer: blocks.0.ffn.fc2, Activation shape: torch.Size([1, 1024])
[[-0.11943099  0.00771889 -0.28072923 ... -0.21622857  0.01911305
  -0.09246563]]
Layer: blocks.1.ffn.fc1, Activation shape: torch.Size([1, 4096])
[[-1.073757    0.13516061 -1.1871307  ... -0.36688873 -1.8417604
  -1.4245058 ]]
Layer: blocks.1.ffn.fc2, Activation shape: torch.Size([1, 1024])
[[-0.07402807 -0.11192346 -0.2510957  ...  0.00666183  0.10459945
  -0.04361888]]
Layer: blocks.2.ffn.fc1, Activation shape: torch.Size([1, 4096])
[[-0.775347   -1.2153834  -1.3380681  ... -0.6721959  -3.0689251
  -0.96198535]]
Layer: blocks.2.ffn.fc2, Activation shape: torch.Size([1, 1024])
[[ 0.00215805 -0.16506933 -0.11634263 ... -0.16431956  0.13905723
   0.00282809]]
Layer: blocks.3.ffn.fc1, Activation shape: torch.Size([1,

### ITERATE OVER DIFFERENT RANDOM AUGS

In [27]:
nsamples = 2
for i, image in enumerate(dataloader): 
    print(i)
    if i >= nsamples:
        break
    image = image[0]
    print(image.shape)
    image = image.to(device)
    all_activations = []
    for _ in range(10):  # Iterate over different random augmentations
        image = Aug(image)  # Apply random augmentation
        with torch.no_grad():
            # This depends on your VAE implementation
            # You may need to adjust based on your specific VAE interface
            #encoded_tokens = vae.encode(image_tensor)  # Adjust this line
            vae_out = vae.img_to_idxBl(image)
            var_in = var.vae_quant_proxy[0].idxBl_to_var_input(vae_out)    
        activations = extract_var_activations(var, var_in)
        all_activations.append(activations)
    

0
torch.Size([1, 3, 256, 256])
1
torch.Size([1, 3, 256, 256])
2


##### Lets see how the activations look like

In [28]:
import numpy as np

In [29]:
for activation in all_activations:
    print("Extracted activations from layers:")
    print(activation.keys())  # Print layer names

    for layer_name, activation_tensor in activation.items():
        print(f"Layer: {layer_name}, Activation shape: {activation_tensor.shape}")

Extracted activations from layers:
dict_keys(['blocks.0.ffn.fc1', 'blocks.0.ffn.fc2', 'blocks.1.ffn.fc1', 'blocks.1.ffn.fc2', 'blocks.2.ffn.fc1', 'blocks.2.ffn.fc2', 'blocks.3.ffn.fc1', 'blocks.3.ffn.fc2', 'blocks.4.ffn.fc1', 'blocks.4.ffn.fc2', 'blocks.5.ffn.fc1', 'blocks.5.ffn.fc2', 'blocks.6.ffn.fc1', 'blocks.6.ffn.fc2', 'blocks.7.ffn.fc1', 'blocks.7.ffn.fc2', 'blocks.8.ffn.fc1', 'blocks.8.ffn.fc2', 'blocks.9.ffn.fc1', 'blocks.9.ffn.fc2', 'blocks.10.ffn.fc1', 'blocks.10.ffn.fc2', 'blocks.11.ffn.fc1', 'blocks.11.ffn.fc2', 'blocks.12.ffn.fc1', 'blocks.12.ffn.fc2', 'blocks.13.ffn.fc1', 'blocks.13.ffn.fc2', 'blocks.14.ffn.fc1', 'blocks.14.ffn.fc2', 'blocks.15.ffn.fc1', 'blocks.15.ffn.fc2', 'head'])
Layer: blocks.0.ffn.fc1, Activation shape: torch.Size([1, 4096])
Layer: blocks.0.ffn.fc2, Activation shape: torch.Size([1, 1024])
Layer: blocks.1.ffn.fc1, Activation shape: torch.Size([1, 4096])
Layer: blocks.1.ffn.fc2, Activation shape: torch.Size([1, 1024])
Layer: blocks.2.ffn.fc1, Activati

In [30]:
block_fc1_means = []
sum = torch.zeros_like(all_activations[0]['blocks.0.ffn.fc1'])
for activation in all_activations:
    tensor = activation['blocks.0.ffn.fc1']
    print(tensor)
    sum += tensor
print("Result:")
sum = sum/len(all_activations)  # Average across all activations
print(sum)

tensor([[-0.8707, -0.4671, -2.0508,  ..., -1.2212, -2.1438, -0.2104]])
tensor([[-0.9984, -1.0831, -1.9471,  ..., -1.2571, -2.1063, -0.3768]])
tensor([[-0.8748, -1.1327, -2.1649,  ..., -1.1381, -2.3447, -0.3765]])
tensor([[-0.7767, -0.7913, -2.2565,  ..., -1.1536, -2.2775, -0.3059]])
tensor([[-0.7865, -0.5853, -2.2865,  ..., -1.1081, -2.3486, -0.2026]])
tensor([[-0.7632, -1.8862, -2.3982,  ..., -1.0351, -2.4850, -0.0863]])
tensor([[-0.6865, -3.0539, -2.4657,  ..., -1.0316, -2.5141, -0.0775]])
tensor([[-0.7863, -3.0650, -2.0223,  ..., -1.2369, -2.3291, -0.0354]])
tensor([[-0.7993, -3.2522, -1.9859,  ..., -1.3331, -2.2886,  0.0123]])
tensor([[-0.8590, -3.4707, -1.9141,  ..., -1.3095, -2.2968, -0.0912]])
Result:
tensor([[-0.8201, -1.8788, -2.1492,  ..., -1.1824, -2.3135, -0.1750]])


In [31]:
# Compute the mean activation across all augmentations for each layer in all_activations
layer_means = {}
if all_activations:
    # Get all layer names from the first activation dict
    layer_names = all_activations[0].keys()
    for lname in layer_names:
        # Stack tensors for this layer across all augmentations
        stacked = torch.stack([a[lname] for a in all_activations], dim=0)
        # Compute mean across augmentations (dim=0)
        layer_means[lname] = stacked.mean(dim=0)
        print(f"Layer: {lname}, Mean activation shape: {layer_means[lname].shape}")

Layer: blocks.0.ffn.fc1, Mean activation shape: torch.Size([1, 4096])
Layer: blocks.0.ffn.fc2, Mean activation shape: torch.Size([1, 1024])
Layer: blocks.1.ffn.fc1, Mean activation shape: torch.Size([1, 4096])
Layer: blocks.1.ffn.fc2, Mean activation shape: torch.Size([1, 1024])
Layer: blocks.2.ffn.fc1, Mean activation shape: torch.Size([1, 4096])
Layer: blocks.2.ffn.fc2, Mean activation shape: torch.Size([1, 1024])
Layer: blocks.3.ffn.fc1, Mean activation shape: torch.Size([1, 4096])
Layer: blocks.3.ffn.fc2, Mean activation shape: torch.Size([1, 1024])
Layer: blocks.4.ffn.fc1, Mean activation shape: torch.Size([1, 4096])
Layer: blocks.4.ffn.fc2, Mean activation shape: torch.Size([1, 1024])
Layer: blocks.5.ffn.fc1, Mean activation shape: torch.Size([1, 4096])
Layer: blocks.5.ffn.fc2, Mean activation shape: torch.Size([1, 1024])
Layer: blocks.6.ffn.fc1, Mean activation shape: torch.Size([1, 4096])
Layer: blocks.6.ffn.fc2, Mean activation shape: torch.Size([1, 1024])
Layer: blocks.7.ffn.

In [32]:
print(layer_means['blocks.0.ffn.fc1'])

tensor([[-0.8201, -1.8788, -2.1492,  ..., -1.1824, -2.3135, -0.1750]])


### NOW ALSO ITERATE OVER DIFFERENT IMAGES

In [79]:
nsamples = 12
all_means = []
for i, image in enumerate(dataloader): 
    print(i)
    if i >= nsamples:
        break
    image = image[0]
    print(image.shape)
    image = image.to(device)
    this_imgs_activations = []
    for _ in range(10):  # Iterate over different random augmentations
        image = Aug(image)  # Apply random augmentation
        with torch.no_grad():
            # This depends on your VAE implementation
            # You may need to adjust based on your specific VAE interface
            #encoded_tokens = vae.encode(image_tensor)  # Adjust this line
            vae_out = vae.img_to_idxBl(image)
            var_in = var.vae_quant_proxy[0].idxBl_to_var_input(vae_out)    
        activations = extract_all_linear_activations(var, var_in)
        this_imgs_activations.append(activations)
    # Compute the mean activation across all augmentations for each layer in all_activations
    layer_means = {}
    if this_imgs_activations:
        # Get all layer names from the first activation dict
        layer_names = this_imgs_activations[0].keys()
        for lname in layer_names:
            # Stack tensors for this layer across all augmentations
            stacked = torch.stack([a[lname] for a in this_imgs_activations], dim=0)
            # Compute mean across augmentations (dim=0)
            layer_means[lname] = stacked.mean(dim=0)
            #print(f"Layer: {lname}, Mean activation shape: {layer_means[lname].shape}")
    all_means.append(layer_means)

0
torch.Size([1, 3, 256, 256])


  with torch.cuda.amp.autocast(enabled=False):


1
torch.Size([1, 3, 256, 256])
2
torch.Size([1, 3, 256, 256])
3
torch.Size([1, 3, 256, 256])
4
torch.Size([1, 3, 256, 256])
5
torch.Size([1, 3, 256, 256])
6
torch.Size([1, 3, 256, 256])
7
torch.Size([1, 3, 256, 256])
8
torch.Size([1, 3, 256, 256])
9
torch.Size([1, 3, 256, 256])
10
torch.Size([1, 3, 256, 256])
11
torch.Size([1, 3, 256, 256])
12


##### Lets try to take the max for each unit across all_means

In [80]:
# Get the name of the very first layer
first_layer = list(all_means[0].keys())[0]

# Stack all mean activations for the first layer across images
first_layer_means = torch.stack([mean[first_layer] for mean in all_means], dim=0)

# Compute the maximum for each unit across all images (dim=0)
max_per_unit = first_layer_means.max(dim=0).values

print(f"Maximum per unit in the first layer ({first_layer}):")
print(max_per_unit)

Maximum per unit in the first layer (word_embed):
tensor([[-0.0202, -0.0096,  0.0225,  ...,  0.1172,  0.0011,  0.0291]])


In [81]:
# Get the name of the very first layer
first_layer = list(all_means[0].keys())[0]

# Stack all mean activations for the first layer across images
first_layer_means = torch.stack([mean[first_layer] for mean in all_means], dim=0)

# Compute the maximum for each unit across all images (dim=0)
max_per_unit = first_layer_means.max(dim=0).values

print(f"Maximum per unit in the first layer ({first_layer}):")
print(max_per_unit)

Maximum per unit in the first layer (word_embed):
tensor([[-0.0202, -0.0096,  0.0225,  ...,  0.1172,  0.0011,  0.0291]])


In [82]:
# For each layer, stack all mean activations across images and compute the maximum per unit
max_per_unit_all_layers = {}

# Assume all_means is a list of dicts, each dict: layer_name -> mean activation tensor
if all_means:
    layer_names = all_means[0].keys()
    for lname in layer_names:
        # Stack all mean activations for this layer across images
        stacked_means = torch.stack([mean[lname] for mean in all_means], dim=0)
        # Compute the maximum for each unit across all images (dim=0)
        max_per_unit_all_layers[lname] = stacked_means.max(dim=0).values
        print(f"Maximum per unit in layer {lname}:")
        print(max_per_unit_all_layers[lname])

Maximum per unit in layer word_embed:
tensor([[-0.0202, -0.0096,  0.0225,  ...,  0.1172,  0.0011,  0.0291]])
Maximum per unit in layer blocks.0.ada_lin.1:
tensor([[ 0.0460, -0.0281, -0.3529,  ...,  0.0516,  0.0199, -0.2571]])
Maximum per unit in layer blocks.0.attn.proj:
tensor([[ 0.1944, -0.1951,  0.2203,  ..., -0.2221,  0.2507,  0.1262]])
Maximum per unit in layer blocks.0.ffn.fc1:
tensor([[-0.8780, -1.7730, -1.7301,  ..., -0.9808, -1.9117, -0.0549]])
Maximum per unit in layer blocks.0.ffn.fc2:
tensor([[ 0.2409, -0.0698,  0.5170,  ..., -0.0554,  0.0205,  0.0233]])
Maximum per unit in layer blocks.1.ada_lin.1:
tensor([[ 0.1872,  0.1702, -0.3267,  ..., -0.1727, -0.0815, -0.1083]])
Maximum per unit in layer blocks.1.attn.proj:
tensor([[-0.0521,  0.0197,  0.2376,  ...,  0.2205,  0.0357,  0.1480]])
Maximum per unit in layer blocks.1.ffn.fc1:
tensor([[-0.7941,  0.2659, -1.2367,  ..., -0.5028, -2.3751, -1.5170]])
Maximum per unit in layer blocks.1.ffn.fc2:
tensor([[-0.0456, -0.0974,  0.5530

In [83]:
for means in all_means:

    # Get the name of the very first layer
    first_layer = list(means.keys())[0]

    # Stack all mean activations for the first layer across images
    print(means[first_layer].shape)
    print(means[first_layer][0][0])
    

torch.Size([1, 1024])
tensor(-0.0300)
torch.Size([1, 1024])
tensor(-0.0202)
torch.Size([1, 1024])
tensor(-0.0273)
torch.Size([1, 1024])
tensor(-0.0324)
torch.Size([1, 1024])
tensor(-0.0258)
torch.Size([1, 1024])
tensor(-0.0219)
torch.Size([1, 1024])
tensor(-0.0278)
torch.Size([1, 1024])
tensor(-0.0277)
torch.Size([1, 1024])
tensor(-0.0243)
torch.Size([1, 1024])
tensor(-0.0258)
torch.Size([1, 1024])
tensor(-0.0206)
torch.Size([1, 1024])
tensor(-0.0252)


In [84]:
# For each layer, stack all mean activations across images and compute the maximum per unit
max_per_unit_all_layers = {}

# Assume all_means is a list of dicts, each dict: layer_name -> mean activation tensor
if all_means:
    layer_names = all_means[0].keys()
    for lname in layer_names:
        # Stack all mean activations for this layer across images
        stacked_means = torch.stack([mean[lname] for mean in all_means], dim=0)
        # Compute the maximum for each unit across all images (dim=0)
        stacked_means = stacked_means.abs()
        max_per_unit_all_layers[lname] = stacked_means.max(dim=0).values
        print(f"Maximum per unit in layer {lname}:")
        print(max_per_unit_all_layers[lname])

Maximum per unit in layer word_embed:
tensor([[0.0324, 0.0155, 0.0225,  ..., 0.1172, 0.0043, 0.0291]])
Maximum per unit in layer blocks.0.ada_lin.1:
tensor([[0.0460, 0.0314, 0.3555,  ..., 0.0516, 0.0199, 0.2588]])
Maximum per unit in layer blocks.0.attn.proj:
tensor([[0.1944, 0.3769, 0.2203,  ..., 0.4112, 0.2507, 0.1262]])
Maximum per unit in layer blocks.0.ffn.fc1:
tensor([[1.0460, 2.4901, 2.2251,  ..., 1.5721, 2.4204, 0.2974]])
Maximum per unit in layer blocks.0.ffn.fc2:
tensor([[0.2409, 0.1868, 0.5170,  ..., 0.2783, 0.3611, 0.0682]])
Maximum per unit in layer blocks.1.ada_lin.1:
tensor([[0.1872, 0.1702, 0.3296,  ..., 0.1782, 0.1196, 0.1190]])
Maximum per unit in layer blocks.1.attn.proj:
tensor([[0.1438, 0.0723, 0.2376,  ..., 0.2205, 0.0850, 0.1480]])
Maximum per unit in layer blocks.1.ffn.fc1:
tensor([[1.7746, 0.2659, 1.9885,  ..., 1.7509, 3.3767, 2.5543]])
Maximum per unit in layer blocks.1.ffn.fc2:
tensor([[0.1427, 0.2680, 0.5530,  ..., 0.3982, 0.5664, 0.0881]])
Maximum per unit 

In [85]:
#For each layer, stack all mean activations across images and compute the maximum and argmax per unit
max_per_unit_all_layers = {}
argmax_per_unit_all_layers = {}

if all_means:
    layer_names = all_means[0].keys()
    for lname in layer_names:
        # Stack all mean activations for this layer across images
        stacked_means = torch.stack([mean[lname] for mean in all_means], dim=0)
        # Compute the maximum and argmax for each unit across all images (dim=0)
        stacked_means = stacked_means.abs()
        max_per_unit_all_layers[lname] = stacked_means.max(dim=0).values
        argmax_per_unit_all_layers[lname] = stacked_means.argmax(dim=0)
        print(f"Layer: {lname}")
        print(f"Maximum per unit: {max_per_unit_all_layers[lname]}")
        print(f"Argmax per unit: {argmax_per_unit_all_layers[lname]}")

Layer: word_embed
Maximum per unit: tensor([[0.0324, 0.0155, 0.0225,  ..., 0.1172, 0.0043, 0.0291]])
Argmax per unit: tensor([[ 3,  3, 10,  ...,  1,  1,  8]])
Layer: blocks.0.ada_lin.1
Maximum per unit: tensor([[0.0460, 0.0314, 0.3555,  ..., 0.0516, 0.0199, 0.2588]])
Argmax per unit: tensor([[1, 1, 3,  ..., 3, 3, 3]])
Layer: blocks.0.attn.proj
Maximum per unit: tensor([[0.1944, 0.3769, 0.2203,  ..., 0.4112, 0.2507, 0.1262]])
Argmax per unit: tensor([[ 3, 11,  4,  ...,  1,  3,  5]])
Layer: blocks.0.ffn.fc1
Maximum per unit: tensor([[1.0460, 2.4901, 2.2251,  ..., 1.5721, 2.4204, 0.2974]])
Argmax per unit: tensor([[ 3,  8,  8,  ...,  4, 11,  8]])
Layer: blocks.0.ffn.fc2
Maximum per unit: tensor([[0.2409, 0.1868, 0.5170,  ..., 0.2783, 0.3611, 0.0682]])
Argmax per unit: tensor([[6, 4, 4,  ..., 8, 4, 4]])
Layer: blocks.1.ada_lin.1
Maximum per unit: tensor([[0.1872, 0.1702, 0.3296,  ..., 0.1782, 0.1196, 0.1190]])
Argmax per unit: tensor([[3, 3, 1,  ..., 1, 1, 1]])
Layer: blocks.1.attn.proj
Ma

In [86]:
# For each layer, compute the mean activation across all images except the one at argmax for each unit
mean_minus_max_per_unit_all_layers = {}

if all_means:
    layer_names = all_means[0].keys()
    n_imgs = len(all_means)
    for lname in layer_names:
        # Stack all mean activations for this layer across images: shape [n_imgs, 1, n_units]
        stacked_means = torch.stack([mean[lname] for mean in all_means], dim=0)
        # Find argmax index for each unit
        stacked_means = stacked_means.abs()
        argmax_idx = stacked_means.argmax(dim=0)  # shape [1, n_units]
        # For each unit, exclude the image at argmax and compute mean over the rest
        means = []
        for unit in range(stacked_means.shape[2]):
            # Get indices of all images except the one at argmax for this unit
            mask = torch.ones(n_imgs, dtype=torch.bool)
            mask[argmax_idx[0, unit]] = False
            # Compute mean over the remaining images for this unit
            mean_val = stacked_means[mask, 0, unit].mean()
            means.append(mean_val)
        mean_minus_max_per_unit_all_layers[lname] = torch.stack(means)
        print(f"Layer: {lname}, µ-max per unit shape: {mean_minus_max_per_unit_all_layers[lname]}")

Layer: word_embed, µ-max per unit shape: tensor([0.0252, 0.0134, 0.0067,  ..., 0.0978, 0.0010, 0.0268])
Layer: blocks.0.ada_lin.1, µ-max per unit shape: tensor([0.0422, 0.0298, 0.3540,  ..., 0.0442, 0.0189, 0.2578])
Layer: blocks.0.attn.proj, µ-max per unit shape: tensor([0.1170, 0.2848, 0.1034,  ..., 0.3321, 0.1761, 0.0251])
Layer: blocks.0.ffn.fc1, µ-max per unit shape: tensor([0.9642, 2.1175, 1.9599,  ..., 1.2164, 2.1719, 0.1451])
Layer: blocks.0.ffn.fc2, µ-max per unit shape: tensor([0.1503, 0.1039, 0.1953,  ..., 0.1001, 0.1713, 0.0280])
Layer: blocks.1.ada_lin.1, µ-max per unit shape: tensor([0.1866, 0.1672, 0.3281,  ..., 0.1755, 0.1005, 0.1137])
Layer: blocks.1.attn.proj, µ-max per unit shape: tensor([0.0750, 0.0418, 0.1707,  ..., 0.1340, 0.0252, 0.1013])
Layer: blocks.1.ffn.fc1, µ-max per unit shape: tensor([1.3273, 0.1490, 1.6628,  ..., 0.8801, 2.9533, 2.0469])
Layer: blocks.1.ffn.fc2, µ-max per unit shape: tensor([0.0732, 0.1980, 0.2299,  ..., 0.2863, 0.4124, 0.0331])
Layer: b

In [87]:
# Compute the difference between max_per_unit_all_layers and mean_minus_max_per_unit_all_layers for each layer
diff_max_minus_mean_all_layers = {}

for lname in max_per_unit_all_layers.keys():
    diff_max_minus_mean_all_layers[lname] = max_per_unit_all_layers[lname] - mean_minus_max_per_unit_all_layers[lname]
    print(f"Layer: {lname}, Difference shape: {diff_max_minus_mean_all_layers[lname].shape}")
    print(diff_max_minus_mean_all_layers[lname])

Layer: word_embed, Difference shape: torch.Size([1, 1024])
tensor([[0.0072, 0.0021, 0.0158,  ..., 0.0195, 0.0033, 0.0023]])
Layer: blocks.0.ada_lin.1, Difference shape: torch.Size([1, 6144])
tensor([[0.0038, 0.0017, 0.0016,  ..., 0.0075, 0.0010, 0.0010]])
Layer: blocks.0.attn.proj, Difference shape: torch.Size([1, 1024])
tensor([[0.0775, 0.0922, 0.1168,  ..., 0.0791, 0.0746, 0.1010]])
Layer: blocks.0.ffn.fc1, Difference shape: torch.Size([1, 4096])
tensor([[0.0818, 0.3726, 0.2652,  ..., 0.3557, 0.2485, 0.1523]])
Layer: blocks.0.ffn.fc2, Difference shape: torch.Size([1, 1024])
tensor([[0.0906, 0.0829, 0.3218,  ..., 0.1781, 0.1898, 0.0402]])
Layer: blocks.1.ada_lin.1, Difference shape: torch.Size([1, 6144])
tensor([[0.0006, 0.0031, 0.0014,  ..., 0.0027, 0.0190, 0.0053]])
Layer: blocks.1.attn.proj, Difference shape: torch.Size([1, 1024])
tensor([[0.0688, 0.0306, 0.0669,  ..., 0.0865, 0.0598, 0.0467]])
Layer: blocks.1.ffn.fc1, Difference shape: torch.Size([1, 4096])
tensor([[0.4473, 0.1169

In [88]:
# Compute the normalized difference: (max - mean) / (max + mean) for each layer
normalized_diff_all_layers = {}

for lname in diff_max_minus_mean_all_layers.keys():
    max_vals = max_per_unit_all_layers[lname]
    mean_vals = mean_minus_max_per_unit_all_layers[lname]
    denom = max_vals + mean_vals
    # Avoid division by zero
    normalized_diff = diff_max_minus_mean_all_layers[lname] / (denom + 1e-8)
    normalized_diff_all_layers[lname] = normalized_diff
    print(f"Layer: {lname}, Normalized difference shape: {normalized_diff.shape}")
    print(normalized_diff)

Layer: word_embed, Normalized difference shape: torch.Size([1, 1024])
tensor([[0.1257, 0.0727, 0.5409,  ..., 0.0905, 0.6283, 0.0415]])
Layer: blocks.0.ada_lin.1, Normalized difference shape: torch.Size([1, 6144])
tensor([[0.0431, 0.0274, 0.0022,  ..., 0.0779, 0.0248, 0.0020]])
Layer: blocks.0.attn.proj, Normalized difference shape: torch.Size([1, 1024])
tensor([[0.2487, 0.1393, 0.3609,  ..., 0.1065, 0.1748, 0.6679]])
Layer: blocks.0.ffn.fc1, Normalized difference shape: torch.Size([1, 4096])
tensor([[0.0407, 0.0809, 0.0634,  ..., 0.1276, 0.0541, 0.3443]])
Layer: blocks.0.ffn.fc2, Normalized difference shape: torch.Size([1, 1024])
tensor([[0.2316, 0.2851, 0.4517,  ..., 0.4708, 0.3565, 0.4179]])
Layer: blocks.1.ada_lin.1, Normalized difference shape: torch.Size([1, 6144])
tensor([[0.0017, 0.0090, 0.0021,  ..., 0.0077, 0.0865, 0.0229]])
Layer: blocks.1.attn.proj, Normalized difference shape: torch.Size([1, 1024])
tensor([[0.3142, 0.2681, 0.1639,  ..., 0.2439, 0.5428, 0.1874]])
Layer: bloc

In [89]:
# Flatten all normalized differences across all layers into a single tensor
all_norm_diffs = []
unit_layer_map = []

for lname, norm_diff in normalized_diff_all_layers.items():
    # norm_diff shape: [1, n_units]
    n_units = norm_diff.shape[1]
    for unit_idx in range(n_units):
        all_norm_diffs.append(norm_diff[0, unit_idx].item())
        unit_layer_map.append((lname, unit_idx))

all_norm_diffs = torch.tensor(all_norm_diffs)

# Compute the number of top units (10%)
topk_percent = 0.1
k = max(1, int(len(all_norm_diffs) * topk_percent))

# Get the indices of the top 10% highest values
topk_vals, topk_indices = torch.topk(all_norm_diffs, k)

# Map indices back to (layer, unit) pairs
topk_units_overall = [unit_layer_map[idx] for idx in topk_indices.tolist()]

print(f"Top {k} units overall (layer, unit): {topk_units_overall}")

Top 20377 units overall (layer, unit): [('blocks.2.ffn.fc2', 791), ('blocks.6.ffn.fc2', 979), ('blocks.8.attn.proj', 1), ('blocks.2.attn.proj', 512), ('blocks.2.ffn.fc2', 34), ('blocks.9.ffn.fc2', 369), ('blocks.10.attn.proj', 102), ('blocks.9.ffn.fc2', 521), ('blocks.5.ffn.fc2', 263), ('blocks.12.ffn.fc1', 4024), ('blocks.13.attn.proj', 752), ('blocks.6.ffn.fc2', 318), ('blocks.3.ffn.fc2', 296), ('blocks.5.attn.proj', 729), ('blocks.6.attn.proj', 361), ('blocks.2.ffn.fc2', 303), ('blocks.10.ffn.fc1', 3378), ('blocks.6.ffn.fc1', 4080), ('blocks.10.attn.proj', 82), ('blocks.12.attn.proj', 767), ('blocks.3.attn.proj', 339), ('blocks.3.ffn.fc2', 219), ('blocks.6.attn.proj', 471), ('blocks.0.attn.proj', 366), ('blocks.6.ffn.fc2', 937), ('blocks.0.ffn.fc2', 378), ('blocks.3.ffn.fc1', 453), ('blocks.10.attn.proj', 980), ('blocks.0.attn.proj', 650), ('blocks.7.ffn.fc1', 2680), ('blocks.3.attn.proj', 676), ('blocks.9.ffn.fc2', 448), ('blocks.4.ffn.fc1', 415), ('blocks.8.ffn.fc2', 764), ('block

In [90]:
def display_topk_units_nicely(topk_units_overall, topk_vals, k=10):
    """
    Display the coordinates and act'ivation values of the top-k units in a readable table format.

    Args:
        topk_units_overall (list): List of (layer_name, unit_idx) tuples.
        topk_vals (torch.Tensor): Activation values for the top units.
        k (int): Number of top units to display.
    """
    print(f"{'Rank':<5} {'Layer':<20} {'Unit Index':<10} {'Activation':<10}")
    print("-" * 55)
    for i, ((layer, unit_idx), val) in enumerate(zip(topk_units_overall[:k], topk_vals[:k]), 1):
        print(f"{i:<5} {layer:<20} {unit_idx:<10} {val.item():<10.4f}")

# Example usage: show top 20 units
display_topk_units_nicely(topk_units_overall, topk_vals, k=20000)

Rank  Layer                Unit Index Activation
-------------------------------------------------------
1     blocks.2.ffn.fc2     791        0.8903    
2     blocks.6.ffn.fc2     979        0.8364    
3     blocks.8.attn.proj   1          0.8350    
4     blocks.2.attn.proj   512        0.8319    
5     blocks.2.ffn.fc2     34         0.8259    
6     blocks.9.ffn.fc2     369        0.8224    
7     blocks.10.attn.proj  102        0.8054    
8     blocks.9.ffn.fc2     521        0.8027    
9     blocks.5.ffn.fc2     263        0.7968    
10    blocks.12.ffn.fc1    4024       0.7948    
11    blocks.13.attn.proj  752        0.7919    
12    blocks.6.ffn.fc2     318        0.7892    
13    blocks.3.ffn.fc2     296        0.7867    
14    blocks.5.attn.proj   729        0.7849    
15    blocks.6.attn.proj   361        0.7824    
16    blocks.2.ffn.fc2     303        0.7818    
17    blocks.10.ffn.fc1    3378       0.7816    
18    blocks.6.ffn.fc1     4080       0.7810    
19    blocks.

In [91]:
# Choose the 10% highest activating units overall (across all layers)
all_units = []
unit_layer_map = []

for lname, norm_diff in normalized_diff_all_layers.items():
    # norm_diff shape: [1, n_units]
    n_units = norm_diff.shape[1]
    for unit_idx in range(n_units):
        all_units.append(norm_diff[0, unit_idx].item())
        unit_layer_map.append((lname, unit_idx))

all_units = torch.tensor(all_units)
k = max(1, int(len(all_units) * topk_percent))
topk_vals, topk_indices = torch.topk(all_units, k)

# Map back to (layer, unit) pairs
# Exclude units from the 'head' layer
topk_units_overall = [unit_layer_map[idx] for idx in topk_indices.tolist() if unit_layer_map[idx][0] != 'head']
print(f"Top {k} units overall (layer, unit) excluding 'head': {topk_units_overall}")

Top 20377 units overall (layer, unit) excluding 'head': [('blocks.2.ffn.fc2', 791), ('blocks.6.ffn.fc2', 979), ('blocks.8.attn.proj', 1), ('blocks.2.attn.proj', 512), ('blocks.2.ffn.fc2', 34), ('blocks.9.ffn.fc2', 369), ('blocks.10.attn.proj', 102), ('blocks.9.ffn.fc2', 521), ('blocks.5.ffn.fc2', 263), ('blocks.12.ffn.fc1', 4024), ('blocks.13.attn.proj', 752), ('blocks.6.ffn.fc2', 318), ('blocks.3.ffn.fc2', 296), ('blocks.5.attn.proj', 729), ('blocks.6.attn.proj', 361), ('blocks.2.ffn.fc2', 303), ('blocks.10.ffn.fc1', 3378), ('blocks.6.ffn.fc1', 4080), ('blocks.10.attn.proj', 82), ('blocks.12.attn.proj', 767), ('blocks.3.attn.proj', 339), ('blocks.3.ffn.fc2', 219), ('blocks.6.attn.proj', 471), ('blocks.0.attn.proj', 366), ('blocks.6.ffn.fc2', 937), ('blocks.0.ffn.fc2', 378), ('blocks.3.ffn.fc1', 453), ('blocks.10.attn.proj', 980), ('blocks.0.attn.proj', 650), ('blocks.7.ffn.fc1', 2680), ('blocks.3.attn.proj', 676), ('blocks.9.ffn.fc2', 448), ('blocks.4.ffn.fc1', 415), ('blocks.8.ffn.fc

In [92]:
def display_topk_units_nicely(topk_units_overall, topk_vals, k=20):
    """
    Display the coordinates and activation values of the top-k units in a readable table format.

    Args:
        topk_units_overall (list): List of (layer_name, unit_idx) tuples.
        topk_vals (torch.Tensor): Activation values for the top units.
        k (int): Number of top units to display.
    """
    print(f"{'Rank':<5} {'Layer':<25} {'Unit Index':<10} {'Activation':<12}")
    print("-" * 60)
    for i, ((layer, unit_idx), val) in enumerate(zip(topk_units_overall[:k], topk_vals[:k]), 1):
        print(f"{i:<5} {layer:<25} {unit_idx:<10} {val.item():<12.4f}")

# Display the top 10% highest activating neurons
k = len(topk_units_overall)
display_topk_units_nicely(topk_units_overall, topk_vals, k=min(20, k))  # Show top 20 for readability

Rank  Layer                     Unit Index Activation  
------------------------------------------------------------
1     blocks.2.ffn.fc2          791        0.8903      
2     blocks.6.ffn.fc2          979        0.8364      
3     blocks.8.attn.proj        1          0.8350      
4     blocks.2.attn.proj        512        0.8319      
5     blocks.2.ffn.fc2          34         0.8259      
6     blocks.9.ffn.fc2          369        0.8224      
7     blocks.10.attn.proj       102        0.8054      
8     blocks.9.ffn.fc2          521        0.8027      
9     blocks.5.ffn.fc2          263        0.7968      
10    blocks.12.ffn.fc1         4024       0.7948      
11    blocks.13.attn.proj       752        0.7919      
12    blocks.6.ffn.fc2          318        0.7892      
13    blocks.3.ffn.fc2          296        0.7867      
14    blocks.5.attn.proj        729        0.7849      
15    blocks.6.attn.proj        361        0.7824      
16    blocks.2.ffn.fc2          303        

In [93]:
import pandas as pd

# Create a DataFrame showing the distribution of top-k units across layers, ordered by layer position

# Extract just the layer names from topk_units_overall (ignore unit index)
topk_layer_names = [layer for layer, unit in topk_units_overall]

# Count occurrences of each layer in the top-k units
layer_counts = pd.Series(topk_layer_names).value_counts().rename_axis('Layer').reset_index(name='TopK_Count')

# To order by layer position, extract the numeric index from the layer name (e.g., 'blocks.3.ffn.fc1' -> 3)
def layer_sort_key(layer_name):
    if layer_name.startswith('blocks.'):
        # Extract the block number
        parts = layer_name.split('.')
        try:
            block_num = int(parts[1])
        except Exception:
            block_num = -1
        # Optionally, sort by sublayer type as well
        sublayer_priority = {'ffn.fc1': 0, 'ffn.fc2': 1, 'attn.proj': 2, 'ada_lin.1': 3}
        sublayer = '.'.join(parts[2:]) if len(parts) > 2 else ''
        sublayer_idx = sublayer_priority.get(sublayer, 99)
        return (block_num, sublayer_idx)
    elif layer_name == 'word_embed':
        return (-2, 0)
    elif layer_name == 'head':
        return (999, 0)
    else:
        return (998, 0)

layer_counts['sort_key'] = layer_counts['Layer'].apply(layer_sort_key)
layer_counts = layer_counts.sort_values('sort_key').drop(columns='sort_key').reset_index(drop=True)

# Show the full DataFrame without truncation
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', None):
    display(layer_counts)


Unnamed: 0,Layer,TopK_Count
0,word_embed,179
1,blocks.0.ffn.fc1,333
2,blocks.0.ffn.fc2,304
3,blocks.0.attn.proj,274
4,blocks.0.ada_lin.1,31
5,blocks.1.ffn.fc1,478
6,blocks.1.ffn.fc2,324
7,blocks.1.attn.proj,315
8,blocks.1.ada_lin.1,33
9,blocks.2.ffn.fc1,393


In [94]:
# Export the layer_counts DataFrame to CSV
layer_counts.to_csv('topk_layer_distribution.csv', index=False)
print("Exported topk layer distribution to 'topk_layer_distribution.csv'")

Exported topk layer distribution to 'topk_layer_distribution.csv'
