In [1]:
# https://nextjournal.com/gkoehler/pytorch-mnist

In [2]:
# Import needed files and basic setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

import data_gen2
import tropical

from ipywidgets import Output
from IPython.display import display, Markdown, Latex, Math, clear_output

from sklearn.neighbors import NearestNeighbors

import math

from cvxopt import solvers, matrix

import time

%matplotlib notebook
#plt.ion()

In [3]:
# Hyperparameters
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 100

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f3e880b6b90>

In [4]:
# Load training and testing sets
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [5]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

In [7]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

In [8]:
train_losses = []
train_acc = []
train_counter = []
test_losses = []
test_acc = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [9]:
def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            pred = output.data.max(1, keepdim=True)[1]
            correct = pred.eq(target.data.view_as(pred)).sum()
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format(
                  epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader), loss.item(), 100. * correct / 64))
            
            train_losses.append(loss.item())
            train_acc.append(100. * correct / 64)
            train_counter.append((batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [10]:
def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    test_acc.append(100. * correct / len(test_loader.dataset))
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
          test_loss, correct, len(test_loader.dataset),
          100. * correct / len(test_loader.dataset)))

In [11]:
test()
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

  if sys.path[0] == '':



Test set: Avg. loss: 2.3197, Accuracy: 921/10000 (9%)


Test set: Avg. loss: 0.2731, Accuracy: 9240/10000 (92%)


Test set: Avg. loss: 0.2078, Accuracy: 9387/10000 (93%)


Test set: Avg. loss: 0.1739, Accuracy: 9491/10000 (94%)



In [12]:
# fig = plt.figure()
# plt.plot(train_counter, train_acc, color='blue')
# plt.plot(test_counter, test_acc, 'o', color='red')
# plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
# plt.xlabel('number of training examples seen')
# plt.ylabel('negative log likelihood loss')

In [13]:
params = []
for param in network.parameters():
    params.append(param.detach().numpy())

A1 = params[0]
b1 = params[1]
A2 = params[2][0,:].reshape((1, -1))
b2 = params[3][:1]

#Fterms, Gterms = tropical.getTropCoeffs(A1, b1, A2, b2, doTime=True)

In [14]:
Ftermsfull = []
Gtermsfull = []
for i in range(0, np.size(A1, axis = 0), 10):
    Fterms, Gterms = tropical.getTropCoeffs(A1[i:i+10, :], b1[i:i+10], A2[:, i:i+10], b2)
    Ftermsfull.append(Fterms)
    Gtermsfull.append(Gterms)

In [15]:
prodF = 1
prodG = 1
for i in range(len(Ftermsfull)):
    print(len(Ftermsfull[i]), len(Gtermsfull[i]))
    prodF *= len(Ftermsfull[i])
    prodG *= len(Gtermsfull[i])
    
print(math.log(prodF, 2), math.log(prodG, 2))

32 32
8 128
32 32
128 8
64 16
32 32
16 64
16 64
16 64
32 32
4 256
8 128
16 16
57.0 71.0


In [16]:
def multiplyTrops(trop1, trop2):
    temp = trop1 + trop2[0, :]
    for i in range(1, trop2.shape[0]):
        temp = np.vstack((temp, trop1 + trop2[i, :]))
    return temp

In [48]:
solvers.options['show_progress'] = False
def computeW(i, k, lam, gam, nbrs, temp):
    _, indices = nbrs.kneighbors(temp[i:i+1, :])
    indices = indices[0,1:]
    neighbors = temp[indices, :]

    # Set up quadratic programming problem
    Qtild = np.dot(temp[indices, :], temp[indices, :].T)
    Etild = np.eye(k)
    E = gam*np.bmat([[Etild, -Etild], [-Etild, Etild]])

    A = np.ones((2*k, 1))
    A[k:, :] = -1

    G = -np.eye(2*k)
    h = np.zeros((2*k, 1))

    b = np.array([1.0]).reshape(1,1)
    Q = np.bmat([[Qtild, -Qtild], [-Qtild, Qtild]]) + E
    
    c = np.dot(temp[indices, :], temp[i, :])
    c = c.reshape(-1, 1)
    c = np.bmat([[-c], [c]]) + lam
    
    # Solve quadratic programming problem
    out = solvers.qp(matrix(Q), matrix(c), matrix(G), matrix(h), matrix(A.T), matrix(b))
    
    # If there are negative weight values, it's near the convex hull
    w = np.array(out['x'])[:k] - np.array(out['x'])[k:]
    return w

def computeWs(points, k=1000, lam=1e-3, gam=1e-6):
    # Rescale points to roughly 0 to 1
    shift = np.amin(points)
    points = points - shift
    scale = np.amax(points)
    points = points/scale
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(points)
    goodIndices = []
    ws = np.zeros((points.shape[0], k))

    start = time.time()
    end = points.shape[0]
    for i in range(end):
        w = computeW(i, k, lam, gam, nbrs, points)
        ws[i, :] = w.T
        if np.sum(w < 0) > 0:
            goodIndices.append(i)
        if (i+1) % 100 == 0 or i+1 == end or i == 0:
            print('i = {}/{}\tverts={}\ttime={}'.format(i+1, end, len(goodIndices), time.time()-start))
            
    return ws, goodIndices

