1. Figure out which examples googlenet gets right, and which ones it gets wrong from the test set
2. Use TCAV to generate explanations for the wrong ones
3. Analyze the explanations

In [1]:
# ..........torch imports............
from pathlib import Path

import numpy as np
import torch
import torchvision
import os
import glob

from PIL import Image
from torchvision.datasets import ImageFolder
# .... Captum imports..................
from captum.attr import LayerIntegratedGradients
from captum.concept import TCAV
from captum.concept._utils.common import concepts_to_str

# .... Local imports..................
from joblib import load, dump

from HierarchicalExplanation import HierarchicalExplanation
from generate_data.hierarchy import Hierarchy
from utils import assemble_all_concepts_from_hierarchy, assemble_random_concepts, generate_experiments
from utils import load_image_tensors, transform, plot_tcav_scores, assemble_scores, get_pval, show_boxplots

# Load Hierarchy
HIERARCHY_JSON_PATH = 'generate_data/hierarchy.json'
HIERARCHY_WORDNET_PATH = 'generate_data/wordnet_labels.txt'
IMAGENET_IDX_TO_LABELS = 'generate_data/imagenet1000_clsidx_to_labels.txt'
h = Hierarchy(json_path=HIERARCHY_JSON_PATH, wordnet_labels_path=HIERARCHY_WORDNET_PATH,
              imagenet_idx_labels_path=IMAGENET_IDX_TO_LABELS)

###################################################
# Assemble Concepts
# Let's assemble concepts into Concept instances using Concept class and concept images stored in `concepts_path`.
###################################################

# concepts_path = "/home/devvrit/ishann/data/captum/tcav/concepts"
concepts_path = "../data"

# Assemble non-random concepts
concepts = assemble_all_concepts_from_hierarchy(h=h, num_images=100, concepts_path=concepts_path,
                                                recreate_if_exists=True)  # Only 100 images for testing, can increase later

# Assemble all random concepts
random_concepts = assemble_random_concepts(concepts_path=concepts_path)


# Defining GoogleNet Model
model = torchvision.models.googlenet(pretrained=True)
model = model.eval()
layers=['fc']

Assembling concepts...


Creating concepts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00,  4.52it/s]
Creating Random concepts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 102.81it/s]


In [2]:
# 1. Figure out which examples googlenet gets right, and which ones it gets wrong from the test set
import json
import tqdm

data_path = '../data'
IMAGENET_IDX_TO_LABELS = 'generate_data/imagenet1000_clsidx_to_labels.txt'
with open(IMAGENET_IDX_TO_LABELS, 'r') as f:
    idx2label = json.load(fp=f)

class_name = 'volcano' #'Siberian husky'
class_idx = h.imagenet_label2idx[class_name]
wrong_count = 1000

path = os.path.join(data_path, class_name)
filenames = glob.glob(path + '/*.JPEG')
correct_files = []
wrong_files = []
for i, filename in tqdm.notebook.tqdm(enumerate(filenames)):
    if len(wrong_files) >= wrong_count:
        break

    img = transform(Image.open(filename).convert('RGB'))
    img_batch = torch.unsqueeze(img, 0)
    pred = model(img_batch)
    if torch.argmax(pred) != class_idx:
        wrong_files.append(filename)
    else:
        correct_files.append(filename)

0it [00:00, ?it/s]

In [3]:
len(wrong_files)

150

In [4]:
# Load the files into tensors
def load_image_tensors_from_list(list_images, do_transform=True, count=100):

    tensors = []
    for i, filename in enumerate(list_images):
        if i >= count:
            break
        print(filename)
        img = Image.open(filename).convert('RGB')
        tensors.append(transform(img) if do_transform else img)
        
    return torch.stack(tensors)

class_images_correct = load_image_tensors_from_list(correct_files, count=200)
print("-------------")
class_images_wrong = load_image_tensors_from_list(wrong_files, count=200)

