In [1]:
import yaml
import click
import torch
import tqdm.auto
import numpy as np
from torchvision.models import vit_b_16

from matplotlib import pyplot as plt


import sys
import os
import re
import math
#project_root = "C:/Users/elmop/deep_feature_reweighting/deep_feature_reweighting/external/pruning_by_explaining"
project_root = "/home/primmere/ide/external/pruning_by_explaining"
sys.path.insert(0, project_root)                 
sys.path.insert(0, os.path.dirname(project_root))

from pruning_by_explaining.models import ModelLoader
from pruning_by_explaining.metrics import compute_accuracy
from pruning_by_explaining.my_metrics import compute_worst_accuracy
from pruning_by_explaining.my_datasets import WaterBirds, get_sample_indices_for_group, WaterBirdSubset, ISIC, ISICSubset
from pruning_by_explaining.utils import (
    initialize_random_seed,
    initialize_wandb_logger,
)


from pruning_by_explaining.pxp import (
    ModelLayerUtils,
    get_cnn_composite,
    get_vit_composite,
)

from pruning_by_explaining.pxp import GlobalPruningOperations
from pruning_by_explaining.pxp import ComponentAttibution


In [2]:
def visualise(global_pruning_mask_combined, prune_r, layertype = "Linear", save = True):
    if layertype == "Linear":
        count=0
        ratios = np.zeros((12,2))
        i = 0
        for n, t in global_pruning_mask_combined.items():
            param_total = t['Linear']['weight'].numel()
            param_nonzero = t['Linear']['weight'].nonzero().size(0)
            param_shape = t['Linear']['weight'].shape
        
            pruned = (param_total-param_nonzero)/param_shape[1]
            total = param_total/param_shape[1]
        
            if 'mlp.0' in n:
                ratios[i][0] = pruned/total
            if 'mlp.3' in n:
                ratios[i][1] = pruned/total
                i += 1
        
        #print(array)
        avgs = np.sum(ratios, axis=0)/12
        
        bar_width = 0.23
        x = np.arange(12)
        
        labels = [
            f"fc1 r={avgs[0]:.4f}",
            f"fc2 r={avgs[1]:.4f}",
        ]
        offsets = (np.arange(ratios.shape[1]) - (ratios.shape[1] - 1) / 2) * bar_width
        
        for j, offset in enumerate(offsets):
            plt.bar(x + offset, ratios[:, j], width=bar_width, label=labels[j])
        
        plt.xlabel("Transformer block idx")
        plt.ylabel("Prune ratio")
        plt.title(f'r = {prune_r}')
        plt.xticks(x, [str(i) for i in range(12)])
        plt.legend()
        plt.tight_layout()
        
        save_dir = ""
        save_path = os.path.join(save_dir, f"prune_ratios{prune_r}.png")
        if save:
            plt.savefig(save_path)
        else:
            plt.show()
        plt.close()
    if layertype == "Softmax":
        ratios = np.zeros((12,1))
        i = 0
        for v in global_pruning_mask_combined.values():
            ratios[i][0] = len(v.detach().numpy())/12
            i += 1

        bar_width = 0.23
        x = np.arange(12)
        
        labels = [
            f"Attention head",
        ]
        offsets = (np.arange(ratios.shape[1]) - (ratios.shape[1] - 1) / 2) * bar_width
        
        for j, offset in enumerate(offsets):
            plt.bar(x + offset, ratios[:, j], width=bar_width, label=labels[j])
        
        plt.xlabel("Transformer block idx")
        plt.ylabel("Prune ratio")
        plt.title(f'r = {prune_r}')
        plt.xticks(x, [str(i) for i in range(12)])
        plt.legend()
        plt.tight_layout()
        save_dir = ""
        save_path = os.path.join(save_dir, f"prune_ratios{prune_r}.png")
                if save:
            plt.savefig(save_path)
        else:
            plt.show()
        plt.close()

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 76)

In [3]:
seed = 1
initialize_random_seed(seed)
num_workers = 12
device_string = "cuda"
device = torch.device(device_string)
isic = ISIC(
    "/scratch_shared/primmere/isic/isic_224/raw_224_with_selected", 
    metadata_path='/scratch_shared/primmere/isic/metadata_w_split_v2_w_elmos_modifications.csv', 
    seed=seed, 
    num_workers=num_workers
        )
least_rel_first = True
abs_flag = True
#least_rel_first2 = False
#abs_flag2 = False
#Zplus_flag = True

#scale_bool = True

prune_r = [0.85, 0.88, 0.9, 0.92, 0.95]



