In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from concept.model import ImageModelWrapper
from concept.activation_generator import ImageActivationGenerator
from concept.tcav import TCAV
from concept.ace import ACE
import concept.utils as utils
import pickle

In [3]:
class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.conv1 = nn.Conv2d(3, 32, 3, 1)
      self.conv2 = nn.Conv2d(32, 64, 3, 1)
      self.dropout1 = nn.Dropout(0.25)
      self.dropout2 = nn.Dropout(0.5)
      self.fc1 = nn.Linear(9216, 128)
      self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
      x = self.conv1(x)
      x = F.relu(x)
      x = self.conv2(x)
      x = F.relu(x)
      x = F.max_pool2d(x, 2)
      x = self.dropout1(x)
      x = torch.flatten(x, 1)
      x = self.fc1(x)
      x = F.relu(x)
      x = self.dropout2(x)
      x = self.fc2(x)
      
      return x
    

In [4]:
net = Net()
transform = transforms.Compose([
    transforms.Resize(28, 28),
    transforms.ToTensor()
])

model = ImageModelWrapper(
    model=net,
    state_dict_path='data/models/simple_conv_net_mnist_ep19.pkl',
    image_shape=(28, 28),
    labels_path='data/MNIST/data/classes.txt'
)

generator = ImageActivationGenerator(
    model=model,
    source_dir='data/MNIST/concept',
    working_dir='data/MNIST',
    max_examples=50,
    transform=transform,
)

Loaded model data/models/simple_conv_net_mnist_ep19.pkl


In [5]:
ace = ACE(activation_generator=generator,
          target_class='random500_0',
          random_concept=None,
          bottlenecks=['conv2'],
          source_dir='data/MNIST/concepts',
          working_dir='data/MNIST/ace',
          num_random_exp=5,
          channel_mean=True,
          max_imgs=40,
          min_imgs=20,
          num_discovery_imgs=5,
          num_workers=3,
          average_image_value=117)

In [6]:
ace.create_patches()

  sigma=sigmas[i])


Created 615 patches using slic segmentation method


In [7]:
ace.discover_concepts() 

Created 612 patches using slic segmentation method
Starting clustering with KM for 612 activations
Created 25 clusters
[ 78  95 179 182 193 307 360 517 560 571 600 603]
[533 534 535 539 540 541 542 543 544 545 548 549 550 551 552 582 587 589
 590 591 593 594 595 596 597 598 609 610]
[134 137 138 152 157 158 159 162 163 164 166 167 168 169 172 173 176 177
 198 199 206 207 210 211 213 214 217 218 222 223 224 225 234 236 238 239
 240 289 531]
[  7  10  66  67  68  73  85 108 109 114 115 126 184 189 251 255 257 309
 316 317 320 326 334 354 355 357 359 365 383 389 390 391 392 394 397 399
 400 401 402 410 436 439 441 443 446 447 449 450 453 455 460 462 468 472
 478 480 485 505 506 528 529 538 558]
[ 11  74 122 148 178 201 227 265 276 285 413 425 465 470 471 473 536 581
 607]
[ 20  26  27  84 113 141 160 194 422 427 474 498]
[514 515 572 573 576 602]
[156 165 174 175 205 212 219 220 221 233 237 283 287 290 291 292 293 294
 296 297 298 299 300 301 302 339 340 342 343 344 345 346 347 348 349 36

In [11]:
# tcav = TCAV(target='one',
#             concepts=['one', 'two'],
#             bottlenecks=['conv2'],
#             activation_generator=generator,
#             alphas=[0.01],
#             random_counterpart=None,
#             working_dir='data/MNIST',
#             num_random_exp=10,
#             random_concepts=None)

ace.generate_tcavs()

['random500_0_concept1',
 'random500_0_concept2',
 'random500_0_concept3',
 'random500_0_concept4']