In [1]:
# Import needed files and basic setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

import data_gen2
import tropical

from ipywidgets import Output
from IPython.display import display, Markdown, Latex, Math, clear_output

from sklearn.neighbors import NearestNeighbors

import math

from cvxopt import solvers, matrix

import time

%matplotlib notebook
#plt.ion()

In [2]:
# Hyperparameters
n_epochs = 100
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 100

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f66f46517b0>

In [3]:
# Load training and testing sets
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

In [5]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum, weight_decay=learning_rate)

In [6]:
train_losses = []
train_acc = []
train_counter = []
test_losses = []
test_acc = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [7]:
def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            pred = output.data.max(1, keepdim=True)[1]
            correct = pred.eq(target.data.view_as(pred)).sum()
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format(
                  epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader), loss.item(), 100. * correct / 64))
            
            train_losses.append(loss.item())
            train_acc.append(100. * correct / 64)
            train_counter.append((batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [8]:
def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    test_acc.append(100. * correct / len(test_loader.dataset))
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
          test_loss, correct, len(test_loader.dataset),
          100. * correct / len(test_loader.dataset)))

In [9]:
test()
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

  if sys.path[0] == '':



Test set: Avg. loss: 2.2975, Accuracy: 1076/10000 (10%)


Test set: Avg. loss: 0.2884, Accuracy: 9213/10000 (92%)


Test set: Avg. loss: 0.2341, Accuracy: 9359/10000 (93%)


Test set: Avg. loss: 0.2095, Accuracy: 9426/10000 (94%)


Test set: Avg. loss: 0.1926, Accuracy: 9477/10000 (94%)


Test set: Avg. loss: 0.1845, Accuracy: 9481/10000 (94%)


Test set: Avg. loss: 0.1766, Accuracy: 9520/10000 (95%)


Test set: Avg. loss: 0.1718, Accuracy: 9542/10000 (95%)


Test set: Avg. loss: 0.1707, Accuracy: 9525/10000 (95%)


Test set: Avg. loss: 0.1656, Accuracy: 9540/10000 (95%)


Test set: Avg. loss: 0.1632, Accuracy: 9556/10000 (95%)




Test set: Avg. loss: 0.1608, Accuracy: 9570/10000 (95%)


Test set: Avg. loss: 0.1583, Accuracy: 9570/10000 (95%)


Test set: Avg. loss: 0.1611, Accuracy: 9571/10000 (95%)


Test set: Avg. loss: 0.1586, Accuracy: 9583/10000 (95%)


Test set: Avg. loss: 0.1582, Accuracy: 9575/10000 (95%)


Test set: Avg. loss: 0.1570, Accuracy: 9578/10000 (95%)


Test set: Avg. loss: 0.1566, Accuracy: 9582/10000 (95%)


Test set: Avg. loss: 0.1546, Accuracy: 9593/10000 (95%)


Test set: Avg. loss: 0.1554, Accuracy: 9582/10000 (95%)


Test set: Avg. loss: 0.1535, Accuracy: 9592/10000 (95%)


Test set: Avg. loss: 0.1538, Accuracy: 9593/10000 (95%)




Test set: Avg. loss: 0.1540, Accuracy: 9594/10000 (95%)


Test set: Avg. loss: 0.1519, Accuracy: 9605/10000 (96%)


Test set: Avg. loss: 0.1519, Accuracy: 9592/10000 (95%)


Test set: Avg. loss: 0.1512, Accuracy: 9595/10000 (95%)


Test set: Avg. loss: 0.1507, Accuracy: 9597/10000 (95%)


Test set: Avg. loss: 0.1542, Accuracy: 9590/10000 (95%)


Test set: Avg. loss: 0.1520, Accuracy: 9611/10000 (96%)


Test set: Avg. loss: 0.1505, Accuracy: 9604/10000 (96%)


Test set: Avg. loss: 0.1526, Accuracy: 9591/10000 (95%)


Test set: Avg. loss: 0.1509, Accuracy: 9596/10000 (95%)


Test set: Avg. loss: 0.1512, Accuracy: 9591/10000 (95%)




Test set: Avg. loss: 0.1517, Accuracy: 9604/10000 (96%)


Test set: Avg. loss: 0.1502, Accuracy: 9595/10000 (95%)


Test set: Avg. loss: 0.1477, Accuracy: 9602/10000 (96%)


Test set: Avg. loss: 0.1504, Accuracy: 9616/10000 (96%)


Test set: Avg. loss: 0.1490, Accuracy: 9610/10000 (96%)


Test set: Avg. loss: 0.1495, Accuracy: 9595/10000 (95%)


Test set: Avg. loss: 0.1492, Accuracy: 9602/10000 (96%)


Test set: Avg. loss: 0.1481, Accuracy: 9616/10000 (96%)


Test set: Avg. loss: 0.1486, Accuracy: 9602/10000 (96%)


Test set: Avg. loss: 0.1514, Accuracy: 9603/10000 (96%)


Test set: Avg. loss: 0.1486, Accuracy: 9612/10000 (96%)




Test set: Avg. loss: 0.1488, Accuracy: 9609/10000 (96%)


Test set: Avg. loss: 0.1497, Accuracy: 9612/10000 (96%)


Test set: Avg. loss: 0.1503, Accuracy: 9602/10000 (96%)


