In [5]:
import torch
import torchvision
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.models.inception import Inception3, BasicConv2d
import time

import numpy as np

In [6]:
# Hyperparameters
n_epochs = 30
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 100
width=128

torch.backends.cudnn.enabled = False

In [8]:
# Load training and testing sets
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./MNIST/', train=True, download=False, # Change download to True if not already downloaded
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./MNIST/', train=False, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [9]:
def train(network, lossFunc, optimizer):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = lossFunc(output, target)
        loss.backward()
        optimizer.step()

In [10]:
def test(network, epoch):
    network.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    print('Epoch {}, Test set: Accuracy: {}/{} ({:.0f}%)'.format(
          epoch, correct, len(test_loader.dataset),
          100. * correct / len(test_loader.dataset)))

In [11]:
class NetDense(nn.Module):
    def __init__(self):
        super(NetDense, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

class NetConv(nn.Module):
    def __init__(self):
        super(NetConv, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [12]:
networkDense = NetDense()
optimizerDense = optim.SGD(networkDense.parameters(), lr=learning_rate,
                      momentum=momentum, weight_decay = learning_rate)

networkConv1 = NetConv()
optimizerConv1 = optim.SGD(networkConv1.parameters(), lr=learning_rate,
                      momentum=momentum, weight_decay = learning_rate)

networkConv2 = NetConv()
optimizerConv2 = optim.SGD(networkConv2.parameters(), lr=learning_rate,
                      momentum=momentum, weight_decay = learning_rate)


lossFunc = torch.nn.CrossEntropyLoss()

In [13]:
test(networkDense, 0)
for epoch in range(1, n_epochs + 1):
    train(networkDense, lossFunc, optimizerDense)
    test(networkDense, epoch)

Epoch 0, Test set: Accuracy: 1358/10000 (13%)
Epoch 1, Test set: Accuracy: 9210/10000 (92%)
Epoch 2, Test set: Accuracy: 9388/10000 (93%)
Epoch 3, Test set: Accuracy: 9455/10000 (94%)
Epoch 4, Test set: Accuracy: 9486/10000 (94%)
Epoch 5, Test set: Accuracy: 9493/10000 (94%)
Epoch 6, Test set: Accuracy: 9531/10000 (95%)
Epoch 7, Test set: Accuracy: 9544/10000 (95%)
Epoch 8, Test set: Accuracy: 9545/10000 (95%)
Epoch 9, Test set: Accuracy: 9542/10000 (95%)
Epoch 10, Test set: Accuracy: 9560/10000 (95%)
Epoch 11, Test set: Accuracy: 9563/10000 (95%)
Epoch 12, Test set: Accuracy: 9571/10000 (95%)
Epoch 13, Test set: Accuracy: 9586/10000 (95%)
Epoch 14, Test set: Accuracy: 9586/10000 (95%)
Epoch 15, Test set: Accuracy: 9585/10000 (95%)
Epoch 16, Test set: Accuracy: 9589/10000 (95%)
Epoch 17, Test set: Accuracy: 9587/10000 (95%)
Epoch 18, Test set: Accuracy: 9590/10000 (95%)
Epoch 19, Test set: Accuracy: 9600/10000 (96%)
Epoch 20, Test set: Accuracy: 9597/10000 (95%)
Epoch 21, Test set: Acc

In [14]:
test(networkConv1, 0)
for epoch in range(1, n_epochs + 1):
    train(networkConv1, lossFunc, optimizerConv1)
    test(networkConv1, epoch)

Epoch 0, Test set: Accuracy: 1000/10000 (10%)
Epoch 1, Test set: Accuracy: 9370/10000 (93%)
Epoch 2, Test set: Accuracy: 9589/10000 (95%)
Epoch 3, Test set: Accuracy: 9652/10000 (96%)
Epoch 4, Test set: Accuracy: 9679/10000 (96%)
Epoch 5, Test set: Accuracy: 9733/10000 (97%)
Epoch 6, Test set: Accuracy: 9741/10000 (97%)
Epoch 7, Test set: Accuracy: 9763/10000 (97%)
Epoch 8, Test set: Accuracy: 9771/10000 (97%)
Epoch 9, Test set: Accuracy: 9774/10000 (97%)
Epoch 10, Test set: Accuracy: 9765/10000 (97%)
Epoch 11, Test set: Accuracy: 9774/10000 (97%)
Epoch 12, Test set: Accuracy: 9776/10000 (97%)
Epoch 13, Test set: Accuracy: 9787/10000 (97%)
Epoch 14, Test set: Accuracy: 9779/10000 (97%)
Epoch 15, Test set: Accuracy: 9770/10000 (97%)
Epoch 16, Test set: Accuracy: 9802/10000 (98%)
Epoch 17, Test set: Accuracy: 9794/10000 (97%)
Epoch 18, Test set: Accuracy: 9796/10000 (97%)
Epoch 19, Test set: Accuracy: 9795/10000 (97%)
Epoch 20, Test set: Accuracy: 9798/10000 (97%)
Epoch 21, Test set: Acc

In [15]:
test(networkConv2, 0)
for epoch in range(1, n_epochs + 1):
    train(networkConv2, lossFunc, optimizerConv2)
    test(networkConv2, epoch)

Epoch 0, Test set: Accuracy: 998/10000 (9%)
Epoch 1, Test set: Accuracy: 9387/10000 (93%)
Epoch 2, Test set: Accuracy: 9601/10000 (96%)
Epoch 3, Test set: Accuracy: 9677/10000 (96%)
Epoch 4, Test set: Accuracy: 9703/10000 (97%)
Epoch 5, Test set: Accuracy: 9707/10000 (97%)
Epoch 6, Test set: Accuracy: 9742/10000 (97%)
Epoch 7, Test set: Accuracy: 9746/10000 (97%)
Epoch 8, Test set: Accuracy: 9764/10000 (97%)
Epoch 9, Test set: Accuracy: 9766/10000 (97%)
Epoch 10, Test set: Accuracy: 9750/10000 (97%)
Epoch 11, Test set: Accuracy: 9752/10000 (97%)
Epoch 12, Test set: Accuracy: 9760/10000 (97%)
Epoch 13, Test set: Accuracy: 9792/10000 (97%)
Epoch 14, Test set: Accuracy: 9779/10000 (97%)
Epoch 15, Test set: Accuracy: 9784/10000 (97%)
Epoch 16, Test set: Accuracy: 9778/10000 (97%)
Epoch 17, Test set: Accuracy: 9776/10000 (97%)
Epoch 18, Test set: Accuracy: 9778/10000 (97%)
Epoch 19, Test set: Accuracy: 9787/10000 (97%)
Epoch 20, Test set: Accuracy: 9793/10000 (97%)
Epoch 21, Test set: Accur

In [17]:
# Load training and testing with batch size 1
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./MNIST/', train=True, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=1, shuffle=False)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./MNIST/', train=False, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=1, shuffle=False)

In [18]:
# This calculates the whichOut row of the Jacobians for a given net and set of data
def computeAverage(net, data_loader, whichOut):
    numSamples = len(data_loader)
    grads = np.zeros((numSamples, 784 + 1))
    outputs = np.zeros(numSamples)
    
    outConsidering = torch.LongTensor([whichOut])
    
    for batch_idx, (data, target) in enumerate(data_loader):
        data.requires_grad_()
        net.zero_grad()
        
        output = net.forward(data)
    
        g = torch.autograd.grad(output[:, whichOut], data)[0].data
        g = g.view(-1, 28*28)
        
        
        outputs[batch_idx] = output[:, whichOut].detach().numpy()
        grads[batch_idx, :] = np.array(g.tolist()[0] + [output[:, whichOut] - torch.dot(torch.squeeze(data.view(-1, 28*28)), torch.squeeze(g))])
    
    return outputs, grads

In [19]:
# Modify the save locations to a preferred location
# Warning - large file sizes
for i in range(10):
    _, linearRegionsDense = computeAverage(networkDense, train_loader, i)
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/denseMNISTTrainRegion' + str(i) + '.npy', linearRegionsDense)
    del linearRegionsDense
    _, linearRegionsConv1 = computeAverage(networkConv1, train_loader, i)
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/conv1MNISTTrainRegion' + str(i) + '.npy', linearRegionsConv1)
    del linearRegionsConv1
    _, linearRegionsConv2 = computeAverage(networkConv2, train_loader, i)
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/conv2MNISTTrainRegion' + str(i) + '.npy', linearRegionsConv2)
    del linearRegionsConv2
    
    _, linearRegionsDense = computeAverage(networkDense, test_loader, i)
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/denseMNISTTestRegion' + str(i) + '.npy', linearRegionsDense)
    del linearRegionsDense
    _, linearRegionsConv1 = computeAverage(networkConv1, test_loader, i)
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/conv1MNISTTestRegion' + str(i) + '.npy', linearRegionsConv1)
    del linearRegionsConv1
    _, linearRegionsConv2 = computeAverage(networkConv2, test_loader, i)
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/conv2MNISTTestRegion' + str(i) + '.npy', linearRegionsConv2)
    del linearRegionsConv2