In [1]:
import yaml
import click
import torch
import tqdm.auto
import numpy as np
from torchvision.models import vit_b_16

from matplotlib import pyplot as plt


import sys
import os
import re
import math
#project_root = "C:/Users/elmop/deep_feature_reweighting/deep_feature_reweighting/external/pruning_by_explaining"
project_root = "/home/primmere/ide/external/pruning_by_explaining"
sys.path.insert(0, project_root)                 
sys.path.insert(0, os.path.dirname(project_root))

from pruning_by_explaining.models import ModelLoader
from pruning_by_explaining.metrics import compute_accuracy
from pruning_by_explaining.my_metrics import compute_worst_accuracy
from pruning_by_explaining.my_datasets import WaterBirds, get_sample_indices_for_group, WaterBirdSubset, ISIC, ISICSubset
from pruning_by_explaining.utils import (
    initialize_random_seed,
    initialize_wandb_logger,
)


from pruning_by_explaining.pxp import (
    ModelLayerUtils,
    get_cnn_composite,
    get_vit_composite,
)

from pruning_by_explaining.pxp import GlobalPruningOperations
from pruning_by_explaining.pxp import ComponentAttibution


In [2]:
def visualise(global_pruning_mask_combined, prune_r, layertype = "Linear"):
    if layertype == "Linear":
        count=0
        ratios = np.zeros((12,2))
        i = 0
        for n, t in global_pruning_mask_combined.items():
            param_total = t['Linear']['weight'].numel()
            param_nonzero = t['Linear']['weight'].nonzero().size(0)
            param_shape = t['Linear']['weight'].shape
        
            pruned = (param_total-param_nonzero)/param_shape[1]
            total = param_total/param_shape[1]
        
            if 'mlp.0' in n:
                ratios[i][0] = pruned/total
            if 'mlp.3' in n:
                ratios[i][1] = pruned/total
                i += 1
        
        #print(array)
        avgs = np.sum(ratios, axis=0)/12
        
        bar_width = 0.23
        x = np.arange(12)
        
        labels = [
            f"fc1 r={avgs[0]:.4f}",
            f"fc2 r={avgs[1]:.4f}",
        ]
        offsets = (np.arange(ratios.shape[1]) - (ratios.shape[1] - 1) / 2) * bar_width
        
        for j, offset in enumerate(offsets):
            plt.bar(x + offset, ratios[:, j], width=bar_width, label=labels[j])
        
        plt.xlabel("Transformer block idx")
        plt.ylabel("Prune ratio")
        plt.title(f'r = {prune_r}')
        plt.xticks(x, [str(i) for i in range(12)])
        plt.legend()
        plt.tight_layout()
        
        save_dir = ""
        save_path = os.path.join(save_dir, f"prune_ratios{prune_r}.png")
        plt.savefig(save_path)
        #print(f"Plot saved to {save_path}")
        #plt.show()
        plt.close()
    if layertype == "Softmax":
        ratios = np.zeros((12,1))
        i = 0
        for v in global_pruning_mask_combined.values():
            ratios[i][0] = len(v.detach().numpy())/12
            i += 1

        bar_width = 0.23
        x = np.arange(12)
        
        labels = [
            f"Attention head",
        ]
        offsets = (np.arange(ratios.shape[1]) - (ratios.shape[1] - 1) / 2) * bar_width
        
        for j, offset in enumerate(offsets):
            plt.bar(x + offset, ratios[:, j], width=bar_width, label=labels[j])
        
        plt.xlabel("Transformer block idx")
        plt.ylabel("Prune ratio")
        plt.title(f'r = {prune_r}')
        plt.xticks(x, [str(i) for i in range(12)])
        plt.legend()
        plt.tight_layout()
        
        save_dir = ""
        save_path = os.path.join(save_dir, f"prune_ratios{prune_r}.png")
        plt.savefig(save_path)
        #print(f"Plot saved to {save_path}")
        plt.close()

In [3]:
initialize_random_seed(1)
num_workers = 8
device_string = "cuda"
device = torch.device(device_string)
waterbirds = WaterBirds('/scratch_shared/primmere/waterbird', seed = 1, num_workers = num_workers)
least_rel_first = True
abs_flag = True
least_rel_first2 = False
abs_flag2 = False
Zplus_flag = True

scale_bool = True

prune_r = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]



layer_type = 'Softmax'

