In [1]:
import yaml
import click
import torch
import tqdm.auto
import numpy as np
from torchvision.models import vit_b_16
 
import sys
import os
import re
import math
#project_root = "C:/Users/elmop/deep_feature_reweighting/deep_feature_reweighting/external/pruning_by_explaining"
project_root = "/home/primmere/ide/external/pruning_by_explaining"
sys.path.insert(0, project_root)                 
sys.path.insert(0, os.path.dirname(project_root))

from pruning_by_explaining.models import ModelLoader
from pruning_by_explaining.metrics import compute_accuracy
from pruning_by_explaining.my_metrics import compute_worst_accuracy
from pruning_by_explaining.my_datasets import WaterBirds, get_sample_indices_for_group, WaterBirdSubset, ISIC, ISICSubset
from pruning_by_explaining.utils import (
    initialize_random_seed,
    initialize_wandb_logger,
)

from pruning_by_explaining.pxp import (
    ModelLayerUtils,
    get_cnn_composite,
    get_vit_composite,
)

from pruning_by_explaining.pxp import GlobalPruningOperations
from pruning_by_explaining.pxp import ComponentAttibution

In [2]:
seed = 1
initialize_random_seed(seed)
num_workers = 12
device_string = "cuda"
device = torch.device(device_string)
waterbirds = WaterBirds('/scratch_shared/primmere/waterbird', seed = 1, num_workers = num_workers)

least_rel_first = True
abs_flag = True
#least_rel_first2 = False
#abs_flag2 = False
#Zplus_flag = True

#scale_bool = True

prune_r = [0.01, 0.05, 0.1, 0.2, 0.3]

           
layer_type = 'Softmax'

train_set = waterbirds.get_train_set()
val_set = waterbirds.get_valid_set()
test_set = waterbirds.get_test_set()

pruning_indices = get_sample_indices_for_group(val_set, 30, device_string, [1])
print("pruning indices:" , len(pruning_indices))

custom_pruning_set = ISICSubset(val_set, pruning_indices)

train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=num_workers)
prune_dataloader = torch.utils.data.DataLoader(custom_pruning_set, batch_size=1, shuffle=True, num_workers=num_workers)
#prune_dataloader2 = torch.utils.data.DataLoader(custom_pruning_set2, batch_size=8, shuffle=True, num_workers=num_workers)
#prune_dataloader3 = torch.utils.data.DataLoader(custom_pruning_set3, batch_size=8, shuffle=True, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=num_workers)

[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 3518
group 1: 185
group 2: 55
group 3: 1037
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 456
group 1: 456
group 2: 143
group 3: 144
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 2255
group 1: 2255
group 2: 642
group 3: 642
target groups: [1]
pruning indices: 30


In [3]:
suggested_composite = {
        "low_level_hidden_layer_rule": "Epsilon",
        "mid_level_hidden_layer_rule":"Epsilon",
        "high_level_hidden_layer_rule": "Epsilon",
        "fully_connected_layers_rule": "Epsilon",
        "softmax_rule": "Epsilon",
    }
composite = get_vit_composite("vit_b_16", suggested_composite)




In [4]:
model = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)


Arch:vit_b_16


  loaded_checkpoint = torch.load(checkpoint_path, map_location=device)


In [5]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [6]:
layer_types = {
        "Softmax": torch.nn.Softmax,
        "Linear": torch.nn.Linear,
        "Conv2d": torch.nn.Conv2d,
    }

In [7]:
"""
acc_worst, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

"""


"""
0: 0.9937915742793791
1: 0.7835920177383592
2: 0.7461059190031153
3: 0.956386292834891
"""


'\n0: 0.9937915742793791\n1: 0.7835920177383592\n2: 0.7461059190031153\n3: 0.956386292834891\n'

In [8]:
component_attributor = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types[layer_type],
        True
    )
print("done")

components_relevances = component_attributor.attribute(
        model,
        prune_dataloader,
        composite,
        abs_flag=True,
        device=device,
    )
