In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from models.t2t import T2T_ViT

from PIL import Image
from torchvision import transforms


model = T2T_ViT(img_size=32, num_classes=100,num_heads=4,depth=12, drop_path_rate=0.1).cuda()
model.eval()  



image_path = "car.jpg"
image = Image.open(image_path)

transform = transforms.Compose([
    transforms.Resize((32, 32)), 
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

x = transform(image).unsqueeze(0).cuda()
print(x.shape)


  from .autonotebook import tqdm as notebook_tqdm


adopt transformer encoder for tokens-to-token
torch.Size([1, 3, 32, 32])


In [2]:
import torch
import torchvision.models as models


checkpoint = torch.load("best-t2t.pth")

model_weights = checkpoint["model_state_dict"]

model_weights = {k.replace("module.", ""): v for k, v in model_weights.items()}

model.load_state_dict(model_weights, strict=False)

_IncompatibleKeys(missing_keys=['query_head.decoder_blocks.0.self_attn.in_proj_weight', 'query_head.decoder_blocks.0.self_attn.in_proj_bias', 'query_head.decoder_blocks.0.self_attn.out_proj.weight', 'query_head.decoder_blocks.0.self_attn.out_proj.bias'], unexpected_keys=['query_head.norm.weight', 'query_head.norm.bias', 'query_head.decoder_blocks.0.fc1.weight', 'query_head.decoder_blocks.0.fc1.bias', 'query_head.decoder_blocks.0.fc2.weight', 'query_head.decoder_blocks.0.fc2.bias'])

In [3]:

def compute_entropy(tensor):
   
    tensor = F.softmax(tensor, dim=-1) 
    entropy = -torch.sum(tensor * torch.log(tensor + 1e-9), dim=-1)  
    return entropy.mean().item()  

In [4]:
def hook_fn(module, input, output, entropies):
   

    B, N, C = input[0].shape  # B: batch size, N: number of tokens, C: channels
    try:
        # Adjust the way we handle the qkv computation
        qkv = module.qkv(input[0])  # Directly compute QKV without reshaping
        #print(f"QKV shape: {qkv.shape}")  
        
        # Split the qkv tensor directly
        q, k, v = qkv.chunk(3, dim=-1)
        #print(f"Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}") 

       
        attn_matrix = (q @ k.transpose(-2, -1)) / (C // 3) ** 0.5
        attn_matrix = attn_matrix.softmax(dim=-1) 

       
        entropy = compute_entropy(attn_matrix)
        
       
        if not isinstance(entropy, torch.Tensor):
            entropy = torch.tensor(entropy, dtype=torch.float32)

        entropies.append(entropy.detach().cpu().numpy()) 
    except Exception as e:
        print(f"QKV computation error: {e}")  

In [5]:
# Hook function used to gather attention entropy and gradients
def compute_block_metrics(model, input_tensor, criterion, target):
    entropies, gradient_norms = [], []

    hooks = []
    for name, module in model.named_modules():
        #if isinstance(module,Attention)and 'blocks' in name:
        if "blocks" in name and type(module).__name__ == "Attention":
            hooks.append(module.register_forward_hook(lambda module, input, output: hook_fn(module, input, output, entropies)))

    
    output = model(input_tensor)  

   
    loss = criterion(output, target)
    loss.backward()

   
    for block in model.blocks:  
        grad_norm = sum(torch.norm(param.grad).item() for param in block.parameters() if param.grad is not None)
        gradient_norms.append(grad_norm)

    
    for hook in hooks:
        hook.remove()

    return np.array(entropies), np.array(gradient_norms)

In [6]:
def prune_t2t(model, input_tensor, criterion, target, k=1.0):
   
    attention_entropies, gradient_norms = compute_block_metrics(model, input_tensor, criterion, target)
    print("Attention Entropies:", attention_entropies)
    print("Gradient Norms:", gradient_norms)

    
    mean_entropy, std_entropy = np.mean(attention_entropies), np.std(attention_entropies)
    mean_grad, std_grad = np.mean(gradient_norms), np.std(gradient_norms)

    T_entropy = mean_entropy - k * std_entropy
    T_grad = mean_grad - k * std_grad
    print("T_entropy:",T_entropy)
    print("T_grad:",T_grad)

   
    prune_mask = (attention_entropies < T_entropy) | (gradient_norms < T_grad)
    print(prune_mask)

    
    block_idx = 0
    for layer_idx, layer in enumerate(model.blocks):  
        if prune_mask[block_idx]: 
           
            model.blocks[layer_idx] = torch.nn.Identity()  
            print(f"Pruning Block {block_idx} in Layer {layer_idx}")  
        block_idx += 1

    return model

In [7]:

from utils.losses import LabelSmoothingCrossEntropy
target = torch.tensor([19], dtype=torch.long).cuda() 
criterion = LabelSmoothingCrossEntropy()



pruned_model = prune_t2t(model, x, criterion, target, k=1.0)
print(pruned_model)

Attention Entropies: [4.172654  4.171084  4.1708627 4.1707306 4.1712217 4.1710362 4.171399
 4.172064  4.1728396 4.1726694 4.17182   4.171971 ]
Gradient Norms: [18.97220428 13.66780362 10.21212132  8.30563916  7.57110652  5.90074862
  5.48550645  4.84765517  4.62953659  4.09240712  3.70700439  3.86960079]
T_entropy: 4.17098034278024
T_grad: 3.1460474519318584
[False False  True  True False False False False False False False False]
Pruning Block 2 in Layer 2
Pruning Block 3 in Layer 3
T2T_ViT(
  (tokens_to_token): T2T_module(
    (soft_split0): Unfold(kernel_size=(3, 3), dilation=1, padding=(1, 1), stride=(2, 2))
    (soft_split1): Unfold(kernel_size=(3, 3), dilation=1, padding=(1, 1), stride=(2, 2))
    (attention1): Token_transformer(
      (norm1): LayerNorm((27,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=27, out_features=192, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=64, out