In [1]:
import math
import sys
import os

from functools import reduce

import re

import numpy as np
from matplotlib import pyplot as plt
from pathlib import Path

ROOT = Path.cwd().resolve().parent
print(ROOT)
sys.path.append(str(ROOT))
sys.path.append(os.path.join(ROOT, 'external', 'wdpruning'))
sys.path.append(os.path.join(ROOT, 'external', 'pruning_by_explaining'))

from my_scripts.utils import visualize


from external.wdpruning.vit_wdpruning import VisionTransformerWithWDPruning
from external.pruning_by_explaining.pxp import GlobalPruningOperations, ComponentAttibution, get_vit_composite, \
    ModelLayerUtils
from external.pruning_by_explaining.models import ModelLoader
from external.pruning_by_explaining.my_datasets import WaterBirds, WaterBirdDataset

from my_utils import evaluate, evaluate_gradients
from ISIC_ViT.isic_data import ISICDataset
from external.dfr.wb_data import WaterBirdsDataset

import argparse

import torchvision
import torch
from torchvision.models import vit_b_16
import torch.nn as nn

/home/primmere/ide


In [2]:
pruning_mask = '/home/primmere/logs/pxp/results/2L/0.01.pth'
data_dir = '/scratch_shared/primmere/waterbird'

batch_size = 32
num_workers = 4
device = torch.device("cuda")
model_path = '/home/primmere/ide/dfr/logs/vit_waterbirds.pth'

In [3]:
model = vit_b_16(weights=None, num_classes=2)
model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
model.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [4]:
transform = torchvision.transforms.Compose([
                torchvision.transforms.Resize(256),
                torchvision.transforms.CenterCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

In [5]:
test_dataset = WaterBirdsDataset(basedir=data_dir,
                                     transform=transform,
                                     split="test")
val_dataset = WaterBirdsDataset(basedir=data_dir,
                                    transform=transform,
                                    split="val")
train_dataset = WaterBirdsDataset(basedir=data_dir,
                                      transform=transform,
                                      split="train")

loader_kwargs = {'batch_size': batch_size, 'num_workers': num_workers, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, **loader_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, **loader_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, **loader_kwargs)
mask = torch.load(pruning_mask, map_location=device)


tensor([2255., 2255.,  642.,  642.])
tensor([456., 456., 143., 144.])
tensor([3518.,  185.,   55., 1037.])


  mask = torch.load(pruning_mask, map_location=device)


In [6]:
for k in mask.keys():
    print(mask[k]['Linear']['weight'])

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0')
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0')
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0')
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1

In [7]:
def count_zero_weights_mask(mask):
    """
    return a dict mapping layer names → (zero_count, total_params, zero_ratio).
    """
    zero_stats = {}
    for k in mask.keys():
        tensor = mask[k]['Linear']['weight']
        tensor = tensor.detach()
        total = tensor.numel()
        zeros = int((tensor == 0).sum().item())
        ratio = zeros / total
        zero_stats[k] = (zeros, total, ratio)
    return zero_stats

In [8]:
zero_stats = count_zero_weights_mask(mask)
for k in zero_stats.keys():
    print(k)

encoder.layers.encoder_layer_0.mlp.0
encoder.layers.encoder_layer_0.mlp.3
encoder.layers.encoder_layer_1.mlp.0
encoder.layers.encoder_layer_1.mlp.3
encoder.layers.encoder_layer_2.mlp.0
encoder.layers.encoder_layer_2.mlp.3
encoder.layers.encoder_layer_3.mlp.0
encoder.layers.encoder_layer_3.mlp.3
encoder.layers.encoder_layer_4.mlp.0
encoder.layers.encoder_layer_4.mlp.3
encoder.layers.encoder_layer_5.mlp.0
encoder.layers.encoder_layer_5.mlp.3
encoder.layers.encoder_layer_6.mlp.0
encoder.layers.encoder_layer_6.mlp.3
encoder.layers.encoder_layer_7.mlp.0
encoder.layers.encoder_layer_7.mlp.3
encoder.layers.encoder_layer_8.mlp.0
encoder.layers.encoder_layer_8.mlp.3
encoder.layers.encoder_layer_9.mlp.0
encoder.layers.encoder_layer_9.mlp.3
encoder.layers.encoder_layer_10.mlp.0
encoder.layers.encoder_layer_10.mlp.3
encoder.layers.encoder_layer_11.mlp.0
encoder.layers.encoder_layer_11.mlp.3


In [9]:
visualize(mask, pruning_mask, prune_framework="pxp")

PXP zero-weight ratios:
 [[0.         0.         0.         0.        ]
 [0.         0.         0.10026042 0.        ]
 [0.         0.         0.00358073 0.        ]
 [0.         0.         0.02636719 0.        ]
 [0.         0.         0.00325521 0.        ]
 [0.         0.         0.00032552 0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.00032552 0.        ]
 [0.         0.         0.0061849  0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.00195312 0.        ]
 [0.         0.         0.00748698 0.        ]]
Plot saved to /home/primmere/logs/pxp/results/2L/prune_ratios.png
