In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from concept.model import ImageModelWrapper
from concept.activation_generator import ImageActivationGenerator
from concept.tcav import TCAV
from concept.ace import ACE
import concept.utils as utils
import pickle

In [3]:
class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.conv1 = nn.Conv2d(3, 32, 3, 1)
      self.conv2 = nn.Conv2d(32, 64, 3, 1)
      self.dropout1 = nn.Dropout(0.25)
      self.dropout2 = nn.Dropout(0.5)
      self.fc1 = nn.Linear(9216, 128)
      self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
      x = self.conv1(x)
      x = F.relu(x)
      x = self.conv2(x)
      x = F.relu(x)
      x = F.max_pool2d(x, 2)
      x = self.dropout1(x)
      x = torch.flatten(x, 1)
      x = self.fc1(x)
      x = F.relu(x)
      x = self.dropout2(x)
      x = self.fc2(x)
      
      return x
    

In [4]:
net = Net()
transform = transforms.Compose([
    transforms.Resize(28, 28),
    transforms.ToTensor()
])

model = ImageModelWrapper(
    model=net,
    state_dict_path='data/models/simple_conv_net_mnist_ep19.pkl',
    image_shape=(28, 28),
    labels_path='data/MNIST/data/classes.txt'
)

generator = ImageActivationGenerator(
    model=model,
    source_dir='data/MNIST/data',
    working_dir='data/MNIST',
    max_examples=50,
    transform=transform,
)

Loaded model data/models/simple_conv_net_mnist_ep19.pkl


In [5]:
all_concepts = ['zero', 'one', 'two', 'three', 'four',
                'five', 'six', 'seven', 'eight', 'nine']
tcav = TCAV(target='one',
            concepts=['one', 'two'],
            bottlenecks=['conv2'],
            activation_generator=generator,
            alphas=[0.01],
            random_counterpart=None,
            working_dir='data/MNIST',
            num_random_exp=10,
            random_concepts=None)

tcav_result = tcav.run(num_workers=10,
                       overwrite=True)
with open('result.pkl', 'wb') as f:
    pickle.dump(tcav_result, f)

with open('result.pkl', 'rb') as f:
    tcav_result = pickle.load(f)
utils.print_results(results=tcav_result,
                    random_counterpart=None,
                    random_concepts=None,
                    min_p_val=0.05)

conv2 ['one', 'random500_0'] one 0.01
conv2 ['one', 'random500_1'] one 0.01
conv2 ['one', 'random500_2'] one 0.01
conv2 ['one', 'random500_3'] one 0.01
conv2 ['one', 'random500_4'] one 0.01
conv2 ['one', 'random500_5'] one 0.01
conv2 ['one', 'random500_6'] one 0.01
conv2 ['one', 'random500_7'] one 0.01
conv2 ['one', 'random500_8'] one 0.01
conv2 ['one', 'random500_9'] one 0.01
conv2 ['two', 'random500_0'] one 0.01
conv2 ['two', 'random500_1'] one 0.01
conv2 ['two', 'random500_2'] one 0.01
conv2 ['two', 'random500_3'] one 0.01
conv2 ['two', 'random500_4'] one 0.01
conv2 ['two', 'random500_5'] one 0.01
conv2 ['two', 'random500_6'] one 0.01
conv2 ['two', 'random500_7'] one 0.01
conv2 ['two', 'random500_8'] one 0.01
conv2 ['two', 'random500_9'] one 0.01
conv2 ['random500_0', 'random500_1'] one 0.01
conv2 ['random500_0', 'random500_2'] one 0.01
conv2 ['random500_0', 'random500_3'] one 0.01
conv2 ['random500_0', 'random500_4'] one 0.01
conv2 ['random500_0', 'random500_5'] one 0.01
conv2 ['ra