In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import sys
import pickle
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL.Image
from sklearn.metrics import accuracy_score
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import timm
from statistics import mean
import random
import gc
from scipy.signal import savgol_filter

from tqdm import tqdm
# from shared_interest.datasets.imagenet import ImageNet
from shared_interest.shared_interest import shared_interest
from shared_interest.util import flatten, normalize_0to1, binarize_std, binarize_percentile
from interpretability_methods.vanilla_gradients import VanillaGradients

import xml.etree.ElementTree as ET
from torchvision.datasets import ImageFolder

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
# retrieved from Shared Interest Repository (edited for debugging purposes) 

"""Dataset for ImageNet with annotations."""

import os
import xml.etree.ElementTree as ET
import torch
from torchvision.datasets import ImageFolder


class ImageNet(ImageFolder):
    """Extends ImageFolder dataset to include ground truth annotations."""

    def __init__(self, image_path, ground_truth_path, image_transform=None,
                 ground_truth_transform=None):
        """
        Extends the parent class with annotation information.

        Additional Args:
        image_path: the path to the ImageNet images. This folder must be
            formatted in ImageFolder style (i.e. label/imagename.jpeg)
        ground_truth_path: the path to the ImageNet annotations. This folder
            must be formated in ImageFolder style (i.e., label/imagename.xml).
        image_transform: a pytorch transform to apply to the images or None.
            Defaults to None.
        ground_truth_transform: a pytorch transform to apply to the ground
            truth annotations or None. Defaults to None.

        """
        super().__init__(image_path, transform=image_transform)
        self.ground_truth_transform = ground_truth_transform
        self.ground_truth_path = ground_truth_path

    def __getitem__(self, index):
        """Returns the image, ground_truth mask, and label of the image."""
        image, _ = super().__getitem__(index)
        image_path, _ = self.imgs[index]
        image_name = image_path.strip().split('/')[-1].split('.')[0]
        label = image_path.strip().split('/')[-2]

        ground_truth_file = os.path.join(self.ground_truth_path, label, '%s.xml' %image_name)
        ground_truth = self._create_ground_truth(ground_truth_file)
        # if (not torch.any(ground_truth)):
        #     print ("BEFORE all zeros: " + str(index) + " (" + ground_truth_file + ")")
        # ground_truth_before = ground_truth.clone().detach()
        zero_ground = 0 
        if self.ground_truth_transform is not None:
            ground_truth = self.ground_truth_transform(ground_truth).squeeze(0)
        if (not torch.any(ground_truth)):
            zero_ground = index
            # print ("item index: " + str(index))
            # print (str(index) + " (" + ground_truth_file + ")")
        return image_path, ground_truth_file, image, ground_truth, int(label), zero_ground

    def _create_ground_truth(self, ground_truth_file):
        """Creates a binary groudn truth mask based on the ImageNet annotations."""
        annotation = self._parse_xml(ground_truth_file)
        height, width = int(annotation['height']), int(annotation['width'])
        ground_truth = torch.zeros((height, width))
        for coordinate in annotation['coordinates']:
            y_min, y_max = int(coordinate['ymin']), int(coordinate['ymax'])
            x_min, x_max = int(coordinate['xmin']), int(coordinate['xmax'])
            ground_truth[y_min:y_max, x_min:x_max] = 1
        # if (ground_truth_file == '/nobackup/users/hbang/data/imagenet/val_old/annotations/0109/ILSVRC2012_val_00016569.xml'):
        #     print (annotation)
        #     print (torch.count_nonzero(ground_truth))
        return ground_truth

    def _parse_xml(self, ground_truth_file):
        """Parse ImageNet annotation XML file."""
        if not os.path.isfile(ground_truth_file):
            raise IOError('No annotation data for %s.' %(ground_truth_file))
        tree = ET.parse(ground_truth_file)
        root = tree.getroot()
        bboxes = [obj.find('bndbox') for obj in root.findall('object')]
        coords = [{'xmin': int(bbox.find('xmin').text),
                   'ymin': int(bbox.find('ymin').text),
                   'xmax': int(bbox.find('xmax').text),
                   'ymax': int(bbox.find('ymax').text), } for bbox in bboxes]
        height = root.find('size').find('height').text
        width = root.find('size').find('width').text
        return {'coordinates': coords, 'height': height, 'width': width}
    

