In [1]:
import yaml
import click
import torch
import tqdm.auto
import numpy as np
from torchvision.models import vit_b_16
 
import sys
import os
import re
import math
#project_root = "C:/Users/elmop/deep_feature_reweighting/deep_feature_reweighting/external/pruning_by_explaining"
project_root = "/home/primmere/ide/external/pruning_by_explaining"
sys.path.insert(0, project_root)                 
sys.path.insert(0, os.path.dirname(project_root))

from pruning_by_explaining.models import ModelLoader
from pruning_by_explaining.metrics import compute_accuracy
from pruning_by_explaining.my_metrics import compute_worst_accuracy
from pruning_by_explaining.my_datasets import WaterBirds, get_sample_indices_for_group, WaterBirdSubset, ISIC, ISICSubset


from pruning_by_explaining.pxp import (
    ModelLayerUtils,
    get_cnn_composite,
    get_vit_composite,
)

from pruning_by_explaining.pxp import GlobalPruningOperations
from pruning_by_explaining.pxp import ComponentAttibution

In [2]:
num_workers = 4
device_string = "cuda"
device = torch.device(device_string)
waterbirds = WaterBirds('/scratch_shared/primmere/waterbird', seed = 1, num_workers = num_workers)

train_set = waterbirds.get_train_set()
val_set = waterbirds.get_valid_set()
test_set = waterbirds.get_test_set()

pruning_indices = get_sample_indices_for_group(val_set, 10, device_string, [1,2])
pruning_indices2 = get_sample_indices_for_group(val_set, 10, device_string, [1,2])
validation_indices = get_sample_indices_for_group(test_set, 'all', device_string)


custom_pruning_set = WaterBirdSubset(val_set, pruning_indices)
custom_pruning_set2 = WaterBirdSubset(val_set, pruning_indices2)
custom_val_set = WaterBirdSubset(test_set, validation_indices)

train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=num_workers)
prune_dataloader = torch.utils.data.DataLoader(custom_pruning_set, batch_size=1, shuffle=True, num_workers=num_workers)
prune_dataloader2 = torch.utils.data.DataLoader(custom_pruning_set2, batch_size=1, shuffle=True, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=num_workers)

[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 3518
group 1: 185
group 2: 55
group 3: 1037
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 456
group 1: 456
group 2: 143
group 3: 144
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 2255
group 1: 2255
group 2: 642
group 3: 642
target groups: [1, 2]
target groups: [1, 2]
target groups: [0, 1, 2, 3]


In [3]:
suggested_composite = {
        "low_level_hidden_layer_rule": "Epsilon",
        "mid_level_hidden_layer_rule":"Epsilon",
        "high_level_hidden_layer_rule": "Epsilon",
        "fully_connected_layers_rule": "Epsilon",
        "softmax_rule": "Epsilon",
    }
composite = get_vit_composite("vit_b_16", suggested_composite)

model = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)



Arch:vit_b_16


  loaded_checkpoint = torch.load(checkpoint_path, map_location=device)


In [4]:
layer_types = {
        "Softmax": torch.nn.Softmax,
        "Linear": torch.nn.Linear,
        "Conv2d": torch.nn.Conv2d,
    }

In [5]:
"""
acc_worst, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

"""


"""
0: 0.9937915742793791
1: 0.7835920177383592
2: 0.7461059190031153
3: 0.956386292834891
"""


evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0: 0.9937915742793791
1: 0.7835920177383592
2: 0.7461059190031153
3: 0.956386292834891


'\n0: 0.9937915742793791\n1: 0.7835920177383592\n2: 0.7461059190031153\n3: 0.956386292834891\n'

In [6]:
"""
acc = compute_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc)
"""

evaluating acc:   0%|          | 0/46 [00:00<?, ?it/s]

0.8803935105281325


In [7]:
component_attributor = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types['Softmax'],
        True
    )
print("done")

done


In [8]:
components_relevances = component_attributor.attribute(
        model,
        prune_dataloader,
        composite,
        abs_flag=True,
        device=device,
    )
print("done!")

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


done!


