In [6]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

from captum.insights import AttributionVisualizer, Batch
from captum.insights.features import ImageFeature

from models import *

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

cudnn.benchmark = True

In [16]:
def get_model():
    basic_complete = CNN()
    checkpoint_complete = torch.load('./checkpoint/basic_training_with_softmax')
    basic_complete.load_state_dict(checkpoint_complete['net'])
    basic_complete.eval()
    return basic_complete

def baseline_func(input):
    return input * 0

def formatted_data_iter():
    dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transforms.ToTensor()
    )
    dataloader = iter(
        torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False, num_workers=2)
    )
    while True:
        images, labels = next(dataloader)
        yield Batch(inputs=images, labels=labels)
        
def get_classes():
    classes = [
        "Plane",
        "Car",
        "Bird",
        "Cat",
        "Deer",
        "Dog",
        "Frog",
        "Horse",
        "Ship",
        "Truck",
    ]
    return classes

In [17]:
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
model = get_model()
visualizer = AttributionVisualizer(
        models=[model],
        #score_func=lambda o: torch.nn.functional.softmax(o, 1),
        classes=get_classes(),
        features=[
            ImageFeature(
                "Input",
                baseline_transforms=[baseline_func],
                input_transforms=[normalize],
            )
        ],
        dataset=formatted_data_iter(),
    )

visualizer.render()

CaptumInsights(insights_config={'classes': ['Plane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Sh…

Output()