train_set = waterbirds.get_train_set()
val_set = waterbirds.get_valid_set()
test_set = waterbirds.get_test_set()

pruning_indices = get_sample_indices_for_group(val_set, 30, device_string, [0,1,2,3])
pruning_indices2 = get_sample_indices_for_group(val_set, 30, device_string, [1])
pruning_indices3 = get_sample_indices_for_group(val_set, 30, device_string, [2])
validation_indices = get_sample_indices_for_group(test_set, 'all', device_string)


custom_pruning_set = WaterBirdSubset(val_set, pruning_indices)
custom_pruning_set2 = WaterBirdSubset(val_set, pruning_indices2)
custom_pruning_set3 = WaterBirdSubset(val_set, pruning_indices3)

custom_val_set = WaterBirdSubset(test_set, validation_indices)

train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=num_workers)
prune_dataloader = torch.utils.data.DataLoader(custom_pruning_set, batch_size=8, shuffle=True, num_workers=num_workers)
prune_dataloader2 = torch.utils.data.DataLoader(custom_pruning_set2, batch_size=8, shuffle=True, num_workers=num_workers)
prune_dataloader3 = torch.utils.data.DataLoader(custom_pruning_set3, batch_size=8, shuffle=True, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=num_workers)

[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 3518
group 1: 185
group 2: 55
group 3: 1037
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 456
group 1: 456
group 2: 143
group 3: 144
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 2255
group 1: 2255
group 2: 642
group 3: 642
target groups: [0, 1, 2, 3]
target groups: [1]
target groups: [2]
target groups: [0, 1, 2, 3]


In [4]:
suggested_composite = {
        "low_level_hidden_layer_rule": "Epsilon",
        "mid_level_hidden_layer_rule":"Epsilon",
        "high_level_hidden_layer_rule": "Epsilon",
        "fully_connected_layers_rule": "Epsilon",
        "softmax_rule": "Epsilon",
    }
composite = get_vit_composite("vit_b_16", suggested_composite)
layer_types = {
        "Softmax": torch.nn.Softmax,
        "Linear": torch.nn.Linear,
        "Conv2d": torch.nn.Conv2d,
    }

In [5]:
model = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)
model2 = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)
model3 = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)



Arch:vit_b_16


  loaded_checkpoint = torch.load(checkpoint_path, map_location=device)


Arch:vit_b_16
Arch:vit_b_16


In [6]:
"""
acc, acc_groups = compute_worst_accuracy(
        model,
        train_dataloader,
        device,
    )
print(acc)
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

"""


"""
0: 0.9937915742793791
1: 0.7835920177383592
2: 0.7461059190031153
3: 0.956386292834891
"""


'\n0: 0.9937915742793791\n1: 0.7835920177383592\n2: 0.7461059190031153\n3: 0.956386292834891\n'

In [7]:
"""
acc = compute_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc)
"""

'\nacc = compute_accuracy(\n        model,\n        val_dataloader,\n        device,\n    )\nprint(acc)\n'

In [8]:
component_attributor = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types[layer_type],
        least_rel_first
    )

components_relevances = component_attributor.attribute(
        model,
        prune_dataloader,
        composite,
        abs_flag=abs_flag,
        Zplus_flag=False,
        device=device,
    )

layer_names = component_attributor.layer_names
pruner = GlobalPruningOperations(
        layer_types[layer_type],
        layer_names,
    )

global_pruning_mask = pruner.generate_global_pruning_mask(
                model,
                components_relevances,
                0.1,
                subsequent_layer_pruning=layer_type,
                least_relevant_first=least_rel_first,
                device=device,
            )
print("done!")



  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


done!


In [9]:
#hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask,)
#print(hook_handles)

In [10]:
component_attributor2 = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types[layer_type],
        least_rel_first2
    )

components_relevances2 = component_attributor2.attribute(
        model2,
        prune_dataloader2,
        composite,
        abs_flag=abs_flag2,
        Zplus_flag=Zplus_flag,
        device=device,
    )
layer_names2 = component_attributor.layer_names
pruner2 = GlobalPruningOperations(
        layer_types[layer_type],
        layer_names2,
    )

global_pruning_mask2 = pruner2.generate_global_pruning_mask(
                model2,
                components_relevances2,
                0.1,
                subsequent_layer_pruning=layer_type,
                least_relevant_first=least_rel_first,
                device=device,
            )
