In [1]:
import yaml
import click
import torch
import tqdm.auto
import numpy as np
from torchvision.models import vit_b_16

from matplotlib import pyplot as plt


import sys
import os
import re
import math
#project_root = "C:/Users/elmop/deep_feature_reweighting/deep_feature_reweighting/external/pruning_by_explaining"
project_root = "/home/primmere/ide/external/pruning_by_explaining"
sys.path.insert(0, project_root)                 
sys.path.insert(0, os.path.dirname(project_root))

from pruning_by_explaining.models import ModelLoader
from pruning_by_explaining.metrics import compute_accuracy
from pruning_by_explaining.my_metrics import compute_worst_accuracy
from pruning_by_explaining.my_datasets import WaterBirds, get_sample_indices_for_group, WaterBirdSubset, ISIC, ISICSubset
from pruning_by_explaining.utils import (
    initialize_random_seed,
    initialize_wandb_logger,
)


from pruning_by_explaining.pxp import (
    ModelLayerUtils,
    get_cnn_composite,
    get_vit_composite,
)

from pruning_by_explaining.pxp import GlobalPruningOperations
from pruning_by_explaining.pxp import ComponentAttibution

from pruning_by_explaining.my_experiments.utils import visualise, plot_layer_head_heatmap, plot_layer_head_pruned, plot_r_accuracy_lines


In [2]:
seed = 1
initialize_random_seed(seed)
num_workers = 12
device_string = "cuda"
device = torch.device(device_string)
isic = ISIC(
    "/scratch_shared/primmere/isic/isic_224/raw_224_with_selected", 
    metadata_path='/scratch_shared/primmere/isic/metadata_w_split_v2_w_elmos_modifications.csv', 
    seed=seed, 
    num_workers=num_workers
        )
least_rel_first = True
abs_flag = True
#least_rel_first2 = False
#abs_flag2 = False
#Zplus_flag = True

#scale_bool = True

prune_r = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.92,0.94,0.96,0.98]



layer_type = 'Linear'

train_set = isic.get_train_set()
val_set = isic.get_valid_set()
test_set = isic.get_test_set()

pruning_indices = get_sample_indices_for_group(val_set, 30, device_string, [0,1,2,3])
print("pruning indices:" , len(pruning_indices))

custom_pruning_set = ISICSubset(val_set, pruning_indices)

train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=num_workers)
prune_dataloader = torch.utils.data.DataLoader(custom_pruning_set, batch_size=1, shuffle=True, num_workers=num_workers)
#prune_dataloader2 = torch.utils.data.DataLoader(custom_pruning_set2, batch_size=8, shuffle=True, num_workers=num_workers)
#prune_dataloader3 = torch.utils.data.DataLoader(custom_pruning_set3, batch_size=8, shuffle=True, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=True, num_workers=num_workers)

