In [None]:
!rm -rf ./ML4CV_XAI
!git clone https://github.com/liuktc/ML4CV_XAI.git
!pip install captum grad_cam Craft-xai torcheval
%load_ext autoreload

%autoreload 2

import sys
sys.path.append('/kaggle/working/ML4CV_XAI')

In [1]:
%load_ext autoreload

%autoreload 2

import numpy as np
import torch
from torch.utils.data import DataLoader

from utils import _DeepLiftShap, _GradCAMPlusPlus, SimpleUpsampling,ERFUpsamplingFast, min_max_normalize
from data import PascalVOC2007
from results.results_metrics import ResultMetrics
from models import vgg11_PascalVOC, vgg_preprocess

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
torch.manual_seed(123)
np.random.seed(123)




In [2]:
model = vgg11_PascalVOC()
model.to(device)
# Load the pretrained weights
model.load_state_dict(torch.load('VGG11_PascalVOC.pt', map_location=device))
model.eval()

preprocess = vgg_preprocess

In [3]:
test_data = PascalVOC2007("test", transform=preprocess)
train_data = PascalVOC2007("trainval", transform=preprocess)

Using downloaded and verified file: data\VOCtest_06-Nov-2007.tar
Extracting data\VOCtest_06-Nov-2007.tar to data
Using downloaded and verified file: data\VOCtrainval_06-Nov-2007.tar
Extracting data\VOCtrainval_06-Nov-2007.tar to data


In [4]:
import torch.utils
from torch.utils.data import Subset

BATCH_SIZE_TEST = 1
NUM_TEST = 128

BATCH_SIZE_TRAIN = 2
NUM_TRAIN = 16

dl_test = DataLoader(Subset(test_data, torch.randperm(len(test_data))[:NUM_TEST]), batch_size=BATCH_SIZE_TEST, shuffle=False)
dl_train = DataLoader(Subset(train_data, torch.randperm(len(train_data))[:NUM_TRAIN]), batch_size=BATCH_SIZE_TRAIN, shuffle=True)

In [5]:
baseline_dist_16 = torch.cat([images for images, _ in dl_train], dim=0).to(device)
baseline_dist_8 = baseline_dist_16[:8].clone()
baseline_dist_4 = baseline_dist_16[:4].clone()

In [6]:
LAYERS = [20,15,10,5]

In [None]:
from metrics import RoadCombined
from utils import MultiplierMix
from tqdm.auto import tqdm

results = ResultMetrics("./results_mixed.csv")

for index, (images, labels) in enumerate(tqdm(dl_test)):
    images = images.to(device)
    labels = labels.to(device).reshape(-1)

    # Make sure that LAYERS is sorted in descending order
    LAYERS.sort(reverse=True)
    for attribution_method in [_DeepLiftShap(baseline_dist_8, name="DeepLiftShap8"), _GradCAMPlusPlus]:
        for upsample_method in [ERFUpsamplingFast(), SimpleUpsampling((224,224))]:
            attributions_per_layer = []
            for layer_index in LAYERS:
                layer = model.features[layer_index]
                attribution_map = attribution_method.attribute(input_tensor=images,
                                                                model=model,
                                                                layer=layer,
                                                                target=labels
                                                                )

                # upsample_method = SimpleUpsampling((224,224))
                attribution_map = upsample_method(attribution=attribution_map,
                                                  image=images,
                                                  device=device,
                                                  model=model,
                                                  layer=layer)

                if (torch.abs(
                    attribution_map.amax(dim=(2, 3), keepdim=True)
                    - attribution_map.amin(dim=(2, 3), keepdim=True)
                )
                < 1e-6  ).any():
                    print("A saliency map is constant, skipping batch")
                    del images, labels, attribution_map
                    torch.cuda.empty_cache()
                    continue

                attribution_map = min_max_normalize(attribution_map)
                attributions_per_layer.append(attribution_map)

                # Mix the attribution maps
                mix = MultiplierMix(layers_to_combine="all")
                mixed_attribution_map = mix(attributions_per_layer)

                for metric in [RoadCombined()]:
                    metric_result_normal = metric(model=model,
                            test_images=images,
                            saliency_maps=attribution_map,
                            class_idx=labels,
                            attribution_method=attribution_method,
                            device=device,
                            layer=layer)
                    
                    results.add_result(model="VGG11",
                                       attribution_method=attribution_method.name,
                                       dataset="PascalVOC",
                                       layer=f"features.{layer_index}",
                                       metric=metric.name,
                                       upscale_method=upsample_method.name,
                                       mixing_method= "None",
                                       value=metric_result_normal,
                                       image_index=index)

                    metric_result_mixed = metric(model=model,
                            test_images=images,
                            saliency_maps=mixed_attribution_map,
                            class_idx=labels,
                            attribution_method=attribution_method,
                            device=device,
                            layer=layer)
                    
                    
                    results.add_result(model="VGG11",
                                       attribution_method=attribution_method.name,
                                       dataset="PascalVOC",
                                       layer=f"features.{layer_index}",
                                       metric=metric.name,
                                       upscale_method=upsample_method.name,
                                       mixing_method= mix.name,
                                       value=metric_result_mixed,
                                       image_index=index)
    break

Results file not found. Creating new results file ./results_mixed.csv.


  0%|          | 0/128 [00:00<?, ?it/s]

               activations. The hooks and attributes will be removed
            after the attribution is finished
  self.results = pd.concat(


KeyboardInterrupt: 