In [9]:
layer_names = component_attributor.layer_names
pruner = GlobalPruningOperations(
        layer_types["Softmax"],
        layer_names,
    )

In [10]:
global_pruning_mask = pruner.generate_global_pruning_mask(
                model,
                components_relevances,
                0.05,
                subsequent_layer_pruning="Both",
                least_relevant_first=True,
                device=device,
            )

In [11]:
global_pruning_mask

OrderedDict([('encoder.layers.encoder_layer_0.self_attention.softmax',
              tensor([11,  4,  5])),
             ('encoder.layers.encoder_layer_1.self_attention.softmax',
              tensor([5, 3])),
             ('encoder.layers.encoder_layer_2.self_attention.softmax',
              tensor([0])),
             ('encoder.layers.encoder_layer_3.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_4.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_5.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_6.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_7.self_attention.softmax',
              tensor([], dtype=torch.int64)),
             ('encoder.layers.encoder_layer_8.self_attention.softmax',
              tensor([], dtype=torc

In [12]:
hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask,)
print(hook_handles)

[<torch.utils.hooks.RemovableHandle object at 0x14d0d006cdf0>]


In [13]:
"""acc_worst, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')
    """


"acc_worst, acc_groups = compute_worst_accuracy(\n        model,\n        val_dataloader,\n        device,\n    )\nfor i in range(4):\n    print(f'{i}: {acc_groups[i]}')\n    "

In [14]:
"""
0: 0.9868421052631579
1: 0.7982456140350878
2: 0.7272727272727273
3: 0.9791666666666666
"""

'\n0: 0.9868421052631579\n1: 0.7982456140350878\n2: 0.7272727272727273\n3: 0.9791666666666666\n'

In [15]:
component_attributor2 = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types['Softmax'],
        True
    )
print("done")

done


In [16]:

# 1. Grab your pruned model’s wrapped state dict
pruned_sd = model.state_dict()

# 2. Strip off any wrapper-module names
clean_sd = {}
for k, v in pruned_sd.items():
    # Drop functorch wrapper layers named “module”
    new_k = re.sub(r'\.module\.', '.', k)
    clean_sd[new_k] = v

# 3. For each attention layer, rebuild the combined in_proj and out_proj keys
D = 768
num_layers = 12
for i in range(num_layers):
    prefix = f"encoder.layers.encoder_layer_{i}.self_attention"

    # stack Q/K/V weights
    w_q = clean_sd.pop(f"{prefix}.q_proj.proj_weight")
    w_k = clean_sd.pop(f"{prefix}.k_proj.proj_weight")
    w_v = clean_sd.pop(f"{prefix}.v_proj.proj_weight")
    clean_sd[f"{prefix}.in_proj_weight"] = torch.cat([w_q, w_k, w_v], dim=0)

    # stack Q/K/V biases
    b_q = clean_sd.pop(f"{prefix}.q_proj.proj_bias")
    b_k = clean_sd.pop(f"{prefix}.k_proj.proj_bias")
    b_v = clean_sd.pop(f"{prefix}.v_proj.proj_bias")
    clean_sd[f"{prefix}.in_proj_bias"] = torch.cat([b_q, b_k, b_v], dim=0)

    # rename out_proj
    clean_sd[f"{prefix}.out_proj.weight"] = clean_sd.pop(f"{prefix}.out_proj.proj_weight")
    clean_sd[f"{prefix}.out_proj.bias"]   = clean_sd.pop(f"{prefix}.out_proj.proj_bias")

# 4. Instantiate fresh ViT-B/16 with 2‐way head
model2 = vit_b_16(weights=False, num_classes=2)

# 5. Load your rebuilt dict (strict=True now that everything matches)
missing, unexpected = model2.load_state_dict(clean_sd, strict=True)
print("missing:", missing)         # should be empty
print("unexpected:", unexpected)   # should be empty

# 6. Done—now you can save and re‐load without wrappers
torch.save(model2.state_dict(), "pruned_vit_b16_2cls_clean.pth")




missing: []
unexpected: []


In [17]:
model2 = ModelLoader.get_basic_model("vit_b_16", "pruned_vit_b16_2cls_clean.pth", device, num_classes=2)
#model2.load_state_dict(torch.load("intermediate_model.pth"))

Arch:vit_b_16


In [18]:
components_relevances2 = component_attributor2.attribute(
        model2,
        prune_dataloader2,
        composite,
        abs_flag=True,
        device=device,
    )
print("done!")

done!


In [19]:
global_pruning_mask2 = pruner.generate_global_pruning_mask(
                model2,
                components_relevances2,
                0.15,
                subsequent_layer_pruning="Softmax",
                least_relevant_first=True,
                device=device,
            )
hook_handles = pruner.fit_pruning_mask(model2, global_pruning_mask2,)


In [20]:
## pruned 10% with prune set 1 and 10% with prune set 2

In [21]:
acc_worst, acc_groups = compute_worst_accuracy(
        model2,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')


evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0: 0.9920177383592018
1: 0.765410199556541
2: 0.7492211838006231
3: 0.9626168224299065


In [22]:
global_pruning_mask = pruner.generate_global_pruning_mask(
                model,
                components_relevances,
                0.15,
                subsequent_layer_pruning="Softmax",
                least_relevant_first=True,
                device=device,
            )
hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask,)


In [23]:
## pruned with prune set 1

In [24]:
acc_worst, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0: 0.9942350332594235
1: 0.7906873614190687
2: 0.7258566978193146
3: 0.9595015576323987


In [25]:
components_relevances

OrderedDict([('encoder.layers.encoder_layer_0.self_attention.softmax',
              tensor([0.7201, 0.4474, 0.5789, 0.9859, 0.3713, 0.3811, 2.8528, 1.5384, 0.6803,
                      0.7912, 1.8463, 0.3318])),
             ('encoder.layers.encoder_layer_1.self_attention.softmax',
              tensor([0.6992, 0.9015, 3.2379, 0.3770, 1.3863, 0.1734, 1.1828, 0.6211, 3.5716,
                      0.5637, 0.5189, 0.7685])),
             ('encoder.layers.encoder_layer_2.self_attention.softmax',
              tensor([0.3301, 1.2126, 0.4874, 0.5708, 1.0122, 3.6033, 1.1563, 0.4195, 1.8890,
                      1.0220, 3.4656, 4.7292])),
             ('encoder.layers.encoder_layer_3.self_attention.softmax',
              tensor([2.0331, 0.8815, 1.7531, 0.8091, 1.7675, 4.7154, 3.8396, 4.4762, 0.7078,
                      1.6967, 1.4860, 1.5935])),
             ('encoder.layers.encoder_layer_4.self_attention.softmax',
              tensor([2.6898, 2.5679, 1.7581, 2.6600, 2.0909, 1.2153, 2.6

In [26]:
model.encoder.layers[0].mlp[0]

Linear(in_features=768, out_features=3072, bias=True)

In [27]:
components_relevances2

OrderedDict([('encoder.layers.encoder_layer_0.self_attention.softmax',
              tensor([0.5133, 0.2549, 0.2424, 0.6850, 0.2456, 0.3346, 1.8293, 0.8924, 0.5772,
                      0.5133, 1.0366, 0.1920])),
             ('encoder.layers.encoder_layer_1.self_attention.softmax',
              tensor([0.2676, 0.4054, 2.0262, 0.2375, 0.7000, 0.0943, 0.5374, 0.3284, 2.0870,
                      0.3281, 0.2365, 0.4580])),
             ('encoder.layers.encoder_layer_2.self_attention.softmax',
              tensor([0.2033, 0.8005, 0.2194, 0.4339, 0.6586, 2.3185, 0.6027, 0.2238, 1.1260,
                      0.6274, 2.4814, 3.0106])),
             ('encoder.layers.encoder_layer_3.self_attention.softmax',
              tensor([1.3536, 0.5563, 1.1859, 0.5837, 1.0808, 3.0699, 2.0863, 2.3576, 0.5441,
                      1.2550, 0.9403, 0.9709])),
             ('encoder.layers.encoder_layer_4.self_attention.softmax',
              tensor([1.7165, 1.4442, 1.1895, 1.6159, 1.2547, 0.5685, 1.6