print("done!")

done!


In [11]:
component_attributor3 = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types[layer_type],
        least_rel_first2
    )

components_relevances3 = component_attributor3.attribute(
        model3,
        prune_dataloader3,
        composite,
        abs_flag=abs_flag2,
        Zplus_flag=Zplus_flag,
        device=device,
    )
layer_names3 = component_attributor.layer_names
pruner3 = GlobalPruningOperations(
        layer_types[layer_type],
        layer_names3,
    )

global_pruning_mask3 = pruner2.generate_global_pruning_mask(
                model3,
                components_relevances3,
                0.1,
                subsequent_layer_pruning=layer_type,
                least_relevant_first=least_rel_first,
                device=device,
            )
print("done!")

done!


In [12]:
#hook_handles = pruner.fit_pruning_mask(model2, global_pruning_mask2,)
#print(hook_handles)

In [13]:
"""acc_worst, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')
    """


"acc_worst, acc_groups = compute_worst_accuracy(\n        model,\n        val_dataloader,\n        device,\n    )\nfor i in range(4):\n    print(f'{i}: {acc_groups[i]}')\n    "

In [14]:
"""
0: 0.9868421052631579
1: 0.7982456140350878
2: 0.7272727272727273
3: 0.9791666666666666
"""

'\n0: 0.9868421052631579\n1: 0.7982456140350878\n2: 0.7272727272727273\n3: 0.9791666666666666\n'

In [15]:
"""
acc_worst, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')
"""


"\nacc_worst, acc_groups = compute_worst_accuracy(\n        model,\n        val_dataloader,\n        device,\n    )\nfor i in range(4):\n    print(f'{i}: {acc_groups[i]}')\n"

In [16]:
"""
acc_worst, acc_groups = compute_worst_accuracy(
        model2,
        val_dataloader,
        device,
    )
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

"""

"\nacc_worst, acc_groups = compute_worst_accuracy(\n        model2,\n        val_dataloader,\n        device,\n    )\nfor i in range(4):\n    print(f'{i}: {acc_groups[i]}')\n\n"

In [17]:
scale = len(pruning_indices)
if scale_bool:
    for t in components_relevances.values():
        t.div_(scale)

In [18]:
"""
for k,v in components_relevances.items():
    if "mlp.3" in k:
        v.div_(4)

for k,v in components_relevances2.items():
    if "mlp.3" in k:
        v.div_(4)

for k,v in components_relevances3.items():
    if "mlp.3" in k:
        v.div_(4)

"""

'\nfor k,v in components_relevances.items():\n    if "mlp.3" in k:\n        v.div_(4)\n\nfor k,v in components_relevances2.items():\n    if "mlp.3" in k:\n        v.div_(4)\n\nfor k,v in components_relevances3.items():\n    if "mlp.3" in k:\n        v.div_(4)\n\n'

In [19]:
scale2 = len(pruning_indices2)
scale3 = len(pruning_indices3)
if scale_bool :
    for t in components_relevances2.values():
        t.div_(scale2*2)
    for t in components_relevances3.values():
        t.div_(scale3*2)

In [20]:
print(scale, scale2)

120 30


In [21]:
for k in components_relevances2.keys():
    #print(torch.mean(v))
    a = torch.mean(components_relevances[k]).item()
    b = torch.mean(components_relevances2[k]).item()
    c = torch.mean(components_relevances3[k]).item()
    print(f'{a:.3f}, {b:.3f}, {c:.3f}')

0.055, 0.016, 0.018
0.069, 0.015, 0.020
0.111, 0.027, 0.027
0.143, 0.030, 0.033
0.154, 0.037, 0.037
0.146, 0.035, 0.038
0.162, 0.033, 0.037
0.126, 0.027, 0.030
0.140, 0.027, 0.033
0.080, 0.016, 0.022
0.054, 0.013, 0.013
0.080, 0.013, 0.018


In [22]:
for name, R in components_relevances.items():
    print(name, torch.sum(R).item())