layer_type = 'Softmax'

train_set = isic.get_train_set()
val_set = isic.get_valid_set()
test_set = isic.get_test_set()

pruning_indices = get_sample_indices_for_group(val_set, 30, device_string, [3])
print("pruning indices:" , len(pruning_indices))

custom_pruning_set = ISICSubset(val_set, pruning_indices)

train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=num_workers)
prune_dataloader = torch.utils.data.DataLoader(custom_pruning_set, batch_size=1, shuffle=True, num_workers=num_workers)
#prune_dataloader2 = torch.utils.data.DataLoader(custom_pruning_set2, batch_size=8, shuffle=True, num_workers=num_workers)
#prune_dataloader3 = torch.utils.data.DataLoader(custom_pruning_set3, batch_size=8, shuffle=True, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=num_workers)

[0 1 2]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 6314
group 1: 5526
group 2: 1571
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 60
group 1: 60
group 2: 60
group 3: 60
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 3158
group 1: 2763
group 2: 882
group 3: 761
target groups: [3]
pruning indices: 30


In [4]:
suggested_composite = {
        "low_level_hidden_layer_rule": "Epsilon",
        "mid_level_hidden_layer_rule":"Epsilon",
        "high_level_hidden_layer_rule": "Epsilon",
        "fully_connected_layers_rule": "Epsilon",
        "softmax_rule": "Epsilon",
    }
composite = get_vit_composite("vit_b_16", suggested_composite)
layer_types = {
        "Softmax": torch.nn.Softmax,
        "Linear": torch.nn.Linear,
        "Conv2d": torch.nn.Conv2d,
    }

In [5]:
model = ModelLoader.get_basic_model("vit_b_16", '/home/primmere/logs/isic_logs_4/vit_isic_v2.pt', device, num_classes=2)
#model2 = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)
#model3 = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)



Arch:vit_b_16


  loaded_checkpoint = torch.load(checkpoint_path, map_location=device)


In [6]:
"""
acc, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc)
for i in range(len(acc_groups)):
    print(f'{i}: {acc_groups[i]}')
"""



"""
0.8175568482284505
0: 0.9050031665611147
1: 0.9996380745566413
2: 0.4557823129251701
3: 0.21287779237844942
"""


'\n0.8175568482284505\n0: 0.9050031665611147\n1: 0.9996380745566413\n2: 0.4557823129251701\n3: 0.21287779237844942\n'

In [7]:
component_attributor = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types[layer_type],
        least_rel_first
    )

components_relevances = component_attributor.attribute(
        model,
        prune_dataloader,
        composite,
        abs_flag=abs_flag,
        Zplus_flag=False,
        device=device,
    )

layer_names = component_attributor.layer_names
pruner = GlobalPruningOperations(
        layer_types[layer_type],
        layer_names,
    )

print("done!")



  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


done!


In [8]:
i = 0
for k in components_relevances.keys():
    #print(torch.mean(v))
    a = torch.mean(components_relevances[k]).item()
    print(f'{i} {a:.3f}')
    i += 1

0 2.739
1 3.752
2 2.459
3 4.225
4 7.457
5 4.093
6 4.980
7 3.282
8 1.355
9 0.000
10 0.000
11 0.011


In [9]:
for name, R in components_relevances.items():
    print(name, torch.sum(R).item())

encoder.layers.encoder_layer_0.self_attention.softmax 32.86621856689453
encoder.layers.encoder_layer_1.self_attention.softmax 45.02645492553711
encoder.layers.encoder_layer_2.self_attention.softmax 29.50483512878418
encoder.layers.encoder_layer_3.self_attention.softmax 50.70082473754883
encoder.layers.encoder_layer_4.self_attention.softmax 89.48729705810547
encoder.layers.encoder_layer_5.self_attention.softmax 49.11297607421875
encoder.layers.encoder_layer_6.self_attention.softmax 59.76432800292969
encoder.layers.encoder_layer_7.self_attention.softmax 39.38343048095703
encoder.layers.encoder_layer_8.self_attention.softmax 16.256092071533203
encoder.layers.encoder_layer_9.self_attention.softmax 1.8475186891464546e-07
encoder.layers.encoder_layer_10.self_attention.softmax 5.622020580631215e-06
encoder.layers.encoder_layer_11.self_attention.softmax 0.12606577575206757


In [10]:
"""
combined_relevances = {}
for (k, v), (k2, v2) in zip(components_relevances.items(), components_relevances2.items()):
    combined_relevances[k] = v-v2

"""
    