../data\volcano\n09472597_10018.JPEG
../data\volcano\n09472597_10029.JPEG
../data\volcano\n09472597_10031.JPEG
../data\volcano\n09472597_10043.JPEG
../data\volcano\n09472597_10084.JPEG
../data\volcano\n09472597_10162.JPEG
../data\volcano\n09472597_10230.JPEG
../data\volcano\n09472597_10231.JPEG
../data\volcano\n09472597_10288.JPEG
../data\volcano\n09472597_10298.JPEG
../data\volcano\n09472597_10361.JPEG
../data\volcano\n09472597_10388.JPEG
../data\volcano\n09472597_10394.JPEG
../data\volcano\n09472597_10417.JPEG
../data\volcano\n09472597_10470.JPEG
../data\volcano\n09472597_10474.JPEG
../data\volcano\n09472597_10505.JPEG
../data\volcano\n09472597_10532.JPEG
../data\volcano\n09472597_10539.JPEG
../data\volcano\n09472597_10555.JPEG
../data\volcano\n09472597_1057.JPEG
../data\volcano\n09472597_10593.JPEG
../data\volcano\n09472597_10600.JPEG
../data\volcano\n09472597_10612.JPEG
../data\volcano\n09472597_1062.JPEG
../data\volcano\n09472597_10621.JPEG
../data\volcano\n09472597_10649.JPEG
../

In [5]:
# Explain the right and wrong predictions
he = HierarchicalExplanation(h=h, model=model, layer='fc', n_steps=5, load_save=False, latex_output=True)

# Right
explanations = he.explain(input_tensors=class_images_correct, input_class_name=class_name, input_idx=class_idx, get_concepts_from_name=lambda x: concepts[x] if x in concepts else random_concepts[int(x.replace("random_", ""))])
print(explanations)
long_form = he.long_form_explanations(explanations, class_name)
print(long_form)



[{'level_name': 'entity', 'children': [('equipment', tensor(-0.1484)), ('geological formation', tensor(0.7627)), ('organism', tensor(-0.2419)), ('random_0', tensor(-0.5496))], 'pval': 0.0008888608871426576}, {'level_name': 'geological formation', 'children': [('valley', tensor(-0.0473)), ('volcano', tensor(0.8725)), ('random_1', tensor(-0.4181))], 'pval': 0.02700740482108153}]
volcano  &  entity $\rightarrow$ geological formation (0.00089) $\rightarrow$ volcano (0.02701) \\
The input is predicted to be a(n) volcano (p-value: 0.0270).
It is a(n) volcano because out of all geological formations, volcano has the highest score among sub-classes:  (p-value: 0.02701)

                \begin{table}[H]
                \begin{tabular}{l|r}
                \toprule
                \textbf{Concept Name} & \textbf{CAV Score}\\
                \midrule
                valley & -0.04731\\ 
volcano & 0.87246\\ 
random_1 & -0.41811\\ 

                \bottomrule 
                \end{tabular}
       

In [6]:
# Wrong
he = HierarchicalExplanation(h=h, model=model, layer='fc', n_steps=5, load_save=False, latex_output=True)

explanations = he.explain(input_tensors=class_images_wrong, input_class_name=class_name, input_idx=class_idx, get_concepts_from_name=lambda x: concepts[x] if x in concepts else random_concepts[int(x.replace("random_", ""))])
print(explanations)
long_form = he.long_form_explanations(explanations, class_name)
print(long_form)

[{'level_name': 'entity', 'children': [('equipment', tensor(-0.1484)), ('geological formation', tensor(0.7627)), ('organism', tensor(-0.2419)), ('random_0', tensor(-0.5496))], 'pval': 0.0008888608301931007}, {'level_name': 'geological formation', 'children': [('valley', tensor(-0.0473)), ('volcano', tensor(0.8725)), ('random_1', tensor(-0.4181))], 'pval': 0.027007407617277324}]
volcano  &  entity $\rightarrow$ geological formation (0.00089) $\rightarrow$ volcano (0.02701) \\
The input is predicted to be a(n) volcano (p-value: 0.0270).
It is a(n) volcano because out of all geological formations, volcano has the highest score among sub-classes:  (p-value: 0.02701)

                \begin{table}[H]
                \begin{tabular}{l|r}
                \toprule
                \textbf{Concept Name} & \textbf{CAV Score}\\
                \midrule
                valley & -0.04731\\ 
