In [2]:
from torchvision import transforms
import PIL
import captum as cp
import glob
import pathlib
import torch
import torch.utils.data as dt
import torchvision

In [None]:
def load_tensor(filename):
    img = PIL.Image.open(filename).convert("RGB")
    return transform(img)

def transform(img):
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )(img)

def load_tensors(class_name, root_path='data/tcav/image/imagenet/', transform=True):
    path = pathlib.Path(root_path) / class_name
    filenames = glob.glob(path / '*.jpg')

    tensors = []
    for filename in filenames:
        img = PIL.Image.open(filename).convert('RGB')
        tensors.append(transform(img) if transform else img)
    return tensors

def assemble_concept(name, id, concept_path):
    dataset = dt.CustomIterableDataset(load_tensor, concept_path / name)
    concept_iter = dt.dataset_to_dataloader(dataset)
    return cp.concept.Concept(id=id, name=name, data_iter=concept_iter)

In [None]:
concepts_path = pathlib.Path("data/tcav/image/concepts/")
stripes_concept = assemble_concept("striped", 0, concepts_path=concepts_path)
dotted_concept = assemble_concept("dotted", 1, concepts_path=concepts_path)
random_concept = assemble_concept("random", 2, concepts_path=concepts_path)

In [None]:
model = torchvision.models.googlenet(pretrained=True)
model = model.eval()

In [None]:
layers=['inception4c', 'inception4d', 'inception4e']

mytcav = cp.concept.TCAV(
    model=model, 
    layers=layers,
    layer_attr_method=cp.LayerIntegratedGradients(model, None, multiply_by_inputs=False)
)

In [None]:
zebra_images = load_tensors('zebra', transform=False)
zebra_tensors = torch.stack([transform(img) for img in zebra_images])

In [None]:
classification_data = [[stripes_concept, random_concept]]

ix = 0
tcav_scores_w_random = mytcav.interpret(
    inputs=zebra_tensors, 
    experimental_sets=classification_data,
    target=ix
)

In [None]:
classification_data = [[stripes_concept, dotted_concept]]
tcav_scores_w_random = mytcav.interpret(
    inputs=zebra_tensors, 
    experimental_sets=classification_data,
    target=ix
)