In [8]:
import torch
import torch.nn.functional as F
import numpy as np
from models.vit_qa import ViT

from PIL import Image
from torchvision import transforms


model = ViT(img_size=32,patch_size=4,num_classes=100,dim=192,
           mlp_dim_ratio=2,depth=9,heads=12,dim_head=192//12,stochastic_depth=0.1).cuda()

model.eval() 


image_path = "car.jpg"
image = Image.open(image_path)

transform = transforms.Compose([
    transforms.Resize((32, 32)),  
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

x = transform(image).unsqueeze(0).cuda()
print(x.shape)
blocks = model.transformer.layers 
print(blocks[0][0].fn)



torch.Size([1, 3, 32, 32])
Attention(
  (attend): Softmax(dim=-1)
  (to_qkv): Linear(in_features=192, out_features=576, bias=False)
  (to_out): Sequential(
    (0): Linear(in_features=192, out_features=192, bias=True)
    (1): Dropout(p=0.0, inplace=False)
  )
)


In [9]:
import torch
import torchvision.models as models


checkpoint = torch.load("best-vit.pth")

model_weights = checkpoint["model_state_dict"]

model_weights = {k.replace("module.", ""): v for k, v in model_weights.items()}

model.load_state_dict(model_weights, strict=False)

<All keys matched successfully>

In [10]:
def compute_attention_entropy(attn_weights):

    attn_weights = attn_weights.mean(dim=1)  
    entropy = -torch.sum(attn_weights * torch.log(attn_weights + 1e-6), dim=-1)
    return entropy.mean().item()  

In [11]:

attention_entropies = []


with torch.no_grad():
    tokens = model.to_patch_embedding(x) 
    print(f"Patch Embedding shape: {tokens.shape}")  

    for i, block in enumerate(blocks):
   
        attention_layer = block[0].fn  
        norm_layer = block[0].norm  
        tokens = norm_layer(tokens)  
        qkv = attention_layer.to_qkv(tokens).chunk(3, dim=-1)  
        q, k, v = qkv

        
        attn_weights = (q @ k.transpose(-2, -1)) / q.shape[-1]**0.5  
        attn_weights = attn_weights.softmax(dim=-1)  

       
        entropy = compute_attention_entropy(attn_weights)
        attention_entropies.append(entropy)

print("attention_entropies:", attention_entropies)

Patch Embedding shape: torch.Size([1, 64, 192])
attention_entropies: [4.099283218383789, 4.120974540710449, 4.093417167663574, 4.044940948486328, 4.042279243469238, 3.7835772037506104, 3.8343920707702637, 4.019201278686523, 4.102091312408447]


In [12]:
gradient_norms=[]
output = model(x)
loss = output.norm() 
loss.backward()  

for i, block in enumerate(blocks):
    grad_norm = sum(p.grad.norm().item() for p in block.parameters() if p.grad is not None)
    gradient_norms.append(grad_norm)
print("gradient_norms:",gradient_norms)


gradient_norms: [73.81246173381805, 39.30472505092621, 34.22625553607941, 34.91710036993027, 34.91826021671295, 38.420780420303345, 38.29824906587601, 42.031699538230896, 30.326500833034515]


In [13]:
k = 1.0  
mean_entropy, std_entropy = np.mean(attention_entropies), np.std(attention_entropies)
mean_grad, std_grad = np.mean(gradient_norms), np.std(gradient_norms)

T_entropy = mean_entropy - k * std_entropy
T_grad = mean_grad - k * std_grad
print("T_entropy:",T_entropy)
print("T_grad:",T_grad)

prune_blocks = [i for i in range(len(blocks)) if (attention_entropies[i] < T_entropy or gradient_norms[i] < T_grad)]
prune_mask = (attention_entropies < T_entropy) | (gradient_norms < T_grad)
print(prune_mask)
print(f"prune: entropy < {T_entropy:.6f} or grad < {T_grad:.6f}")
print(f"prune block: {prune_blocks}")

T_entropy: 3.9001513263548473
T_grad: 28.55012187405655
[False False False False False  True  True False False]
prune: entropy < 3.900151 or grad < 28.550122
prune block: [5, 6]