volcano & 0.87246\\ 
random_1 & -0.41811\\ 

                \bottomrule 
                \end{tabular}
      

In [9]:
h.get_leaf_nodes()

['tiger',
 'tabby',
 'Siberian husky',
 'dalmatian',
 'golden retriever',
 'white wolf',
 'volcano',
 'valley',
 'pay-phone',
 'computer keyboard',
 'oscilloscope',
 'jigsaw puzzle',
 'crossword puzzle',
 'soccer ball',
 'basketball',
 'golf ball']

In [10]:
# Explain images from another class entirely
he = HierarchicalExplanation(h=h, model=model, layer='fc', n_steps=5, load_save=False, latex_output=True)
incorrect_class_name = 'computer keyboard'
incorrect_idx = h.imagenet_label2idx[class_name]

class_name = 'volcano'
class_images = load_image_tensors(class_name, root_path=concepts_path, transform=False, count=200)
class_tensors = torch.stack([transform(img) for img in class_images])
class_idx = h.imagenet_label2idx[class_name]


explanations = he.explain(input_tensors=class_tensors, input_class_name=incorrect_class_name, input_idx=incorrect_idx, get_concepts_from_name=lambda x: concepts[x] if x in concepts else random_concepts[int(x.replace("random_", ""))])
print(explanations)
long_form = he.long_form_explanations(explanations, incorrect_class_name)
print(long_form)



[{'level_name': 'entity', 'children': [('equipment', tensor(-0.1484)), ('geological formation', tensor(0.7627)), ('organism', tensor(-0.2419)), ('random_0', tensor(-0.5496))], 'pval': 0.7437523459559103}, {'level_name': 'equipment', 'children': [('ball', tensor(-0.1442)), ('puzzle', tensor(-0.1303)), ('electronics', tensor(-0.0219)), ('random_1', tensor(0.3427))], 'pval': 0.7939860789540382}, {'level_name': 'electronics', 'children': [('oscilloscope', tensor(-0.0126)), ('computer keyboard', tensor(-0.0726)), ('pay-phone', tensor(0.0200)), ('random_2', tensor(0.0372))], 'pval': 0.0039249725761268586}]
computer keyboard  &  entity $\rightarrow$ equipment (0.74375) $\rightarrow$ electronics (0.79399) $\rightarrow$ computer keyboard (0.00392) \\
The input is predicted to be a(n) computer keyboard (p-value: 0.0039).
It is a(n) computer keyboard because out of all electronicss, computer keyboard has the highest score among sub-classes:  (p-value: 0.00392)

                \begin{table}[H]
  

The output you provided is from a trained GoogLeNet model that has been given images of a valley (the landform) and asked to explain how it is a Siberian husky. The model uses a hierarchical classification system, in which it first assigns scores to different classes at the highest level of the hierarchy, then repeats the process for sub-classes within those classes, and so on until it reaches the final prediction.

In this case, the model predicts that the input images are of a Siberian husky with a low degree of confidence (p-value: 0.9409). The model arrives at this prediction by first classifying the input as an organism, then a canine, then a dog, and finally as a Siberian husky. For each classification, the model assigns scores to different sub-classes within that category, and the sub-class with the highest score is chosen as the final prediction.

For example, within the category of canines, the model assigns a score of 0.0435 to the sub-class of canine and a score of -0.0337 to the sub-class of feline. Since the score for canine is higher, it is chosen as the final prediction for that category. This process is repeated for each classification until the final prediction of Siberian husky is reached.

However, the high p-values for each classification indicate that the model is not very confident in its predictions. A p-value of 0.9409 for the final prediction of Siberian husky is quite high, which means that there is a high probability that the observed result (the prediction of a Siberian husky) occurred by chance. This indicates that the model is not very confident in its prediction, and the result should be interpreted with caution.

Overall, this output shows how the trained GoogLeNet model uses a hierarchical classification system to arrive at a prediction, but the high p-values suggest that the model is not very confident in its final prediction.

In [None]:
class_images_correct.shape

In [None]:
class_images_wrong.shape