'\ncombined_relevances = {}\nfor (k, v), (k2, v2) in zip(components_relevances.items(), components_relevances2.items()):\n    combined_relevances[k] = v-v2\n\n'

In [11]:
for r in prune_r:
    global_pruning_mask = pruner.generate_global_pruning_mask(
                    model,
                    components_relevances,
                    r,
                    subsequent_layer_pruning=layer_type,
                    least_relevant_first=True,
                    device=device,
                )
    hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask,)
    acc, acc_groups = compute_worst_accuracy(
            model,
            val_dataloader,
            device,
        )
    print(f'{r} & {acc:.3f} & {acc_groups[0]:.3f} & {acc_groups[1]:.3f} & {acc_groups[2]:.3f} & {acc_groups[3]:.3f} \\\\')
    #visualise(global_pruning_mask, r, layer_type)

evaluating group acc:   0%|          | 0/119 [00:00<?, ?it/s]

0.85 & 0.817 & 0.899 & 0.999 & 0.456 & 0.238 \\


evaluating group acc:   0%|          | 0/119 [00:00<?, ?it/s]

0.88 & 0.804 & 0.858 & 0.996 & 0.483 & 0.252 \\


evaluating group acc:   0%|          | 0/119 [00:00<?, ?it/s]

0.9 & 0.796 & 0.834 & 0.953 & 0.505 & 0.410 \\


evaluating group acc:   0%|          | 0/119 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14da277a30a0>
Traceback (most recent call last):
  File "/home/primmere/.conda/envs/dfr2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x14da277a30a0>    
self._shutdown_workers()Traceback (most recent call last):

  File "/home/primmere/.conda/envs/dfr2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/home/primmere/.conda/envs/dfr2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():

  File "/home/primmere/.conda/envs/dfr2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
  File "/home/primmere/.conda/envs/dfr2/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), '

0.92 & 0.740 & 0.734 & 0.804 & 0.630 & 0.658 \\


evaluating group acc:   0%|          | 0/119 [00:00<?, ?it/s]

0.95 & 0.786 & 0.844 & 0.949 & 0.382 & 0.419 \\


In [12]:
"""
0.9 & 0.816 & 0.894 & 1.000 & 0.472 & 0.227 \\
0.92 & 0.820 & 0.892 & 0.999 & 0.488 & 0.255 \\
0.94 & 0.829 & 0.880 & 0.999 & 0.517 & 0.368 \\
0.95 & 0.837 & 0.879 & 0.998 & 0.536 & 0.430 \\
0.96 & 0.838 & 0.864 & 0.985 & 0.558 & 0.524 \\
0.98 & 0.710 & 0.772 & 0.630 & 0.664 & 0.796 \\
"""

'\n0.9 & 0.816 & 0.894 & 1.000 & 0.472 & 0.227 \\\n0.92 & 0.820 & 0.892 & 0.999 & 0.488 & 0.255 \\\n0.94 & 0.829 & 0.880 & 0.999 & 0.517 & 0.368 \\\n0.95 & 0.837 & 0.879 & 0.998 & 0.536 & 0.430 \\\n0.96 & 0.838 & 0.864 & 0.985 & 0.558 & 0.524 \\\n0.98 & 0.710 & 0.772 & 0.630 & 0.664 & 0.796 \\\n'

In [13]:
count=0
for n, t in global_pruning_mask.items():
    param_total = t['Linear']['weight'].numel()
    param_nonzero = t['Linear']['weight'].nonzero().size(0)
    param_shape = t['Linear']['weight'].shape
    
    pruned = (param_total-param_nonzero)/param_shape[1]
    total = param_total/param_shape[1]
    #print(pruned, "/", total, "pruned,", pruned/total)
    print(f'{100*pruned/total:.1f}% - {pruned} / {total} pruned')
    

IndexError: too many indices for tensor of dimension 1

In [None]:
"""
acc, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc)
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

"""

In [None]:
#global_pruning_mask

In [None]:
#global_pruning_mask2

In [None]:
for v in components_relevances.values():
    print(v[0:6])

for v in components_relevances.values():
    print(v[0:7])

In [None]:
print(f'{prune_r},{acc:.3f},{acc_groups[0]:.3f},{acc_groups[1]:.3f},{acc_groups[2]:.3f},{acc_groups[3]:.3f}')

In [None]:
print(f'{r} & {acc:.3f} & {acc_groups[0]:.3f} & {acc_groups[1]:.3f} & {acc_groups[2]:.3f} & {acc_groups[3]:.3f} \\\\')

In [None]:
layer_type

In [None]:
visualise(global_pruning_mask_combined, r, layer_type)