In [3]:
import yaml
import click
import torch
import tqdm.auto
import numpy as np
 
import sys
import os
import math
project_root = "C:/Users/elmop/deep_feature_reweighting/deep_feature_reweighting/external/pruning_by_explaining"
sys.path.insert(0, project_root)                 
sys.path.insert(0, os.path.dirname(project_root))

from pruning_by_explaining.models import ModelLoader
from pruning_by_explaining.metrics import compute_accuracy
from pruning_by_explaining.my_metrics import compute_worst_accuracy
from pruning_by_explaining.my_datasets import WaterBirds, get_sample_indices_for_group, WaterBirdSubset, ISIC, ISICSubset


from pruning_by_explaining.pxp import (
    ModelLayerUtils,
    get_cnn_composite,
    get_vit_composite,
)

from pruning_by_explaining.pxp import GlobalPruningOperations
from pruning_by_explaining.pxp import ComponentAttibution

In [2]:
num_workers = 2
waterbirds = WaterBirds("F:/model_stuff/waterbirds/waterbird", seed = 1, num_workers = num_workers)
val_set = waterbirds.get_valid_set()

pruning_indices = get_sample_indices_for_group(val_set, 10, "cpu", [3])
validation_indices = get_sample_indices_for_group(val_set, 'all', "cpu")

custom_pruning_set = WaterBirdSubset(val_set, pruning_indices)
custom_val_set = WaterBirdSubset(val_set, validation_indices)

prune_dataloader = torch.utils.data.DataLoader(custom_pruning_set, batch_size=1, shuffle=True, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(custom_val_set, batch_size=8, shuffle=True, num_workers=num_workers)

[0 1 2 3]
Number of unique labels: 2, Number of unique places: 2, Total groups: 4
group 0: 456
group 1: 456
group 2: 143
group 3: 144
target groups: [3]
target groups: [0, 1, 2, 3]


In [4]:
suggested_composite = {
        "low_level_hidden_layer_rule": "Epsilon",
        "mid_level_hidden_layer_rule":"Epsilon",
        "high_level_hidden_layer_rule": "Epsilon",
        "fully_connected_layers_rule": "Epsilon",
        "softmax_rule": "Epsilon",
    }
composite = get_vit_composite("vit_b_16", suggested_composite)

model = ModelLoader.get_basic_model("vit_b_16", "F:/model_stuff/vit_waterbirds.pth", torch.device("cpu"), num_classes=2)


Arch:vit_b_16


  loaded_checkpoint = torch.load(checkpoint_path, map_location=device)


In [4]:
layer_types = {
        "Softmax": torch.nn.Softmax,
        "Linear": torch.nn.Linear,
        "Conv2d": torch.nn.Conv2d,
    }

In [7]:
device = torch.device("cpu")
acc_top1 = compute_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc_top1)

evaluating acc:   0%|          | 0/150 [00:13<?, ?it/s]

0.8824020016680567


In [6]:
len(validation_indices)

1199

In [8]:
model.eval()

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [9]:
acc_top1 = compute_accuracy(
        model,
        val_dataloader,
        device,
    )
print(acc_top1)

evaluating acc:   0%|          | 0/150 [00:17<?, ?it/s]

0.8824020016680567


In [10]:
worst_acc, group_acc = compute_worst_accuracy(
        model,
        val_dataloader,
        device,
    )
print(worst_acc)
print(group_acc)

0.7272727272727273
{2: 0.7272727272727273, 0: 0.9868421052631579, 1: 0.7960526315789473, 3: 0.9791666666666666}


In [5]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a