In [34]:
imagenet_dir = '/nobackup/users/hbang/data/imagenet/val_old/'
image_dir = os.path.join(imagenet_dir, 'images')
annotation_dir = os.path.join(imagenet_dir, 'annotations')

In [35]:
# ImageNet transforms.
image_transform = transforms.Compose([transforms.Resize(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                           std=[0.229, 0.224, 0.225]),
                                     ])

ground_truth_transform = transforms.Compose([transforms.ToPILImage(),
                                             transforms.Resize(256, PIL.Image.NEAREST),
                                             transforms.CenterCrop(224),
                                             transforms.ToTensor()])

reverse_image_transform = transforms.Compose([transforms.Normalize(mean=[0, 0, 0], 
                                                                   std=[4.3668, 4.4643, 4.4444]),
                                              transforms.Normalize(mean=[-0.485, -0.456, -0.406], 
                                                                   std=[1, 1, 1]),
                                              transforms.ToPILImage(),])

In [36]:
dataset = ImageNet(image_dir, annotation_dir, image_transform, ground_truth_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, 
                                         num_workers=10, pin_memory=True)

In [37]:
def parse_xml(ground_truth_file):
    """Parse ImageNet annotation XML file."""
    if not os.path.isfile(ground_truth_file):
        raise IOError('No annotation data for %s.' %(ground_truth_file))
    tree = ET.parse(ground_truth_file)
    root = tree.getroot()
    bboxes = [obj.find('bndbox') for obj in root.findall('object')]
    coords = [{'xmin': int(bbox.find('xmin').text),
               'ymin': int(bbox.find('ymin').text),
               'xmax': int(bbox.find('xmax').text),
               'ymax': int(bbox.find('ymax').text), } for bbox in bboxes]
    height = root.find('size').find('height').text
    width = root.find('size').find('width').text
    return {'coordinates': coords, 'height': height, 'width': width}

In [None]:
# checking what is causing the 0 ground_truth values 