Test set: Avg. loss: 0.1497, Accuracy: 9604/10000 (96%)


Test set: Avg. loss: 0.1491, Accuracy: 9604/10000 (96%)


Test set: Avg. loss: 0.1492, Accuracy: 9599/10000 (95%)


Test set: Avg. loss: 0.1478, Accuracy: 9611/10000 (96%)


Test set: Avg. loss: 0.1464, Accuracy: 9619/10000 (96%)


Test set: Avg. loss: 0.1485, Accuracy: 9610/10000 (96%)


Test set: Avg. loss: 0.1471, Accuracy: 9625/10000 (96%)


Test set: Avg. loss: 0.1477, Accuracy: 9615/10000 (96%)




Test set: Avg. loss: 0.1469, Accuracy: 9618/10000 (96%)


Test set: Avg. loss: 0.1465, Accuracy: 9620/10000 (96%)


Test set: Avg. loss: 0.1466, Accuracy: 9613/10000 (96%)


Test set: Avg. loss: 0.1484, Accuracy: 9610/10000 (96%)


Test set: Avg. loss: 0.1472, Accuracy: 9606/10000 (96%)


Test set: Avg. loss: 0.1481, Accuracy: 9617/10000 (96%)


Test set: Avg. loss: 0.1478, Accuracy: 9619/10000 (96%)


Test set: Avg. loss: 0.1485, Accuracy: 9620/10000 (96%)


Test set: Avg. loss: 0.1476, Accuracy: 9617/10000 (96%)


Test set: Avg. loss: 0.1486, Accuracy: 9607/10000 (96%)




Test set: Avg. loss: 0.1472, Accuracy: 9615/10000 (96%)


Test set: Avg. loss: 0.1463, Accuracy: 9624/10000 (96%)


Test set: Avg. loss: 0.1480, Accuracy: 9617/10000 (96%)


Test set: Avg. loss: 0.1472, Accuracy: 9614/10000 (96%)


Test set: Avg. loss: 0.1470, Accuracy: 9603/10000 (96%)


Test set: Avg. loss: 0.1482, Accuracy: 9618/10000 (96%)


Test set: Avg. loss: 0.1466, Accuracy: 9612/10000 (96%)


Test set: Avg. loss: 0.1467, Accuracy: 9620/10000 (96%)


Test set: Avg. loss: 0.1452, Accuracy: 9612/10000 (96%)


Test set: Avg. loss: 0.1456, Accuracy: 9616/10000 (96%)


Test set: Avg. loss: 0.1474, Accuracy: 9615/10000 (96%)




Test set: Avg. loss: 0.1449, Accuracy: 9624/10000 (96%)


Test set: Avg. loss: 0.1461, Accuracy: 9622/10000 (96%)


Test set: Avg. loss: 0.1488, Accuracy: 9613/10000 (96%)


Test set: Avg. loss: 0.1463, Accuracy: 9615/10000 (96%)


Test set: Avg. loss: 0.1461, Accuracy: 9621/10000 (96%)


Test set: Avg. loss: 0.1507, Accuracy: 9607/10000 (96%)


Test set: Avg. loss: 0.1480, Accuracy: 9602/10000 (96%)


Test set: Avg. loss: 0.1484, Accuracy: 9616/10000 (96%)


Test set: Avg. loss: 0.1463, Accuracy: 9605/10000 (96%)


Test set: Avg. loss: 0.1471, Accuracy: 9621/10000 (96%)


Test set: Avg. loss: 0.1469, Accuracy: 9614/10000 (96%)




Test set: Avg. loss: 0.1477, Accuracy: 9608/10000 (96%)


Test set: Avg. loss: 0.1470, Accuracy: 9620/10000 (96%)


Test set: Avg. loss: 0.1462, Accuracy: 9605/10000 (96%)


Test set: Avg. loss: 0.1481, Accuracy: 9606/10000 (96%)


Test set: Avg. loss: 0.1488, Accuracy: 9611/10000 (96%)


Test set: Avg. loss: 0.1472, Accuracy: 9601/10000 (96%)


Test set: Avg. loss: 0.1460, Accuracy: 9624/10000 (96%)


Test set: Avg. loss: 0.1460, Accuracy: 9621/10000 (96%)


Test set: Avg. loss: 0.1486, Accuracy: 9614/10000 (96%)


Test set: Avg. loss: 0.1493, Accuracy: 9606/10000 (96%)


Test set: Avg. loss: 0.1464, Accuracy: 9616/10000 (96%)




Test set: Avg. loss: 0.1462, Accuracy: 9612/10000 (96%)


Test set: Avg. loss: 0.1478, Accuracy: 9619/10000 (96%)


Test set: Avg. loss: 0.1462, Accuracy: 9607/10000 (96%)



In [10]:
fig = plt.figure()
plt.plot(train_counter, train_acc, color='blue')
plt.plot(test_counter, test_acc, 'o', color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')

<IPython.core.display.Javascript object>

Text(0,0.5,'negative log likelihood loss')

In [11]:
params = []
for param in network.parameters():
    params.append(param.detach().numpy())

A1 = params[0]
b1 = params[1]
A2 = params[2]
b2 = params[3]

In [12]:
np.savez('parametersRegularizedMore', A1=A1, b1=b1, A2=A2, b2=b2)