print("done!")

layer_names = component_attributor.layer_names
pruner = GlobalPruningOperations(
        layer_types[layer_type],
        layer_names,
    )

global_pruning_mask = pruner.generate_global_pruning_mask(
                model,
                components_relevances,
                0.05,
                subsequent_layer_pruning="Both",
                least_relevant_first=True,
                device=device,
            )

hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask,)
print(hook_handles)

done


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


done!
[<torch.utils.hooks.RemovableHandle object at 0x151ed103f100>]


In [9]:
global_pruning_mask

OrderedDict([('encoder.layers.encoder_layer_0.self_attention.softmax',
              tensor([11,  4])),
             ('encoder.layers.encoder_layer_1.self_attention.softmax',
              tensor([5, 3])),
             ('encoder.layers.encoder_layer_2.self_attention.softmax',
              tensor([2, 7, 0])),
             ('encoder.layers.encoder_layer_3.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_4.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_5.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_6.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_7.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_8.self_attention.softmax',
              tensor([], dtype=to

In [10]:
model2 = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)


Arch:vit_b_16


In [11]:
for k, v in global_pruning_mask.items():
    print(v.detach().numpy())

[11  4]
[5 3]
[2 7 0]
[]
[]
[]
[]
[]
[]
[]
[]
[]


In [12]:
for layer in model2.encoder.layers:
    print(layer.self_attention.out_proj)

NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)


In [13]:
model2.encoder.layers[0]

EncoderBlock(
  (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (self_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): MLPBlock(
    (0): Linear(in_features=768, out_features=3072, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=3072, out_features=768, bias=True)
    (4): Dropout(p=0.0, inplace=False)
  )
)

In [14]:
"""
0: 0.9868421052631579
1: 0.7982456140350878
2: 0.7272727272727273
3: 0.9791666666666666
"""

'\n0: 0.9868421052631579\n1: 0.7982456140350878\n2: 0.7272727272727273\n3: 0.9791666666666666\n'

In [15]:
def mask_head(head):
    def hook(module, _input, output):
        y = output[0]       # y = attn_output (B, N, E)
        hdim = module.head_dim # 64
        y[..., head*hdim:(head+1)*hdim] = 0.0
        return (y, output[1]) if len(output) == 2 else y
    return hook


def apply_mask_softmax(model, mask):
    hook_handles = []
    for layer, (k, v) in zip(model.encoder.layers, mask.items()):
        for head in v.detach().cpu().numpy():
            hook_handles.append(layer.self_attention.register_forward_hook(mask_head(head)))
    return hook_handles

In [16]:
"""
def debug_head(layer, head):
    def dbg_hook(module, _input, output):
        y = output[0]
        hdim = module.head_dim
        sl = y[..., head*hdim:(head+1)*hdim]
        print(f'{layer} head-{head}: max |value| =', sl.abs().max().item())
    return dbg_hook
"""

"\ndef debug_head(layer, head):\n    def dbg_hook(module, _input, output):\n        y = output[0]\n        hdim = module.head_dim\n        sl = y[..., head*hdim:(head+1)*hdim]\n        print(f'{layer} head-{head}: max |value| =', sl.abs().max().item())\n    return dbg_hook\n"

In [17]:
"""
dbg_handles = []
i = 0
for layer in model2.encoder.layers:
    for head in range(12):
        dbg_handles.append(layer.self_attention.register_forward_hook(debug_head(i, head)))
    i += 1

print(len(dbg_handles))
"""

'\ndbg_handles = []\ni = 0\nfor layer in model2.encoder.layers:\n    for head in range(12):\n        dbg_handles.append(layer.self_attention.register_forward_hook(debug_head(i, head)))\n    i += 1\n\nprint(len(dbg_handles))\n'

In [18]:
prev_rel = {k: v.detach().clone() for k, v in components_relevances.items()}
for r in [0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:
    print(r)
    del model
    model = ModelLoader.get_basic_model("vit_b_16", "pruned_vit_b16_2cls_clean.pth", device, num_classes=2)

    hook_handles = apply_mask_softmax(model, global_pruning_mask)
    
    component_attributor = ComponentAttibution(
            "Relevance",
            "ViT",
            layer_types[layer_type],
            True
        )
    
    components_relevances = component_attributor.attribute(
            model,
            prune_dataloader,
            composite,
            abs_flag=True,
            device=device,
        )
    
    
    global_pruning_mask = pruner.generate_global_pruning_mask(
                    model,
                    components_relevances,
                    r,
                    subsequent_layer_pruning="Linear",
                    least_relevant_first=True,
                    device=device,
                )
    for k in components_relevances:
        diff = components_relevances[k] - prev_rel[k]
        print(diff)
    
    hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask)
    
    acc, acc_groups = compute_worst_accuracy(
            model,
            val_dataloader,
            device,
        )
    print(acc)
    for i in range(4):
        print(f'{i}: {acc_groups[i]}')

0.1
Arch:vit_b_16
tensor([ 1.1921e-07,  5.9605e-08,  5.9605e-08,  2.3842e-07, -5.9605e-08,
         5.9605e-08, -4.7684e-07,  2.3842e-07,  0.0000e+00,  1.1921e-07,
         0.0000e+00, -5.9605e-08])
tensor([ 1.1921e-07,  5.9605e-08,  0.0000e+00,  0.0000e+00,  1.1921e-07,
        -2.9802e-08, -3.5763e-07, -5.9605e-08,  0.0000e+00, -5.9605e-08,
        -5.9605e-08,  0.0000e+00])
tensor([-5.9605e-08, -2.3842e-07, -5.9605e-08,  0.0000e+00,  1.1921e-07,
         9.5367e-07,  1.1921e-07, -5.9605e-08,  0.0000e+00,  1.1921e-07,
         0.0000e+00, -4.7684e-07])
tensor([-2.3842e-07,  0.0000e+00, -2.3842e-07,  1.1921e-07, -2.3842e-07,
        -9.5367e-07,  0.0000e+00,  4.7684e-07,  0.0000e+00, -2.3842e-07,
        -2.3842e-07,  2.3842e-07])
tensor([ 4.7684e-07,  4.7684e-07,  4.7684e-07,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  4.7684e-07, -4.7684e-07,  4.7684e-07,
         2.3842e-07,  2.3842e-07])
tensor([ 0.0000e+00,  0.0000e+00,  2.3842e-07, -2.3842e-07,  4.7684e-07,
    

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.882119433897135
0: 0.9933481152993349
1: 0.7906873614190687
2: 0.7383177570093458
3: 0.956386292834891
0.15
Arch:vit_b_16
tensor([ 1.1921e-07,  5.9605e-08, -5.9605e-08,  2.3842e-07, -5.9605e-08,
         1.7881e-07,  0.0000e+00,  2.3842e-07,  1.1921e-07,  0.0000e+00,
         0.0000e+00,  0.0000e+00])
tensor([ 1.1921e-07,  0.0000e+00,  4.7684e-07, -5.9605e-08,  0.0000e+00,
         0.0000e+00, -2.3842e-07,  0.0000e+00,  4.7684e-07, -5.9605e-08,
         5.9605e-08,  0.0000e+00])
tensor([-1.1921e-07, -1.1921e-07,  2.9802e-08,  0.0000e+00,  0.0000e+00,
         4.7684e-07,  0.0000e+00,  5.9605e-08, -2.3842e-07, -1.1921e-07,
        -4.7684e-07,  0.0000e+00])
tensor([-2.3842e-07,  0.0000e+00, -2.3842e-07,  2.3842e-07,  0.0000e+00,
        -1.4305e-06, -9.5367e-07,  0.0000e+00, -1.1921e-07, -2.3842e-07,
        -2.3842e-07,  0.0000e+00])
tensor([4.7684e-07, 4.7684e-07, 0.0000e+00, 0.0000e+00, 2.3842e-07, 1.1921e-07,
        9.5367e-07, 4.7684e-07, 4.7684e-07, 4.7684e-07, 2.3842e-07, 0.00

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.8860890576458406
0: 0.9942350332594235
1: 0.8088691796008869
2: 0.7133956386292835
3: 0.9501557632398754
0.2
Arch:vit_b_16
tensor([ 2.3842e-07,  5.9605e-08,  5.9605e-08,  2.3842e-07, -5.9605e-08,
         0.0000e+00, -4.7684e-07,  2.3842e-07,  1.1921e-07,  1.1921e-07,
        -2.3842e-07,  0.0000e+00])
tensor([ 0.0000e+00,  5.9605e-08,  0.0000e+00,  5.9605e-08,  1.1921e-07,
         2.9802e-08, -2.3842e-07,  0.0000e+00,  0.0000e+00, -1.7881e-07,
         5.9605e-08, -2.3842e-07])
tensor([-5.9605e-08, -2.3842e-07,  0.0000e+00,  0.0000e+00,  1.1921e-07,
         4.7684e-07, -1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
        -4.7684e-07, -9.5367e-07])
tensor([-2.3842e-07,  1.1921e-07,  2.3842e-07,  1.1921e-07, -2.3842e-07,
        -1.4305e-06,  0.0000e+00,  0.0000e+00,  1.1921e-07,  0.0000e+00,
        -4.7684e-07,  2.3842e-07])
tensor([ 4.7684e-07,  2.3842e-07,  2.3842e-07, -4.7684e-07,  0.0000e+00,
         1.1921e-07,  4.7684e-07,  4.7684e-07,  0.0000e+00,  9.5367e-07,
     

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.8705557473248188
0: 0.9929046563192905
1: 0.7600886917960089
2: 0.7383177570093458
3: 0.9610591900311527
0.3
Arch:vit_b_16
tensor([ 1.1921e-07,  1.7881e-07,  5.9605e-08,  2.3842e-07,  0.0000e+00,
         5.9605e-08,  0.0000e+00,  2.3842e-07,  2.3842e-07, -1.1921e-07,
         2.3842e-07,  0.0000e+00])
tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1921e-07,
         0.0000e+00, -2.3842e-07,  0.0000e+00,  4.7684e-07, -5.9605e-08,
         5.9605e-08,  0.0000e+00])
tensor([ 0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00,  1.1921e-07,
         0.0000e+00, -1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -4.7684e-07])
tensor([ 0.0000e+00, -1.1921e-07,  2.3842e-07, -1.1921e-07, -2.3842e-07,
        -9.5367e-07, -4.7684e-07,  0.0000e+00,  1.1921e-07,  0.0000e+00,
         0.0000e+00,  4.7684e-07])
tensor([ 4.7684e-07,  2.3842e-07,  0.0000e+00, -4.7684e-07,  0.0000e+00,
        -1.1921e-07,  4.7684e-07,  4.7684e-07,  0.0000e+00,  4.7684e-07,
     

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.8822920262340352
0: 0.9937915742793791
1: 0.8053215077605321
2: 0.6900311526479751
3: 0.9532710280373832
0.4
Arch:vit_b_16
tensor([ 1.1921e-07,  5.9605e-08,  0.0000e+00,  1.1921e-07, -5.9605e-08,
         0.0000e+00, -4.7684e-07,  2.3842e-07,  0.0000e+00,  0.0000e+00,
        -2.3842e-07,  0.0000e+00])
tensor([ 0.0000e+00, -1.1921e-07,  0.0000e+00,  0.0000e+00,  1.1921e-07,
         0.0000e+00, -3.5763e-07,  0.0000e+00,  4.7684e-07,  5.9605e-08,
         0.0000e+00, -2.3842e-07])
tensor([ 0.0000e+00, -1.1921e-07, -5.9605e-08,  1.1921e-07, -1.1921e-07,
         9.5367e-07, -1.1921e-07,  0.0000e+00, -2.3842e-07,  0.0000e+00,
        -9.5367e-07,  0.0000e+00])
tensor([ 0.0000e+00, -1.1921e-07,  0.0000e+00, -2.3842e-07, -2.3842e-07,
        -9.5367e-07, -4.7684e-07, -9.5367e-07, -1.1921e-07,  2.3842e-07,
        -2.3842e-07,  2.3842e-07])
tensor([9.5367e-07, 9.5367e-07, 0.0000e+00, 1.4305e-06, 4.7684e-07, 1.1921e-07,
        0.0000e+00, 4.7684e-07, 0.0000e+00, 9.5367e-07, 0.0000e+00, 2.3

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.8717638936831205
0: 0.9915742793791574
1: 0.7849223946784922
2: 0.67601246105919
3: 0.9517133956386293
0.5
Arch:vit_b_16
tensor([ 1.1921e-07,  1.1921e-07, -5.9605e-08,  2.3842e-07, -1.1921e-07,
         0.0000e+00, -4.7684e-07,  0.0000e+00,  1.1921e-07,  1.1921e-07,
        -2.3842e-07, -5.9605e-08])
tensor([ 5.9605e-08,  0.0000e+00, -4.7684e-07, -5.9605e-08,  1.1921e-07,
         0.0000e+00, -2.3842e-07,  5.9605e-08, -4.7684e-07,  5.9605e-08,
         0.0000e+00,  1.1921e-07])
tensor([-5.9605e-08, -1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         9.5367e-07,  0.0000e+00,  5.9605e-08,  0.0000e+00, -1.1921e-07,
        -9.5367e-07, -4.7684e-07])
tensor([ 0.0000e+00,  0.0000e+00, -2.3842e-07,  1.1921e-07, -2.3842e-07,
        -9.5367e-07, -4.7684e-07,  0.0000e+00,  0.0000e+00, -4.7684e-07,
        -2.3842e-07,  9.5367e-07])
tensor([ 4.7684e-07,  7.1526e-07,  2.3842e-07,  0.0000e+00,  0.0000e+00,
        -1.1921e-07,  4.7684e-07,  4.7684e-07, -4.7684e-07,  9.5367e-07,
       

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.8816016568864342
0: 0.9831485587583149
1: 0.7955654101995565
2: 0.7632398753894081
3: 0.9454828660436138
0.6
Arch:vit_b_16
tensor([ 1.1921e-07,  1.1921e-07,  1.1921e-07,  2.3842e-07, -5.9605e-08,
         0.0000e+00,  0.0000e+00,  2.3842e-07,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00])
tensor([ 0.0000e+00,  0.0000e+00,  4.7684e-07,  0.0000e+00,  1.1921e-07,
         2.9802e-08, -1.1921e-07,  0.0000e+00,  4.7684e-07,  0.0000e+00,
         0.0000e+00, -2.3842e-07])
tensor([ 5.9605e-08, -1.1921e-07, -5.9605e-08,  0.0000e+00,  1.1921e-07,
         9.5367e-07,  0.0000e+00,  0.0000e+00,  2.3842e-07, -1.1921e-07,
         0.0000e+00, -4.7684e-07])
tensor([-4.7684e-07,  0.0000e+00, -2.3842e-07, -1.1921e-07,  2.3842e-07,
        -4.7684e-07,  0.0000e+00,  0.0000e+00, -1.1921e-07, -2.3842e-07,
        -2.3842e-07,  2.3842e-07])
tensor([ 4.7684e-07,  4.7684e-07,  0.0000e+00,  4.7684e-07,  0.0000e+00,
         0.0000e+00,  4.7684e-07,  4.7684e-07, -4.7684e-07,  4.7684e-07,
     

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.877114256127028
0: 0.9849223946784922
1: 0.8062084257206208
2: 0.6947040498442367
3: 0.9299065420560748
0.7
Arch:vit_b_16
tensor([ 0.0000e+00,  1.7881e-07,  5.9605e-08,  1.1921e-07, -5.9605e-08,
        -5.9605e-08,  0.0000e+00,  2.3842e-07,  0.0000e+00,  2.3842e-07,
        -2.3842e-07, -5.9605e-08])
tensor([ 0.0000e+00, -1.1921e-07, -9.5367e-07,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -1.1921e-07,  5.9605e-08,  0.0000e+00, -1.1921e-07,
         0.0000e+00,  0.0000e+00])
tensor([ 0.0000e+00,  0.0000e+00, -2.9802e-08,  1.1921e-07, -1.1921e-07,
         0.0000e+00,  0.0000e+00, -5.9605e-08,  2.3842e-07, -2.3842e-07,
        -9.5367e-07, -4.7684e-07])
tensor([-4.7684e-07,  1.1921e-07,  0.0000e+00,  0.0000e+00, -2.3842e-07,
        -9.5367e-07, -4.7684e-07,  4.7684e-07,  1.1921e-07, -4.7684e-07,
         2.3842e-07,  4.7684e-07])
tensor([ 4.7684e-07,  2.3842e-07,  0.0000e+00,  0.0000e+00, -2.3842e-07,
        -2.3842e-07,  0.0000e+00,  4.7684e-07, -9.5367e-07,  0.0000e+00,
      

evaluating group acc:   0%|          | 0/91 [00:00<?, ?it/s]

0.3781498101484294
0: 0.3964523281596452
1: 0.05942350332594235
2: 0.82398753894081
3: 0.9875389408099688


In [19]:
"""
for r in [0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:
    print(r)
# 1. Grab your pruned model’s wrapped state dict
    pruned_sd = model.state_dict()
    
    # 2. Strip off any wrapper-module names
    clean_sd = {}
    for k, v in pruned_sd.items():
        # Drop functorch wrapper layers named “module”
        new_k = re.sub(r'\.module\.', '.', k)
        clean_sd[new_k] = v
    
    # 3. For each attention layer, rebuild the combined in_proj and out_proj keys
    D = 768
    num_layers = 12
    for i in range(num_layers):
        prefix = f"encoder.layers.encoder_layer_{i}.self_attention"
    
        # stack Q/K/V weights
        w_q = clean_sd.pop(f"{prefix}.q_proj.proj_weight")
        w_k = clean_sd.pop(f"{prefix}.k_proj.proj_weight")
        w_v = clean_sd.pop(f"{prefix}.v_proj.proj_weight")
        clean_sd[f"{prefix}.in_proj_weight"] = torch.cat([w_q, w_k, w_v], dim=0)
    
        # stack Q/K/V biases
        b_q = clean_sd.pop(f"{prefix}.q_proj.proj_bias")
        b_k = clean_sd.pop(f"{prefix}.k_proj.proj_bias")
        b_v = clean_sd.pop(f"{prefix}.v_proj.proj_bias")
        clean_sd[f"{prefix}.in_proj_bias"] = torch.cat([b_q, b_k, b_v], dim=0)
    
        # rename out_proj
        clean_sd[f"{prefix}.out_proj.weight"] = clean_sd.pop(f"{prefix}.out_proj.proj_weight")
        clean_sd[f"{prefix}.out_proj.bias"]   = clean_sd.pop(f"{prefix}.out_proj.proj_bias")
    
    # 4. Instantiate fresh ViT-B/16 with 2‐way head
    del model
    model = vit_b_16(weights=False, num_classes=2)
    
    # 5. Load your rebuilt dict (strict=True now that everything matches)
    missing, unexpected = model.load_state_dict(clean_sd, strict=True)

    
    # 6. Done—now you can save and re‐load without wrappers
    torch.save(model.state_dict(), "pruned_vit_b16_2cls_clean.pth")
    
    model = ModelLoader.get_basic_model("vit_b_16", "pruned_vit_b16_2cls_clean.pth", device, num_classes=2)
    
    component_attributor2 = ComponentAttibution(
            "Relevance",
            "ViT",
            layer_types[layer_type],
            True
        )
    
    components_relevances2 = component_attributor2.attribute(
            model,
            prune_dataloader,
            composite,
            abs_flag=True,
            device=device,
        )
    
    
    global_pruning_mask2 = pruner.generate_global_pruning_mask(
                    model,
                    components_relevances2,
                    r,
                    subsequent_layer_pruning="Linear",
                    least_relevant_first=True,
                    device=device,
                )
    hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask2,)
    
    acc, acc_groups = compute_worst_accuracy(
            model,
            val_dataloader,
            device,
        )
    print(acc)
    for i in range(4):
        print(f'{i}: {acc_groups[i]}')
"""

'\nfor r in [0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:\n    print(r)\n# 1. Grab your pruned model’s wrapped state dict\n    pruned_sd = model.state_dict()\n    \n    # 2. Strip off any wrapper-module names\n    clean_sd = {}\n    for k, v in pruned_sd.items():\n        # Drop functorch wrapper layers named “module”\n        new_k = re.sub(r\'\\.module\\.\', \'.\', k)\n        clean_sd[new_k] = v\n    \n    # 3. For each attention layer, rebuild the combined in_proj and out_proj keys\n    D = 768\n    num_layers = 12\n    for i in range(num_layers):\n        prefix = f"encoder.layers.encoder_layer_{i}.self_attention"\n    \n        # stack Q/K/V weights\n        w_q = clean_sd.pop(f"{prefix}.q_proj.proj_weight")\n        w_k = clean_sd.pop(f"{prefix}.k_proj.proj_weight")\n        w_v = clean_sd.pop(f"{prefix}.v_proj.proj_weight")\n        clean_sd[f"{prefix}.in_proj_weight"] = torch.cat([w_q, w_k, w_v], dim=0)\n    \n        # stack Q/K/V biases\n        b_q = clean_sd.pop(f"{prefix}.q_

In [20]:
#model2.load_state_dict(torch.load("intermediate_model.pth"))

In [21]:
global_pruning_mask2

NameError: name 'global_pruning_mask2' is not defined

In [None]:
## pruned 10% with prune set 1 and 10% with prune set 2

In [None]:
acc_worst, acc_groups = compute_worst_accuracy(
        model2,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')


In [None]:
global_pruning_mask = pruner.generate_global_pruning_mask(
                model,
                components_relevances,
                0.15,
                subsequent_layer_pruning="Linear",
                least_relevant_first=True,
                device=device,
            )
hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask,)


In [None]:
## pruned with prune set 1

In [None]:
acc_worst, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

In [None]:
components_relevances

In [None]:
model.encoder.layers[0].mlp[0]

In [None]:
components_relevances2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

in_dim     = 2
hidden_dim = 3
out_dim    = 2
batch_size = 1


class SimpleMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim,   hidden_dim, bias=True)
        self.fc2 = nn.Linear(hidden_dim, out_dim, bias=True)

        self.a1 = None
        self.a2 = None

    def forward(self, x):
        self.a1 = F.relu(self.fc1(x))
        self.a2 = F.relu(self.fc2(self.a1))
        return self.a2

torch.manual_seed(42)
net = SimpleMLP(in_dim, hidden_dim, out_dim)

x = torch.randn(batch_size, in_dim)
logits = net(x)


In [None]:
logits = net(x)

print('x = ', x)
print()
print('w1 = ', net.fc1.weight)
print('b1 = ', net.fc1.bias)
print('a1 = ', net.a1)
print()

print('w2 = ', net.fc2.weight)
print('b2 = ', net.fc2.bias)
print('a2 = ', net.a2)
print()

print(logits)