torch.set_printoptions(threshold=900_000)
model = timm.create_model(timm.list_models('*coatnet*')[-1], pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

saliency_method =  VanillaGradients(model)
model.eval()

# nan_ground_truth = {}

for i, (image_path, ground_truth_file, images, ground_truth, labels, zero_ground) in enumerate(tqdm(dataloader, position=0, leave=True)):
    with torch.no_grad():
        images = images.to(device)
        ground_truth = ground_truth.numpy()
        labels = labels.numpy()
        
        # deleting 0 ground truths from dataset 
        # if i in nan_ground_truth:
        #     nans = nan_ground_truth[i]
        #     for j in sorted(nans, reverse=True):
        #         images = torch.cat([images[0:j], images[j+1:]])
        #     ground_truth = np.delete(ground_truth, nans, 0)
        #     labels = np.delete(labels, nans, 0)
            
        # nonzero_ground = torch.nonzero(zero_ground, as_tuple = True)[0].tolist()
        # if (len(nonzero_ground) > 0):
        #     nan_ground_truth[i] = nonzero_ground 

        
        # Compute Shared Interest scores
        for score in total_shared_interest_scores:
            shared_interest_scores = shared_interest(ground_truth, saliency_masks, score=score)
            if score == 'ground_truth_coverage' and np.isnan(shared_interest_scores).any():
                index = sum(np.argwhere(np.isnan(shared_interest_scores)).tolist(), [])
                # check if the set of indices of GTC score of NaN and set of indicies of ground_truth of 0 are equal 
                if not (set(index) == set(nonzero_ground)):
                    print ("not equal")
                    print (index) 
                    print (nonzero_ground)
                    print ("-------------")
                
                # check if for each annotation, the 0 is happening due to annotation being scaled to 224 
                for j in index: 
                    annotation = parse_xml(ground_truth_file[j])
                    # print (annotation)

                    x_ = int(annotation['width'])
                    y_ = int(annotation['height'])
                    if (x_ < y_):
                        scale = 256 / x_
                        x = 256
                        y = int(y_ * (scale))
                    else: 
                        scale = 256 / y_
                        y = 256
                        x = int(x_ * (scale))

                    for coordinate in annotation['coordinates']:
                        xmin = int(np.round(int(coordinate['xmin'])) * scale)
                        ymin = int(np.round(int(coordinate['ymin'])) * scale)
                        xmax = int(np.round(int(coordinate['xmax'])) * scale)
                        ymax = int(np.round(int(coordinate['ymax'])) * scale)
                        
                        xleft = xmax <= ((x - 224) / 2)
                        xright = xmin >= ((x + 224) / 2)
                        yup = ymax <= ((y - 224) / 2)
                        ydown = ymin >= ((y + 224) / 2)
                        
                        # coming from the pixels being inbetween 
                        err_xleft = (xmax - ((x - 224) / 2)) < 2
                        err_xright = (xmin - ((x + 224) / 2)) > -2
                        err_yup = (ymax - ((y - 224) / 2)) < 2
                        err_ydown = (ymin - ((y + 224) / 2)) > -2
                        if (xleft or xright or yup or ydown):
                            pass
                        elif (err_xleft or err_xright or err_yup or err_ydown):
                            print (x,y)
                            print (annotation)
                            print (coordinate)
                            print (xmin, xmax)
                            print (ymin, ymax)
                            print ((xmax - ((x - 224) / 2)), (xmin - ((x + 224) / 2)))
                            print ((ymax - ((y - 224) / 2)), (ymin - ((y + 224) / 2)))
                            print ("----------------")
                        else :
                            print ("not wrong?")
                            break

In [17]:
model_path = os.path.abspath("/home/hbang/data/vanilla_gradient/" + timm.list_models('*coatnet*')[-1] + "/")
with open(os.path.join(model_path, "shared_interest_scores.pickle"), 'rb') as handle:
    shared_interest_scores = pickle.load(handle)

gtc_coverage = shared_interest_scores['ground_truth_coverage']
print (gtc_coverage.size)
if np.isnan(gtc_coverage).any():
    print ("NAN")
print (gtc_coverage)

50000
NAN
[0.21652526 0.17744452 0.11775148 ... 0.23425196 0.17495775 0.13681567]


In [18]:
n = 0
nans = np.array([])
for i in range (0, gtc_coverage.size):
    if np.isnan(gtc_coverage[i]):
        n += 1
        nans = np.append(nans, i)
# print (nans) # total 185 of NANs in GTC
# print (nans.size)
nans = nans.astype(int)
print (nans)

[ 5463  7509 15345 16151 16194 20090 20265 20273 20276 20295 20297 20323
 20637 20800 20803 20804 20806 20808 20809 20814 20820 20821 20825 20835
 20839 20842 20845 20859 20894 20986 21360 21379 21451 21469 21473 21481
 21482 21494 21501 21504 21510 21518 21520 21521 21522 21535 21542 21548
 21549 21618 21673 21694 21860 21866 21875 21965 22754 22897 23156 23242
 24487 24971 25492 25720 25729 25753 26105 26114 26118 26136 26550 26559
 26812 26818 26829 27852 27891 27929 28488 29075 29250 29274 30107 30111
 30133 30135 30138 30141 30145 30148 30587 30765 31579 32114 32138 32557
 32765 33352 33364 33876 34010 34131 34695 34987 35125 35441 35758 35762
 35763 35764 35772 35773 35776 35777 35781 35782 35785 35786 35787 35789
 35793 35794 35796 35797 36119 36584 36698 37103 37180 37213 37314 37316
 37318 37323 37335 37336 37340 37342 37347 37348 37349 37804 38357 38432
 38441 38777 39263 39680 39697 40280 40939 40941 40944 41617 42043 42149
 42296 42617 42620 42622 42626 42631 43107 44504 44