## Set up

In [None]:
import os
base_dir = os.path.normpath(os.getcwd() + os.sep + os.pardir) 

In [None]:
import PIL.Image as Image
import os
import matplotlib.pyplot as plt
import numpy as np
import glob
import torch
import torchvision
from torchvision import transforms

from captum.concept import TCAV
from captum.concept import Concept
from captum.concept._utils.data_iterator import dataset_to_dataloader, CustomIterableDataset
from captum.concept._utils.common import concepts_to_str
from captum.attr import LayerGradientXActivation, LayerIntegratedGradients

In [None]:
concepts_path = f"{base_dir}/data/test-images/"

## Utilities

In [None]:
transform = transforms.Compose([
 transforms.Resize(256),
 transforms.CenterCrop(224),
 transforms.ToTensor()
])

In [None]:
transform_normalize = transforms.Normalize(
     mean=[0.485, 0.456, 0.406],
     std=[0.229, 0.224, 0.225]
 )

In [None]:
# image for assemble_concept()
def get_tensor_from_filename(filename):
    img = Image.open(filename).convert("RGB")
    return transform(img)

In [None]:
# 
def assemble_concept(name, id, concepts_path=f"{base_dir}/Files/data/"):
    concept_path = os.path.join(concepts_path, name) + "/"
    dataset = CustomIterableDataset(get_tensor_from_filename, concept_path)
    concept_iter = dataset_to_dataloader(dataset)

    return Concept(id=id, name=name, data_iter=concept_iter)

In [None]:
# Load sample images from folder
def load_image_tensors(class_name, root_path, transform=True):
    path = os.path.join(root_path, class_name)
    filenames = glob.glob(path + '/*.jpg')

    tensors = []
    for filename in filenames:
        img = Image.open(filename).convert('RGB')
        tensors.append(transform(img) if transform else img)
    
    return tensors

## Define concepts, read images

In [None]:
cat_concept = assemble_concept("cat", 0, concepts_path=concepts_path)
car_concept = assemble_concept("car", 1, concepts_path=concepts_path)

random_0_concept = assemble_concept("random500_0", 2, concepts_path=concepts_path)
random_1_concept = assemble_concept("random500_1", 3, concepts_path=concepts_path)

In [None]:
n_figs = 5
n_concepts = 4

fig, axs = plt.subplots(n_concepts, n_figs + 1, figsize = (25, 4 * n_concepts))

for c, concept in enumerate([cat_concept, car_concept, random_0_concept, random_1_concept]):
    concept_path = os.path.join(concepts_path, concept.name) + "/"
    img_files = glob.glob(concept_path + '*')
    for i, img_file in enumerate(img_files[:n_figs]):
        if os.path.isfile(img_file):
            #if i == 0:
            axs[c, 0].text(1.0, 0.5, str(concept.name), ha='right', va='center', family='sans-serif', size=24)
            #else:
            img = plt.imread(img_file)
            axs[c, i+1].imshow(img)

            axs[c, i+1].axis('off')

In [None]:
zebra_imgs = load_image_tensors('zebra', transform=False, root_path=f"{base_dir}/Files/data/")

In [None]:
fig, axs = plt.subplots(1, 5, figsize = (25, 5))
axs[0].imshow(zebra_imgs[0])
axs[1].imshow(zebra_imgs[1])
axs[2].imshow(zebra_imgs[2])
axs[3].imshow(zebra_imgs[3])
axs[4].imshow(zebra_imgs[4])

axs[0].axis('off')
axs[1].axis('off')
axs[2].axis('off')
axs[3].axis('off')
axs[4].axis('off')

plt.show()

In [None]:
# Load sample images from folder
zebra_tensors = torch.stack([transform_normalize(transform(img)) for img in zebra_imgs])

In [None]:
# zebra class index
zebra_ind = 340

## TCAV

In [None]:
model = torchvision.models.googlenet(pretrained=True)
model = model.eval()

In [None]:
layers=['inception4c', 'inception4d', 'inception4e']

In [None]:
mytcav = TCAV(model=model,
              layers=layers,
              layer_attr_method = LayerIntegratedGradients(
                model, None, multiply_by_inputs=False),
              save_path=f"{base_dir}/Files/cav/")

In [None]:
experimental_set_rand = [[cat_concept, random_0_concept], [car_concept, random_1_concept], [cat_concept, car_concept]]
experimental_set_rand

In [None]:
mytcav

In [None]:
tcav_scores_w_random = mytcav.interpret(inputs=zebra_tensors,
                                        experimental_sets=experimental_set_rand,
                                        target=zebra_ind,
                                        n_steps=5,
                                       )
tcav_scores_w_random

## Plots

In [None]:
def format_float(f):
    return float('{:.3f}'.format(f) if abs(f) >= 0.0005 else '{:.3e}'.format(f))

def plot_tcav_scores(experimental_sets, tcav_scores):
    fig, ax = plt.subplots(1, len(experimental_sets), figsize = (40, 7))

    barWidth = 1 / (len(experimental_sets[0]) + 1)

    for idx_es, concepts in enumerate(experimental_sets):

        concepts = experimental_sets[idx_es]
        concepts_key = concepts_to_str(concepts)

        pos = [np.arange(len(layers))]
        for i in range(1, len(concepts)):
            pos.append([(x + barWidth) for x in pos[i-1]])
        _ax = (ax[idx_es] if len(experimental_sets) > 1 else ax)
        for i in range(len(concepts)):
            val = [format_float(scores['sign_count'][i]) for layer, scores in tcav_scores[concepts_key].items()]
            _ax.bar(pos[i], val, width=barWidth, edgecolor='white', label=concepts[i].name)

        # Add xticks on the middle of the group bars
        _ax.set_xlabel('Set {}'.format(str(idx_es)), fontweight='bold', fontsize=16)
        _ax.set_xticks([r + 0.3 * barWidth for r in range(len(layers))])
        _ax.set_xticklabels(layers, fontsize=16)

        # Create legend & Show graphic
        _ax.legend(fontsize=16)

    plt.show()

In [None]:
plot_tcav_scores(experimental_set_rand, tcav_scores_w_random)

## Activations

In [None]:
activation = f'{base_dir}/Files/cav/av/default_model_id/car-1/inception4e/0.pt' 
activation = torch.load(activations)
activation
# activations[0].shape 
# 4c torch.Size([100352]) 1*(128+256+64+64)*14*14 
# 4e torch.Size([163072]) 1*(256+320+128+128)*14*14 

In [None]:
activations[0].shape

In [None]:
cav_0_1 = f'{base_dir}/Files/cav/default_model_id/0-1-inception4e.pkl' 
cav_0_1 = torch.load(cav_0_1)

cav_0_2 = f'{base_dir}/Files/cav/default_model_id/0-2-inception4e.pkl' 
cav_0_2 = torch.load(cav_0_2)

cav_1_3 = f'{base_dir}/Files/cav/default_model_id/1-3-inception4e.pkl' 
cav_1_3 = torch.load(cav_1_3)

In [None]:
cav_0_1

In [None]:
cav_0_2

In [None]:
cav_1_3