# Compress then explain: example with Expected Gradients

Example of CTE with the [`captum` Python package](https://github.com/pytorch/captum) explaining a CNN model trained on the `CIFAR_10` dataset.

#### load packages

In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from goodpoints import compress
import captum

#### load the dataset and model

Following the PyTorch tutorial https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 512

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

In [4]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

In [5]:
for epoch in range(20):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'epoch: {epoch + 1} | loss: {running_loss / len(trainloader):.3f}')
print('Finished Training')

epoch: 1 | loss: 1.919
epoch: 2 | loss: 1.610
epoch: 3 | loss: 1.511
epoch: 4 | loss: 1.442
epoch: 5 | loss: 1.385
epoch: 6 | loss: 1.335
epoch: 7 | loss: 1.297
epoch: 8 | loss: 1.259
epoch: 9 | loss: 1.228
epoch: 10 | loss: 1.200
epoch: 11 | loss: 1.173
epoch: 12 | loss: 1.145
epoch: 13 | loss: 1.128
epoch: 14 | loss: 1.106
epoch: 15 | loss: 1.089
epoch: 16 | loss: 1.067
epoch: 17 | loss: 1.055
epoch: 18 | loss: 1.046
epoch: 19 | loss: 1.028
epoch: 20 | loss: 1.011
Finished Training


In [6]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

Accuracy of the network on the 10000 test images: 60 %


#### compress background data

In [7]:
X_test_2d = testloader.dataset.data.astype(float) / 255
X_test = X_test_2d.reshape(10000, 3*32**2)
n = X_test.shape[0]
d = X_test.shape[1]
sigma = np.sqrt(2 * d)

In [8]:
id_cte = compress.compresspp_kt(X_test, kernel_type=b"gaussian", k_params=np.array([sigma**2]), g=4, seed=0)

#### then explain

In [9]:
explainer = captum.attr.IntegratedGradients(net)
inputs = torch.movedim(torch.as_tensor(X_test_2d, dtype=torch.float), 3, 1)

In [10]:
baselines_cte = torch.movedim(torch.as_tensor(X_test_2d[id_cte], dtype=torch.float), 3, 1)
results = []
for data in testloader:
    inputs, _ = data
    results += [explainer.attribute(inputs, baselines_cte[[i]], target=1) for i in range(baselines_cte.shape[0])]
    break

In [11]:
explanation_cte = torch.mean(torch.stack(results), dim=0)

#### compare with iid sampling

In [12]:
np.random.seed(0)
id_iid = np.random.choice(n, size=len(id_cte))
baselines_iid = torch.movedim(torch.as_tensor(X_test_2d[id_iid], dtype=torch.float), 3, 1)
results = []
for data in testloader:
    inputs, _ = data
    results += [explainer.attribute(inputs, baselines_iid[[i]], target=1) for i in range(baselines_iid.shape[0])]
    break

In [13]:
explanation_iid = torch.mean(torch.stack(results), dim=0)

#### calculate "ground truth"

In [14]:
np.random.seed(0)
id_gt = np.random.choice(n, size=20*len(id_cte))
baselines_gt = torch.movedim(torch.as_tensor(X_test_2d[id_gt], dtype=torch.float), 3, 1)
results = []
for data in testloader:
    inputs, _ = data
    results += [explainer.attribute(inputs, baselines_gt[[i]], target=1) for i in range(baselines_gt.shape[0])]
    break

In [15]:
explanation_gt = torch.mean(torch.stack(results), dim=0)

#### evaluate

In [16]:
def metric_mae(x, y):
    return torch.mean(torch.abs(x-y))

In [17]:
print(f'Explanation approximation error introduced by iid sampling:\
      {metric_mae(explanation_gt, explanation_iid):.4f}')
print(f'Relative improvement by CTE:\
      {100*(metric_mae(explanation_gt, explanation_iid) - metric_mae(explanation_gt, explanation_cte)) / metric_mae(explanation_gt, explanation_iid):.2f}%')

Explanation approximation error introduced by iid sampling:      0.0017
Relative improvement by CTE:      14.16%