[0 1 2]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 6314
group 1: 5526
group 2: 1571
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 60
group 1: 60
group 2: 60
group 3: 60
[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 3158
group 1: 2763
group 2: 882
group 3: 761
target groups: [0, 1, 2, 3]
pruning indices: 120


In [3]:
suggested_composite = {
        "low_level_hidden_layer_rule": "Epsilon",
        "mid_level_hidden_layer_rule":"Epsilon",
        "high_level_hidden_layer_rule": "Epsilon",
        "fully_connected_layers_rule": "Epsilon",
        "softmax_rule": "Epsilon",
    }
composite = get_vit_composite("vit_b_16", suggested_composite)
layer_types = {
        "Softmax": torch.nn.Softmax,
        "Linear": torch.nn.Linear,
        "Conv2d": torch.nn.Conv2d,
    }

In [4]:
model = ModelLoader.get_basic_model("vit_b_16", '/home/primmere/logs/isic_logs_4/vit_isic_v2.pt', device, num_classes=2)
#model2 = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)
#model3 = ModelLoader.get_basic_model("vit_b_16", "/home/primmere/ide/dfr/logs/vit_waterbirds.pth", device, num_classes=2)



Arch:vit_b_16


  loaded_checkpoint = torch.load(checkpoint_path, map_location=device)


In [5]:
"""
acc, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc)
for i in range(len(acc_groups)):
    print(f'{i}: {acc_groups[i]}')
"""



"""
0.8175568482284505
0: 0.9050031665611147
1: 0.9996380745566413
2: 0.4557823129251701
3: 0.21287779237844942
"""


'\n0.8175568482284505\n0: 0.9050031665611147\n1: 0.9996380745566413\n2: 0.4557823129251701\n3: 0.21287779237844942\n'

In [6]:
component_attributor = ComponentAttibution(
        "Relevance",
        "ViT",
        layer_types[layer_type],
        least_rel_first
    )

components_relevances = component_attributor.attribute(
        model,
        prune_dataloader,
        composite,
        abs_flag=abs_flag,
        Zplus_flag=False,
        device=device,
    )

layer_names = component_attributor.layer_names
pruner = GlobalPruningOperations(
        layer_types[layer_type],
        layer_names,
    )

print("done!")



zennit


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


done!


In [7]:
#plot_layer_head_heatmap(components_relevances, normalise = len(pruning_indices))

In [8]:
i = 0
for k in components_relevances.keys():
    #print(torch.mean(v))
    a = torch.mean(components_relevances[k]).item()
    print(f'{i} {a:.3f}')
    i += 1

0 1.213
1 4.591
2 1.298
3 4.264
4 1.569
5 3.732
6 0.344
7 1.128
8 0.003
9 0.098
10 0.000
11 0.067
12 0.000
13 0.031
14 0.000
15 0.012
16 0.000
17 0.005
18 0.000
19 0.005
20 0.000
21 0.005
22 0.000
23 0.005


In [9]:
for name, R in components_relevances.items():
    print(name, torch.sum(R).item())

encoder.layers.encoder_layer_0.mlp.0 3726.03515625
encoder.layers.encoder_layer_0.mlp.3 3525.6142578125
encoder.layers.encoder_layer_1.mlp.0 3987.47021484375
encoder.layers.encoder_layer_1.mlp.3 3274.826416015625
encoder.layers.encoder_layer_2.mlp.0 4819.919921875
encoder.layers.encoder_layer_2.mlp.3 2866.365478515625
encoder.layers.encoder_layer_3.mlp.0 1057.2669677734375
encoder.layers.encoder_layer_3.mlp.3 866.6635131835938
encoder.layers.encoder_layer_4.mlp.0 8.745113372802734
encoder.layers.encoder_layer_4.mlp.3 75.50364685058594
encoder.layers.encoder_layer_5.mlp.0 0.283525288105011
encoder.layers.encoder_layer_5.mlp.3 51.424766540527344
encoder.layers.encoder_layer_6.mlp.0 0.00014410133007913828
encoder.layers.encoder_layer_6.mlp.3 23.728965759277344
encoder.layers.encoder_layer_7.mlp.0 0.00014098797691985965
encoder.layers.encoder_layer_7.mlp.3 9.548144340515137
encoder.layers.encoder_layer_8.mlp.0 0.00017486786236986518
encoder.layers.encoder_layer_8.mlp.3 3.8210458755493164
e

In [10]:
"""
combined_relevances = {}
for (k, v), (k2, v2) in zip(components_relevances.items(), components_relevances2.items()):
    combined_relevances[k] = v-v2

"""
    

'\ncombined_relevances = {}\nfor (k, v), (k2, v2) in zip(components_relevances.items(), components_relevances2.items()):\n    combined_relevances[k] = v-v2\n\n'

In [None]:
accs = np.zeros((len(prune_r),6))

i = 0

for r in prune_r:
    global_pruning_mask = pruner.generate_global_pruning_mask(
                    model,
                    components_relevances,
                    r,
                    subsequent_layer_pruning=layer_type,
                    least_relevant_first=True,
                    device=device,
                )
    hook_handles = pruner.fit_pruning_mask(model, global_pruning_mask,)
    
    acc, acc_groups = compute_worst_accuracy(
            model,
            val_dataloader,
            device,
        )
    accs[i]=np.array([r,acc,acc_groups[0],acc_groups[1],acc_groups[2],acc_groups[3]])
    i+=1
    print(f'{r} & {acc:.3f} & {acc_groups[0]:.3f} & {acc_groups[1]:.3f} & {acc_groups[2]:.3f} & {acc_groups[3]:.3f} \\\\')
    #visualise(global_pruning_mask, r, layer_type)
    


evaluating group acc:   0%|          | 0/237 [00:00<?, ?it/s]

0.1 & 0.818 & 0.905 & 1.000 & 0.456 & 0.213 \\


evaluating group acc:   0%|          | 0/237 [00:00<?, ?it/s]

0.2 & 0.818 & 0.905 & 1.000 & 0.456 & 0.213 \\


evaluating group acc:   0%|          | 0/237 [00:00<?, ?it/s]

In [None]:
for i in range(len(accs)):
    print(f'{accs[i][0]} & {accs[i][1]:.3f} & {accs[i][2]:.3f} & {accs[i][3]:.3f} & {accs[i][4]:.3f} & {accs[i][5]:.3f} \\\\')

In [None]:
plot_r_accuracy_lines(accs[7:], save_path = 'pxp viz/isic/acc_linear_0123.png')

In [None]:
"""
acc, acc_groups = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc)
for i in range(4):
    print(f'{i}: {acc_groups[i]}')

"""

In [None]:
#global_pruning_mask

In [None]:
#global_pruning_mask2

In [None]:
#for v in components_relevances.values():
#    print(v[0:6])

for v in components_relevances.values():
    print(v[0:7])

In [None]:
print(f'{prune_r},{acc:.3f},{acc_groups[0]:.3f},{acc_groups[1]:.3f},{acc_groups[2]:.3f},{acc_groups[3]:.3f}')

In [None]:
print(f'{r} & {acc:.3f} & {acc_groups[0]:.3f} & {acc_groups[1]:.3f} & {acc_groups[2]:.3f} & {acc_groups[3]:.3f} \\\\')

In [None]:
layer_type

In [None]:
#visualise(global_pruning_mask_combined, r, layer_type)

In [None]:
#torch.save(model.state_dict(), "checkpoints/isic_30_0123_r0.96.pth")