# Testing with Concept Activation Vectors (TCAV)

Introduced by [Kim et al. (2018)](https://arxiv.org/pdf/1711.11279.pdf)

In [None]:
from torchvision.models import resnet50

model = resnet50(pretrained=True)

In [None]:
#concept = "cup"
concept = "chair"
#k_class = "bottle"
k_class = "dining table"

## Get images for classifier

We will use the [pycocotools](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoDemo.ipynb) API to retrieve some examples of chairs:

In [None]:
from pycocotools.coco import COCO

coco = COCO("./datasets/instances_val2017.json")

Let's get the IDs of the images in the coco dataset that have the concept of _chair_:

In [None]:
concept_cat_id = coco.getCatIds(catNms=[concept])
concept_imgs_ids = coco.getImgIds(catIds=concept_cat_id)
len(concept_imgs_ids)

Let's remove the images where chairs are accompanied by dining table:

In [None]:
overlap_cats_ids = coco.getCatIds(catNms=[concept, k_class])
overlap_imgs_ids = coco.getImgIds(catIds=overlap_cats_ids)
len(overlap_imgs_ids)

In [None]:
concept_imgs_ids = [img_id for img_id in concept_imgs_ids if img_id not in overlap_imgs_ids]
len(concept_imgs_ids)

Let's grab random images for the other class of the classifier:

In [None]:
import random

all_imgs_ids = list(coco.imgs.keys())
not_concept_imgs_ids = [img_id for img_id in all_imgs_ids if img_id not in concept_imgs_ids]
random.Random(0).shuffle(not_concept_imgs_ids)
not_concept_imgs_ids = not_concept_imgs_ids[:len(concept_imgs_ids)]

assert not any(img_id in not_concept_imgs_ids for img_id in concept_imgs_ids)

## Extract activations from images

In [None]:
from torchvision import transforms

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

preprocessing = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

In [None]:
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt

img_id = concept_imgs_ids[18]
img_url =  coco.loadImgs(img_id)[0]["coco_url"]
img = io.imread(img_url)
img_pil = Image.fromarray(img)
prepro_img = preprocessing(img_pil)
#type(img)

plt.axis('off')
plt.imshow(img)
plt.show()

In [None]:
import numpy as np
from PIL import Image
import skimage.io as io

class_ids = np.concatenate(
    (np.zeros(len(concept_imgs_ids)), np.ones(len(not_concept_imgs_ids))),
    axis=0
)

all_imgs_ids = np.concatenate((concept_imgs_ids, not_concept_imgs_ids))
imgs = []
for img_id in all_imgs_ids:
    img_url =  coco.loadImgs(int(img_id))[0]["coco_url"]
    img = io.imread(img_url)
    img_pil = Image.fromarray(img).convert("RGB")
    img_prepro = preprocessing(img_pil)
    img_unsq = img_prepro.unsqueeze(0)
    imgs.append(img_unsq)

In [None]:
import matplotlib.pyplot as plt

img_id = 18
plt.imshow(imgs[img_id].permute(1, 2, 0) )

In [None]:
import torch

imgs_tensor = torch.cat(imgs)
imgs_tensor.shape

In [None]:
from collections import OrderedDict

repr_input = OrderedDict()
repr_output = OrderedDict()

def get_representation(name):
    def hook(model, input, output):
        repr_output[name] = output.detach()
    return hook