In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from models.swin import SwinTransformer
from PIL import Image
from torchvision import transforms


depths = [2,6,4]
num_heads = [3,6,12]
mlp_ratio = 2
window_size = 4

model=SwinTransformer(img_size=32,embed_dim=96,window_size=window_size,drop_path_rate=0.1,
                      patch_size=2,mlp_ratio=mlp_ratio,depths=depths,num_heads=num_heads,num_classes=100,
                      is_SPT=False, is_LSA=False).cuda()

model.eval()  


image_path = "car.jpg"
image = Image.open(image_path)

transform = transforms.Compose([
    transforms.Resize((32, 32)),  
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
])

x = transform(image).unsqueeze(0).cuda()
print(x.shape)



  from .autonotebook import tqdm as notebook_tqdm
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([1, 3, 32, 32])


In [3]:
import torch
import torchvision.models as models


checkpoint = torch.load("best-swin.pth")

model_weights = checkpoint["model_state_dict"]

#delete "module." 
model_weights = {k.replace("module.", ""): v for k, v in model_weights.items()}

model.load_state_dict(model_weights, strict=False)

<All keys matched successfully>

In [4]:
import torch
import torch.nn.functional as F
import numpy as np

def compute_entropy(tensor):
   
    tensor = F.softmax(tensor, dim=-1) 
    entropy = -torch.sum(tensor * torch.log(tensor + 1e-9), dim=-1)  
    return entropy.mean().item() 

In [5]:
def compute_block_metrics(model, input_tensor, criterion, target):
    entropies, gradient_norms = [], []

    def hook_fn(module, input, output):
        entropy = compute_entropy(output)
        entropies.append(entropy)

    hooks = []
    for layer in model.layers:  
        if hasattr(layer, "blocks"):  
            for block in layer.blocks: 
                hooks.append(block.register_forward_hook(hook_fn))

  
    output = model(input_tensor) 

   
    loss = criterion(output, target)
    loss.backward()

   
    gradient_norms = []
    for layer in model.layers:
        if hasattr(layer, "blocks"):
            for block in layer.blocks:
                grad_norm = 0.0
                for param in block.parameters():
                    if param.grad is not None:
                        grad_norm += torch.norm(param.grad).item()
                gradient_norms.append(grad_norm)

    
    for hook in hooks:
        hook.remove()

    print("entropies",entropies)
    print("gradient_norms",gradient_norms)

    return np.array(entropies), np.array(gradient_norms)

In [6]:

from utils.losses import LabelSmoothingCrossEntropy
target = torch.tensor([29], dtype=torch.long).cuda() 
criterion = LabelSmoothingCrossEntropy()
attention_entropies,gradient_norms= compute_block_metrics(model, x, criterion, target)


entropies [4.445015907287598, 4.449002265930176, 5.062499046325684, 5.0653815269470215, 5.02440071105957, 4.959166049957275, 4.834014892578125, 4.724241733551025, 5.688312530517578, 5.625767707824707, 5.506977081298828, 5.411611557006836]
gradient_norms [55.633842304348946, 41.452643886208534, 39.26008752733469, 40.059071816504, 39.80225479602814, 36.138844415545464, 35.21963840723038, 27.96538368612528, 73.30160504579544, 69.61833009123802, 79.04520133137703, 81.44851377606392]


In [7]:
k = 1.0
mean_entropy, std_entropy = np.mean(attention_entropies), np.std(attention_entropies)
mean_grad, std_grad = np.mean(gradient_norms), np.std(gradient_norms)

T_entropy = mean_entropy - k * std_entropy
T_grad = mean_grad - k * std_grad
print("T_entropy:",T_entropy)
print("T_grad:",T_grad)

prune_mask = (attention_entropies < T_entropy) | (gradient_norms < T_grad)
print("prune_mask:",prune_mask)



T_entropy: 4.661746839141438
T_grad: 33.20842826505502
prune_mask: [ True  True False False False False False  True False False False False]
