In [1]:
import torch
import torchvision
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.models.inception import Inception3, BasicConv2d
import time

import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.1307,), (0.3081,))])

trainloader = torch.utils.data.DataLoader(
               torchvision.datasets.MNIST(root='./MNIST', train=True,
                                          download=True, transform=transform),
               batch_size=1, shuffle=False)

testloader =  torch.utils.data.DataLoader(
               torchvision.datasets.MNIST(root='./MNIST', train=False,
                                          download=True, transform=transform),
               batch_size=1, shuffle=False)

In [3]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [18, 18, 18, 18], num_classes = 10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return super(MnistResNet, self).forward(x)

In [4]:
class MnistInception(Inception3):
    def __init__(self):
        super(MnistInception, self).__init__(num_classes=10, aux_logits=False)
        self.Conv2d_1a_3x3 = BasicConv2d(1, 32, kernel_size=3, stride=2)
    def forward(self, x):
        x = torch.nn.functional.interpolate(x, size=(229, 229), mode='bilinear')
        return super(MnistInception, self).forward(x)

In [5]:
# Calculate the whichOut row of the Jacobians of net for each element of data_loader
def computeRegions(net, data_loader, whichOut):
    numSamples = len(data_loader)
    grads = np.zeros((len(data_loader.dataset), 28*28+1))
    outputs = np.zeros(len(data_loader.dataset))
    inputs = np.zeros((len(data_loader.dataset), 28*28+1))
    
    outConsidering = torch.LongTensor([whichOut])
    
    for batch_idx, (data, target) in enumerate(data_loader):
        torch.cuda.empty_cache()
        
        data.requires_grad_()
        net.zero_grad()
        
        output = net.forward(data.to(device))
    
        g = torch.autograd.grad(output[:, whichOut], data)[0].data
        g = g.view(-1, 28*28)
        
        
        outputs[batch_idx] = output[:, whichOut].cpu().detach().numpy()
        grads[batch_idx, :] = np.array(g.tolist()[0] + [output[:, whichOut] - torch.dot(torch.squeeze(data.view(-1, 28*28)), torch.squeeze(g))])
    
    return outputs, grads

In [6]:
def test(net, testloader):
    net.eval()
    numRight = 0.0
    for inputs, labels in testloader:
        torch.cuda.empty_cache()
        output = net(inputs.to(device))
        _, pred = torch.max(output, 1)
        numRight += torch.sum(pred == labels.to(device))
    return numRight.item()/len(testloader.dataset)

In [7]:
inception = torch.load('inceptionMNIST.pt')
#resnet = torch.load('resnet110CIFAR.pt')

In [8]:
test(inception, testloader)

  "See the documentation of nn.Upsample for details.".format(mode))


0.9908

In [9]:
start = time.time()
for i in [0]:
    outputs, linearRegions = computeRegions(inception, trainloader, i)
    # Modify the save location to match preferred output.  Large file sizes
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/inceptionMNISTTrainRegion' + str(i) + '.npy', linearRegions)
    print(i, 'done', time.time() - start, 'seconds')

0 done 2197.5527062416077 seconds


In [9]:
start = time.time()
for i in range(0, 10):
    outputs, linearRegions = computeRegions(inception, trainloader, i)
    # Modify the save location to match preferred output.  Large file sizes
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/inceptionMNISTTrainRegion' + str(i) + '.npy', linearRegions)
    print(i, 'done', time.time() - start, 'seconds')
    
start = time.time()
for i in range(10):
    outputs, linearRegions = computeRegions(inception, testloader, i)
    # Modify the save location to match preferred output.  Large file sizes
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/inceptionMNISTTestRegion' + str(i) + '.npy', linearRegions)
    print(i, 'done', time.time() - start, 'seconds')

1 done 2194.2757756710052 seconds
2 done 4390.52343583107 seconds
3 done 6585.8518624305725 seconds
4 done 8773.426416397095 seconds
5 done 10973.865422010422 seconds
6 done 13167.919077157974 seconds
7 done 15359.212124109268 seconds
8 done 17546.686780691147 seconds
9 done 19740.423147201538 seconds
0 done 367.2597146034241 seconds
1 done 733.1342158317566 seconds
2 done 1101.5757358074188 seconds
3 done 1469.7197341918945 seconds
4 done 1836.3726260662079 seconds
5 done 2203.178708791733 seconds
6 done 2569.786785364151 seconds
7 done 2933.3795988559723 seconds
8 done 3298.0682969093323 seconds
9 done 3665.441531419754 seconds


In [10]:
del inception
torch.cuda.empty_cache()

In [11]:
resnet = torch.load('resnet110MNIST.pt')

In [12]:
test(resnet, testloader)

0.9892

In [13]:
start = time.time()
for i in range(10):
    outputs, linearRegions = computeRegions(resnet, trainloader, i)
    # Modify the save location to match preferred output.  Large file sizes
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/resnetMNISTTrainRegion' + str(i) + '.npy', linearRegions)
    print(i, 'done', time.time() - start, 'seconds')

start = time.time()
for i in range(10):
    outputs, linearRegions = computeRegions(resnet, testloader, i)
    # Modify the save location to match preferred output.  Large file sizes
    np.save('/s/red/b/nobackup/data/bsattelb/linearRegions/resnetMNISTTestRegion' + str(i) + '.npy', linearRegions)
    print(i, 'done', time.time() - start, 'seconds')

0 done 3068.696019411087 seconds
1 done 6147.640231847763 seconds
2 done 9227.18725681305 seconds
3 done 12301.23743224144 seconds
4 done 15377.643580913544 seconds
5 done 18450.97530388832 seconds
6 done 21529.149015665054 seconds
7 done 24604.543788194656 seconds
8 done 27680.62859773636 seconds
9 done 30755.660396575928 seconds
0 done 512.361946105957 seconds
1 done 1025.9252994060516 seconds
2 done 1540.211095571518 seconds
3 done 2053.8930168151855 seconds
4 done 2568.3460199832916 seconds
5 done 3082.062150478363 seconds
6 done 3595.60719537735 seconds
7 done 4109.745201826096 seconds
8 done 4625.90789604187 seconds
9 done 5141.268374681473 seconds


In [14]:
del resnet
torch.cuda.empty_cache()