In [1]:
from pathlib import Path
from torchvision import transforms
import PIL
import captum as cp
import captum.concept._utils.data_iterator as di
import glob
import torch
import torchvision

In [8]:
def load_tensor(filename):
    img = PIL.Image.open(filename).convert("RGB")
    return transform(img)

def transform(img):
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )(img)

def load_tensors(class_name, root_path="data/concepts/", transform=True):
    path = Path(root_path) / class_name
    filenames = glob.glob(str(path / '*.jpg'))

    tensors = []
    for filename in filenames:
        img = PIL.Image.open(filename).convert('RGB')
        tensors.append(transform(img) if transform else img)
    return tensors

def assemble_concept(name, id, concept_path):
    dataset = di.CustomIterableDataset(load_tensor, f"{str(concept_path / name)}/")
    concept_iter = di.dataset_to_dataloader(dataset)
    return cp.concept.Concept(id=id, name=name, data_iter=concept_iter)

In [9]:
concept_path = Path("data/concepts/")
stripes_concept = assemble_concept("striped", 0, concept_path=concept_path)
dotted_concept = assemble_concept("dotted", 1, concept_path=concept_path)
random0_concept = assemble_concept("random500_0", 2, concept_path=concept_path)
random1_concept = assemble_concept("random500_1", 3, concept_path=concept_path)
random2_concept = assemble_concept("random500_2", 4, concept_path=concept_path)

In [10]:
model = torchvision.models.googlenet(pretrained=True)
model = model.eval()

In [11]:
layers=['inception4c', 'inception4d', 'inception4e']

mytcav = cp.concept.TCAV(
    model=model, 
    layers=layers,
    layer_attr_method=cp.attr.LayerIntegratedGradients(model, None, multiply_by_inputs=False)
)

In [12]:
zebra_images = load_tensors('zebra', transform=False)
zebra_tensors = torch.stack([transform(img) for img in zebra_images])

In [None]:
classification_data = [[stripes_concept, random0_concept]]

ix = 340
tcav_scores = mytcav.interpret(
    inputs=zebra_tensors, 
    experimental_sets=classification_data,
    n_steps=5,
    target=ix
)

In [None]:
tcav_scores

In [16]:
classification_data = [[stripes_concept, dotted_concept]]

ix = 340
tcav_scores = mytcav.interpret(
    inputs=zebra_tensors, 
    experimental_sets=classification_data,
    n_steps=5,
    target=ix
)

In [None]:
tcav_scores