encoder.layers.encoder_layer_0.self_attention.softmax 0.6642364263534546
encoder.layers.encoder_layer_1.self_attention.softmax 0.833003044128418
encoder.layers.encoder_layer_2.self_attention.softmax 1.3342777490615845
encoder.layers.encoder_layer_3.self_attention.softmax 1.7167727947235107
encoder.layers.encoder_layer_4.self_attention.softmax 1.8532968759536743
encoder.layers.encoder_layer_5.self_attention.softmax 1.7472929954528809
encoder.layers.encoder_layer_6.self_attention.softmax 1.9488774538040161
encoder.layers.encoder_layer_7.self_attention.softmax 1.5074082612991333
encoder.layers.encoder_layer_8.self_attention.softmax 1.6787011623382568
encoder.layers.encoder_layer_9.self_attention.softmax 0.9655097723007202
encoder.layers.encoder_layer_10.self_attention.softmax 0.6515893936157227
encoder.layers.encoder_layer_11.self_attention.softmax 0.9574383497238159


In [23]:
"""
combined_relevances = {}
for (k, v), (k2, v2) in zip(components_relevances.items(), components_relevances2.items()):
    combined_relevances[k] = v-v2

"""
    

'\ncombined_relevances = {}\nfor (k, v), (k2, v2) in zip(components_relevances.items(), components_relevances2.items()):\n    combined_relevances[k] = v-v2\n\n'

In [24]:
combined_relevances = {}
check = True
for (k1, v1), (k2, v2), (k3, v3) in zip(
        components_relevances.items(), components_relevances2.items(), components_relevances3.items()
):
    check = check & (k1==k2==k3)
    combined_relevances[k1] = v1 - v3
print(check)

True


In [25]:
for r in prune_r:
    global_pruning_mask_combined = pruner.generate_global_pruning_mask(
                    model,
                    combined_relevances,
                    r,
                    subsequent_layer_pruning=layer_type,
                    least_relevant_first=True,
                    device=device,
                )
    hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask_combined,)
    acc, acc_groups = compute_worst_accuracy(
            model,
            val_dataloader,
            device,
        )
    print(f'{r} & {acc:.3f} & {acc_groups[0]:.3f} & {acc_groups[1]:.3f} & {acc_groups[2]:.3f} & {acc_groups[3]:.3f} \\\\')
    visualise(global_pruning_mask_combined, r, layer_type)

evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0.01 & 0.876 & 0.993 & 0.766 & 0.763 & 0.960 \\


evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0.05 & 0.875 & 0.993 & 0.766 & 0.760 & 0.958 \\


evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0.1 & 0.862 & 0.995 & 0.756 & 0.687 & 0.949 \\


evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0.2 & 0.849 & 0.992 & 0.706 & 0.737 & 0.956 \\


evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0.3 & 0.866 & 0.989 & 0.750 & 0.749 & 0.958 \\


evaluating group acc:   0%|          | 0/46 [00:00<?, ?it/s]

0.4 & 0.876 & 0.991 & 0.805 & 0.654 & 0.945 \\


evaluating group acc:   0%|          | 0/46 [00:40<?, ?it/s]

0.5 & 0.870 & 0.990 & 0.795 & 0.646 & 0.939 \\


In [26]:
hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask_combined,)

In [27]:
count=0
for n, t in global_pruning_mask_combined.items():
    param_total = t['Linear']['weight'].numel()
    param_nonzero = t['Linear']['weight'].nonzero().size(0)
    param_shape = t['Linear']['weight'].shape
    
    pruned = (param_total-param_nonzero)/param_shape[1]
    total = param_total/param_shape[1]
    #print(pruned, "/", total, "pruned,", pruned/total)
    print(f'{100*pruned/total:.1f}% - {pruned} / {total} pruned')
    

IndexError: too many indices for tensor of dimension 1

In [None]:
"""
acc, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc)
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

"""

In [None]:
#global_pruning_mask

In [None]:
#global_pruning_mask2

In [None]:
for v, v2 in zip(components_relevances.values(), components_relevances2.values()):
    print(v[0:8])
    print(v2[0:8])
    print(v[0:8]-v2[0:8])
    print("--")

In [None]:
print(f'{prune_r},{acc:.3f},{acc_groups[0]:.3f},{acc_groups[1]:.3f},{acc_groups[2]:.3f},{acc_groups[3]:.3f}')

In [None]:
print(f'{r} & {acc:.3f} & {acc_groups[0]:.3f} & {acc_groups[1]:.3f} & {acc_groups[2]:.3f} & {acc_groups[3]:.3f} \\\\')

In [None]:
layer_type

In [None]:
visualise(global_pruning_mask_combined, r, layer_type)