In [None]:
import numpy as np
import torchvision
import torch
from torch import Tensor
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline
torch.cuda.is_available()

In [None]:
class DistInvert(nn.Module):
    def __init__(self):
        super(DistInvert, self).__init__()
    def forward(self, a, b):
        return 1/a.dist(b)
d = DistInvert()
a, b = Variable(Tensor([[8, 11]])), Variable(Tensor([[11, 15]]))
d(a, b)

In [None]:
early = nn.Sequential(
    nn.Conv2d(1, 8, 5),
    nn.ReLU(),
    nn.Conv2d(8, 8, 5),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(8, 8, 3),
    nn.ReLU(),
    nn.Conv2d(8, 8, 3)
)

In [None]:
transform = transforms.ToTensor()

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

In [None]:
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
def imshow(img):
    npimg = img.numpy()
    print(npimg.shape)
    for i in npimg:
        plt.figure()
        plt.imshow(i[0], cmap= 'Greys')

imshow(images)

In [None]:
platonic_np = [np.load("platonic_{}_.npy".format(i)) for i in range(10)]

In [None]:
def imshowArr(img):
    for i in img:
        if i is not None:
            plt.figure()
            plt.imshow(i, cmap= 'Greys')
imshowArr(platonic_np)

In [None]:
class MNISTFeatureExtractor2(nn.Module):
    def __init__(self):
        super(MNISTFeatureExtractor2, self).__init__()
        self.early = early
        self.lin1 = nn.Linear(6 * 6 * 8, 120)
        self.lin2 = nn.Linear(120, 84)
    def forward(self, x):
        x = F.relu(self.early(x))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        return x
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
feature_extractor = MNISTFeatureExtractor2()

In [None]:
class NNLayerNet(nn.Module):
    def __init__(self):
        super(NNLayerNet, self).__init__()
        self.fe = feature_extractor
        self.di = d
    def forward(self, x):
        platonic_hidden = [self.fe(i) for i in platonic]
        x_hidden = self.fe(x)
        ret = [
            [self.di(x_hidden[i], ph[0]).view(1, 1) for i in range(x_hidden.size()[0])]
            for ph in platonic_hidden
        ]
        ret = [torch.cat(r, 0) for r in ret]
        return F.softmax(torch.cat(ret, 1))
        
net = NNLayerNet()
net

In [None]:
platonic = [Variable(Tensor(np.array([[i]]))) for i in platonic_np]

In [None]:
platonic[0].size()

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.01)

optimizer

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
net.cpu()
for i in range(10):
    dataiter = iter(trainloader)
    hist = []
    for i in range(13000):
        if i%100 == 99:
            print(i)
            print(sum(hist) / len(hist))
        input, target = dataiter.next()
        input, target = Variable(input), Variable(target)
        # in your training loop:
        optimizer.zero_grad()   # zero the gradient buffers
        output = net(input)
        loss = criterion(output, target)
        loss.backward()
        hist.append(loss.data[0])
        optimizer.step() 
    correct = 0
    total = 0
    for data in testloader:
        images, labels = data
        outputs = net(Variable(images))
        _, predicted = torch.max(outputs.data, 1)
        #print(predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum()

    print('Accuracy of the network on the test images: ' + str(
        correct / total))

In [None]:
net = torch.load("model1_full")

In [None]:
net.cpu()
correct = 0
total = 0
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    #print(predicted)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy of the network on the test images: ' + str(
    correct / total))

In [None]:
n = iter(testloader).next()
x = Variable(n[0][0:1])
imshow(n[0][0:1])
n[1][0]

In [None]:
platonic_hidden = [net.fe(i) for i in platonic]
x_hidden = net.fe(x)
ret = [
    [net.di(x_hidden[i], ph[0]).view(1, 1) for i in range(x_hidden.size()[0])]
    for ph in platonic_hidden
]
[torch.cat(r, 0) for r in ret]