from joblib import Parallel, delayed
def computeWsParallel(points, k=1000, lam=1e-3, gam=1e-6):
    shift = np.amin(points)
    points = points - shift
    scale = np.amax(points)
    points = points/scale
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(points)
    goodIndices = []
    ws = np.zeros((points.shape[0], k))

    start = time.time()
    end = points.shape[0]
    for i in range(0, end, 100):
        wsT = Parallel(n_jobs=4)(delayed(computeW)(i, k, lam, gam, nbrs, points) for i in range(i, max(i+100, end)))
        for j in range(len(wsT)):
            ws[i, :] = wsT[j].T
            if np.sum(wsT[j] < 0) > 0:
                goodIndices.append(i + j)
            print('i = {}/{}\tverts={}\ttime={}'.format(i+j+1, end, len(goodIndices), time.time()-start))
            
    return ws, goodIndices

In [22]:
# from joblib import Parallel, delayed
# start = time.time()
# ws = Parallel(n_jobs=4)(delayed(computeW)(i) for i in range(16))
# print(start - time.time())

In [33]:
newFterms = []
goodIndicesFull = []
fullWs = []
for i in range(0, len(Ftermsfull)-1, 2):
    tempAdd = np.array(list(Ftermsfull[i]))
    tempAdd2 = np.array(list(Ftermsfull[i+1]))
    temp = multiplyTrops(tempAdd, tempAdd2)
    
    # There may be an odd number of terms - in that case, wrap the last in with the previous 2
    if len(Ftermsfull) - 1 <= i+2:
        tempAdd2 = np.array(list(Ftermsfull[i+1]))
        temp = multiplyTrops(temp, tempAdd2)
    
    newFterms.append(temp)
    ws, goodIndices = computeWs(temp, k=min(1000, temp.shape[0]-1))
    goodIndicesFull.append(goodIndices)
    fullWs.append(ws)
    
print([val.shape[0] for val in newFterms])
print([len(val) for val in goodIndicesFull])

i = 1/256	verts=1	time=0.13802289962768555
i = 100/256	verts=100	time=10.865399837493896
i = 200/256	verts=200	time=21.694053173065186
i = 256/256	verts=256	time=27.740741729736328
i = 1/4096	verts=1	time=3.161743640899658
i = 100/4096	verts=100	time=333.84939908981323
i = 200/4096	verts=200	time=675.5911660194397
i = 300/4096	verts=300	time=1008.7682416439056
i = 400/4096	verts=400	time=1343.7411675453186
i = 500/4096	verts=500	time=1678.43390417099
i = 600/4096	verts=600	time=2010.0175104141235
i = 700/4096	verts=700	time=2335.753485441208
i = 800/4096	verts=800	time=2655.792400598526
i = 900/4096	verts=900	time=3001.0851793289185
i = 1000/4096	verts=1000	time=3330.0499637126923
i = 1100/4096	verts=1100	time=3664.6881988048553
i = 1200/4096	verts=1200	time=4005.0956881046295
i = 1300/4096	verts=1300	time=4342.645967245102
i = 1400/4096	verts=1400	time=4685.877069711685
i = 1500/4096	verts=1500	time=5022.865285158157
i = 1600/4096	verts=1600	time=5357.232897281647
i = 1700/4096	verts=

In [None]:
temp = multiplyTrops(newFterms[0], newFterms[-1])
ws, goodIndices = computeWs(temp, k=min(1000, temp.shape[0]-1))

i = 1/65536	verts=1	time=3.182948589324951
i = 100/65536	verts=97	time=313.57209968566895
i = 200/65536	verts=193	time=626.9816176891327
i = 300/65536	verts=291	time=934.8976018428802
i = 400/65536	verts=385	time=1242.9539551734924
i = 500/65536	verts=485	time=1558.6873517036438
i = 600/65536	verts=582	time=1856.3139126300812
i = 700/65536	verts=677	time=2165.697219848633
i = 800/65536	verts=777	time=2468.375405550003
i = 900/65536	verts=871	time=2776.136768579483
i = 1000/65536	verts=969	time=3092.1326949596405
i = 1100/65536	verts=1006	time=3395.798807144165
i = 1200/65536	verts=1024	time=3698.798754930496
i = 1300/65536	verts=1034	time=3996.8488540649414
i = 1400/65536	verts=1054	time=4291.096279621124
i = 1500/65536	verts=1069	time=4595.8072056770325
i = 1600/65536	verts=1084	time=4886.637312173843
i = 1700/65536	verts=1102	time=5192.1670796871185
i = 1800/65536	verts=1113	time=5486.083515882492
i = 1900/65536	verts=1133	time=5793.051257133484
i = 2000/65536	verts=1147	time=6102.01