In [1]:
# https://nextjournal.com/gkoehler/pytorch-mnist

In [2]:
# Import needed files and basic setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

import data_gen2
import tropical

from ipywidgets import Output
from IPython.display import display, Markdown, Latex, Math, clear_output

from sklearn.neighbors import NearestNeighbors

import math

from cvxopt import solvers, matrix

import time

%matplotlib notebook
#plt.ion()

In [3]:
# Hyperparameters
n_epochs = 100
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 100

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f9baa91f7f0>

In [4]:
# Load training and testing sets
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [5]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

In [7]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

In [8]:
train_losses = []
train_acc = []
train_counter = []
test_losses = []
test_acc = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [12]:
def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            pred = output.data.max(1, keepdim=True)[1]
            correct = pred.eq(target.data.view_as(pred)).sum()
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format(
                  epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader), loss.item(), 100. * correct / 64))
            
            train_losses.append(loss.item())
            train_acc.append(100. * correct / 64)
            train_counter.append((batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [13]:
def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    test_acc.append(100. * correct / len(test_loader.dataset))
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
          test_loss, correct, len(test_loader.dataset),
          100. * correct / len(test_loader.dataset)))

In [14]:
test()
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

  if sys.path[0] == '':



Test set: Avg. loss: 0.3481, Accuracy: 9031/10000 (90%)


Test set: Avg. loss: 0.2346, Accuracy: 9309/10000 (93%)


Test set: Avg. loss: 0.1906, Accuracy: 9450/10000 (94%)


Test set: Avg. loss: 0.1619, Accuracy: 9516/10000 (95%)


Test set: Avg. loss: 0.1461, Accuracy: 9571/10000 (95%)


Test set: Avg. loss: 0.1338, Accuracy: 9584/10000 (95%)


Test set: Avg. loss: 0.1239, Accuracy: 9626/10000 (96%)


Test set: Avg. loss: 0.1177, Accuracy: 9646/10000 (96%)


Test set: Avg. loss: 0.1116, Accuracy: 9673/10000 (96%)


Test set: Avg. loss: 0.1072, Accuracy: 9684/10000 (96%)


Test set: Avg. loss: 0.1029, Accuracy: 9688/10000 (96%)




Test set: Avg. loss: 0.1001, Accuracy: 9691/10000 (96%)


Test set: Avg. loss: 0.0966, Accuracy: 9697/10000 (96%)


Test set: Avg. loss: 0.0949, Accuracy: 9721/10000 (97%)


Test set: Avg. loss: 0.0918, Accuracy: 9720/10000 (97%)


Test set: Avg. loss: 0.0897, Accuracy: 9724/10000 (97%)


Test set: Avg. loss: 0.0891, Accuracy: 9730/10000 (97%)


Test set: Avg. loss: 0.0867, Accuracy: 9739/10000 (97%)


Test set: Avg. loss: 0.0862, Accuracy: 9737/10000 (97%)


Test set: Avg. loss: 0.0849, Accuracy: 9749/10000 (97%)


Test set: Avg. loss: 0.0827, Accuracy: 9757/10000 (97%)


Test set: Avg. loss: 0.0832, Accuracy: 9745/10000 (97%)




Test set: Avg. loss: 0.0826, Accuracy: 9752/10000 (97%)


Test set: Avg. loss: 0.0806, Accuracy: 9746/10000 (97%)


Test set: Avg. loss: 0.0805, Accuracy: 9758/10000 (97%)


Test set: Avg. loss: 0.0790, Accuracy: 9765/10000 (97%)


Test set: Avg. loss: 0.0794, Accuracy: 9754/10000 (97%)


Test set: Avg. loss: 0.0789, Accuracy: 9757/10000 (97%)


Test set: Avg. loss: 0.0776, Accuracy: 9763/10000 (97%)


Test set: Avg. loss: 0.0778, Accuracy: 9760/10000 (97%)


Test set: Avg. loss: 0.0783, Accuracy: 9767/10000 (97%)


Test set: Avg. loss: 0.0772, Accuracy: 9764/10000 (97%)


Test set: Avg. loss: 0.0761, Accuracy: 9757/10000 (97%)




Test set: Avg. loss: 0.0748, Accuracy: 9776/10000 (97%)


Test set: Avg. loss: 0.0747, Accuracy: 9778/10000 (97%)


Test set: Avg. loss: 0.0749, Accuracy: 9770/10000 (97%)


Test set: Avg. loss: 0.0751, Accuracy: 9776/10000 (97%)


Test set: Avg. loss: 0.0747, Accuracy: 9762/10000 (97%)


Test set: Avg. loss: 0.0753, Accuracy: 9772/10000 (97%)


Test set: Avg. loss: 0.0739, Accuracy: 9774/10000 (97%)


Test set: Avg. loss: 0.0725, Accuracy: 9782/10000 (97%)


Test set: Avg. loss: 0.0734, Accuracy: 9782/10000 (97%)


Test set: Avg. loss: 0.0743, Accuracy: 9772/10000 (97%)


Test set: Avg. loss: 0.0729, Accuracy: 9788/10000 (97%)




Test set: Avg. loss: 0.0726, Accuracy: 9782/10000 (97%)


Test set: Avg. loss: 0.0724, Accuracy: 9781/10000 (97%)


Test set: Avg. loss: 0.0742, Accuracy: 9776/10000 (97%)


Test set: Avg. loss: 0.0721, Accuracy: 9785/10000 (97%)


Test set: Avg. loss: 0.0721, Accuracy: 9783/10000 (97%)


Test set: Avg. loss: 0.0736, Accuracy: 9777/10000 (97%)


Test set: Avg. loss: 0.0719, Accuracy: 9785/10000 (97%)


Test set: Avg. loss: 0.0733, Accuracy: 9781/10000 (97%)


Test set: Avg. loss: 0.0704, Accuracy: 9792/10000 (97%)


Test set: Avg. loss: 0.0730, Accuracy: 9782/10000 (97%)


Test set: Avg. loss: 0.0711, Accuracy: 9789/10000 (97%)




Test set: Avg. loss: 0.0718, Accuracy: 9787/10000 (97%)


Test set: Avg. loss: 0.0722, Accuracy: 9792/10000 (97%)


Test set: Avg. loss: 0.0724, Accuracy: 9785/10000 (97%)


Test set: Avg. loss: 0.0710, Accuracy: 9783/10000 (97%)


Test set: Avg. loss: 0.0721, Accuracy: 9780/10000 (97%)



KeyboardInterrupt: 

In [12]:
# fig = plt.figure()
# plt.plot(train_counter, train_acc, color='blue')
# plt.plot(test_counter, test_acc, 'o', color='red')
# plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
# plt.xlabel('number of training examples seen')
# plt.ylabel('negative log likelihood loss')

In [15]:
params = []
for param in network.parameters():
    params.append(param.detach().numpy())

A1 = params[0]
b1 = params[1]
A2 = params[2][0,:].reshape((1, -1))
b2 = params[3][:1]

#Fterms, Gterms = tropical.getTropCoeffs(A1, b1, A2, b2, doTime=True)

In [14]:
Ftermsfull = []
Gtermsfull = []
for i in range(0, np.size(A1, axis = 0), 10):
    Fterms, Gterms = tropical.getTropCoeffs(A1[i:i+10, :], b1[i:i+10], A2[:, i:i+10], b2)
    Ftermsfull.append(Fterms)
    Gtermsfull.append(Gterms)

In [15]:
prodF = 1
prodG = 1
for i in range(len(Ftermsfull)):
    print(len(Ftermsfull[i]), len(Gtermsfull[i]))
    prodF *= len(Ftermsfull[i])
    prodG *= len(Gtermsfull[i])
    
print(math.log(prodF, 2), math.log(prodG, 2))

32 32
8 128
32 32
128 8
64 16
32 32
16 64
16 64
16 64
32 32
4 256
8 128
16 16
57.0 71.0


In [16]:
def multiplyTrops(trop1, trop2):
    temp = trop1 + trop2[0, :]
    for i in range(1, trop2.shape[0]):
        temp = np.vstack((temp, trop1 + trop2[i, :]))
    return temp

In [48]:
solvers.options['show_progress'] = False
def computeW(i, k, lam, gam, nbrs, temp):
    _, indices = nbrs.kneighbors(temp[i:i+1, :])
    indices = indices[0,1:]
    neighbors = temp[indices, :]

    # Set up quadratic programming problem
    Qtild = np.dot(temp[indices, :], temp[indices, :].T)
    Etild = np.eye(k)
    E = gam*np.bmat([[Etild, -Etild], [-Etild, Etild]])

    A = np.ones((2*k, 1))
    A[k:, :] = -1

    G = -np.eye(2*k)
    h = np.zeros((2*k, 1))

    b = np.array([1.0]).reshape(1,1)
    Q = np.bmat([[Qtild, -Qtild], [-Qtild, Qtild]]) + E
    
    c = np.dot(temp[indices, :], temp[i, :])
    c = c.reshape(-1, 1)
    c = np.bmat([[-c], [c]]) + lam
    
    # Solve quadratic programming problem
    out = solvers.qp(matrix(Q), matrix(c), matrix(G), matrix(h), matrix(A.T), matrix(b))
    
    # If there are negative weight values, it's near the convex hull
    w = np.array(out['x'])[:k] - np.array(out['x'])[k:]
    return w

def computeWs(points, k=1000, lam=1e-3, gam=1e-6):
    # Rescale points to roughly 0 to 1
    shift = np.amin(points)
    points = points - shift
    scale = np.amax(points)
    points = points/scale
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(points)
    goodIndices = []
    ws = np.zeros((points.shape[0], k))

    start = time.time()
    end = points.shape[0]
    for i in range(end):
        w = computeW(i, k, lam, gam, nbrs, points)
        ws[i, :] = w.T
        if np.sum(w < 0) > 0:
            goodIndices.append(i)
        if (i+1) % 100 == 0 or i+1 == end or i == 0:
            print('i = {}/{}\tverts={}\ttime={}'.format(i+1, end, len(goodIndices), time.time()-start))
            
    return ws, goodIndices

from joblib import Parallel, delayed
def computeWsParallel(points, k=1000, lam=1e-3, gam=1e-6):
    shift = np.amin(points)
    points = points - shift
    scale = np.amax(points)
    points = points/scale
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(points)
    goodIndices = []
    ws = np.zeros((points.shape[0], k))

    start = time.time()
    end = points.shape[0]
    for i in range(0, end, 100):
        wsT = Parallel(n_jobs=4)(delayed(computeW)(i, k, lam, gam, nbrs, points) for i in range(i, max(i+100, end)))
        for j in range(len(wsT)):
            ws[i, :] = wsT[j].T
            if np.sum(wsT[j] < 0) > 0:
                goodIndices.append(i + j)
            print('i = {}/{}\tverts={}\ttime={}'.format(i+j+1, end, len(goodIndices), time.time()-start))
            
    return ws, goodIndices

In [22]:
# from joblib import Parallel, delayed
# start = time.time()
# ws = Parallel(n_jobs=4)(delayed(computeW)(i) for i in range(16))
# print(start - time.time())

In [57]:
newFterms = []
goodIndicesFull = []
fullWs = []
for i in range(0, len(Ftermsfull)-1, 2):
    tempAdd = np.array(list(Ftermsfull[i]))
    tempAdd2 = np.array(list(Ftermsfull[i+1]))
    temp = multiplyTrops(tempAdd, tempAdd2)
    
    # There may be an odd number of terms - in that case, wrap the last in with the previous 2
    if len(Ftermsfull) - 1 <= i+2:
        tempAdd3 = np.array(list(Ftermsfull[i+2]))
        temp = multiplyTrops(temp, tempAdd3)
    
    newFterms.append(temp)
    ws, goodIndices = computeWs(temp, k=min(786, temp.shape[0]-1))
    goodIndicesFull.append(goodIndices)
    fullWs.append(ws)
    
print([val.shape[0] for val in newFterms])
print([len(val) for val in goodIndicesFull])

i = 1/256	verts=1	time=0.1315152645111084
i = 100/256	verts=100	time=10.98108959197998
i = 200/256	verts=200	time=21.91749143600464
i = 256/256	verts=256	time=27.982080936431885
i = 1/4096	verts=1	time=1.5399019718170166
i = 100/4096	verts=100	time=155.4342246055603
i = 200/4096	verts=200	time=319.23856377601624
i = 300/4096	verts=300	time=481.1208655834198
i = 400/4096	verts=400	time=637.0060093402863
i = 500/4096	verts=500	time=791.4978244304657
i = 600/4096	verts=600	time=945.5785973072052
i = 700/4096	verts=700	time=1106.5557553768158
i = 800/4096	verts=800	time=1260.2444641590118
i = 900/4096	verts=900	time=1414.8863384723663
i = 1000/4096	verts=1000	time=1572.0981800556183
i = 1100/4096	verts=1100	time=1735.3753430843353
i = 1200/4096	verts=1200	time=1897.5189082622528
i = 1300/4096	verts=1300	time=2062.4106800556183
i = 1400/4096	verts=1400	time=2216.1803669929504
i = 1500/4096	verts=1500	time=2371.6837339401245
i = 1600/4096	verts=1600	time=2533.5381050109863
i = 1700/4096	vert

In [None]:
newGterms = []
goodIndicesFullG = []
fullWsG = []
for i in range(0, len(Gtermsfull)-1, 2):
    tempAdd = np.array(list(Gtermsfull[i]))
    tempAdd2 = np.array(list(Gtermsfull[i+1]))
    temp = multiplyTrops(tempAdd, tempAdd2)
    
    # There may be an odd number of terms - in that case, wrap the last in with the previous 2
    if len(Ftermsfull) - 1 <= i+2:
        tempAdd2 = np.array(list(Gtermsfull[i+2]))
        temp = multiplyTrops(temp, tempAdd2)
    
    newGterms.append(temp)
    ws, goodIndices = computeWs(temp, k=min(1000, temp.shape[0]-1))
    goodIndicesFullG.append(goodIndices)
    fullWsG.append(ws)
    
print([val.shape[0] for val in newGterms])
print([len(val) for val in goodIndicesFullG])

i = 1/4096	verts=1	time=3.5673720836639404
i = 100/4096	verts=100	time=365.65286684036255
i = 200/4096	verts=200	time=720.0447235107422
i = 300/4096	verts=300	time=1084.5185689926147
i = 400/4096	verts=400	time=1458.4401633739471
i = 500/4096	verts=500	time=1819.8820161819458
i = 600/4096	verts=600	time=2168.7433104515076
i = 700/4096	verts=700	time=2530.6758663654327
i = 800/4096	verts=800	time=2898.9037623405457
i = 900/4096	verts=900	time=3258.5181336402893
i = 1000/4096	verts=1000	time=3621.5378308296204
i = 1100/4096	verts=1100	time=3984.6942200660706
i = 1200/4096	verts=1200	time=4337.238765239716
i = 1300/4096	verts=1300	time=4695.868378639221
i = 1400/4096	verts=1400	time=5045.953987836838
i = 1500/4096	verts=1500	time=5405.806816577911
i = 1600/4096	verts=1600	time=5756.345547914505
i = 1700/4096	verts=1700	time=6119.721143722534
i = 1800/4096	verts=1800	time=6479.777388811111
i = 1900/4096	verts=1900	time=6840.853674173355
i = 2000/4096	verts=2000	time=7207.907062768936
i = 2

In [50]:
temp = multiplyTrops(newFterms[0], newFterms[-1])
ws, goodIndices = computeWs(temp, k=min(1000, temp.shape[0]-1))

i = 1/65536	verts=1	time=3.182948589324951
i = 100/65536	verts=97	time=313.57209968566895
i = 200/65536	verts=193	time=626.9816176891327
i = 300/65536	verts=291	time=934.8976018428802
i = 400/65536	verts=385	time=1242.9539551734924
i = 500/65536	verts=485	time=1558.6873517036438
i = 600/65536	verts=582	time=1856.3139126300812
i = 700/65536	verts=677	time=2165.697219848633
i = 800/65536	verts=777	time=2468.375405550003
i = 900/65536	verts=871	time=2776.136768579483
i = 1000/65536	verts=969	time=3092.1326949596405
i = 1100/65536	verts=1006	time=3395.798807144165
i = 1200/65536	verts=1024	time=3698.798754930496
i = 1300/65536	verts=1034	time=3996.8488540649414
i = 1400/65536	verts=1054	time=4291.096279621124
i = 1500/65536	verts=1069	time=4595.8072056770325
i = 1600/65536	verts=1084	time=4886.637312173843
i = 1700/65536	verts=1102	time=5192.1670796871185
i = 1800/65536	verts=1113	time=5486.083515882492
i = 1900/65536	verts=1133	time=5793.051257133484
i = 2000/65536	verts=1147	time=6102.01

i = 16600/65536	verts=4139	time=51732.70277619362
i = 16700/65536	verts=4159	time=52065.92827820778
i = 16800/65536	verts=4183	time=52392.397146224976
i = 16900/65536	verts=4207	time=52714.68871951103
i = 17000/65536	verts=4220	time=53046.77878499031
i = 17100/65536	verts=4235	time=53373.85959839821
i = 17200/65536	verts=4247	time=53707.44296312332
i = 17300/65536	verts=4268	time=54029.14952707291
i = 17400/65536	verts=4289	time=54349.36410212517
i = 17500/65536	verts=4320	time=54658.121846199036
i = 17600/65536	verts=4351	time=54955.162558078766
i = 17700/65536	verts=4384	time=55264.399864435196
i = 17800/65536	verts=4417	time=55569.08432006836
i = 17900/65536	verts=4447	time=55875.18064188957
i = 18000/65536	verts=4479	time=56168.32724428177
i = 18100/65536	verts=4508	time=56470.907829999924
i = 18200/65536	verts=4540	time=56776.23694586754
i = 18300/65536	verts=4571	time=57081.98174905777
i = 18400/65536	verts=4607	time=57392.35288286209
i = 18500/65536	verts=4682	time=57693.4439933

i = 33000/65536	verts=8105	time=102860.89912629128
i = 33100/65536	verts=8111	time=103185.72646927834
i = 33200/65536	verts=8117	time=103513.46689987183
i = 33300/65536	verts=8119	time=103840.80813908577
i = 33400/65536	verts=8124	time=104166.36252427101
i = 33500/65536	verts=8129	time=104497.8814690113
i = 33600/65536	verts=8131	time=104825.99758267403
i = 33700/65536	verts=8134	time=105164.43039369583
i = 33800/65536	verts=8141	time=105488.67664980888
i = 33900/65536	verts=8149	time=105796.97510075569
i = 34000/65536	verts=8157	time=106100.3727042675
i = 34100/65536	verts=8171	time=106393.37444472313
i = 34200/65536	verts=8184	time=106697.68911647797
i = 34300/65536	verts=8196	time=106999.70196080208
i = 34400/65536	verts=8207	time=107292.44409823418
i = 34500/65536	verts=8217	time=107583.57112121582
i = 34600/65536	verts=8233	time=107876.2711865902
i = 34700/65536	verts=8247	time=108184.33309721947
i = 34800/65536	verts=8258	time=108479.59762525558
i = 34900/65536	verts=8276	time=10

i = 49000/65536	verts=11961	time=152627.6687707901
i = 49100/65536	verts=11977	time=152932.67748308182
i = 49200/65536	verts=11998	time=153226.7618534565
i = 49300/65536	verts=12029	time=153518.14371919632
i = 49400/65536	verts=12060	time=153826.40931868553
i = 49500/65536	verts=12094	time=154133.5251071453
i = 49600/65536	verts=12125	time=154442.6230957508
i = 49700/65536	verts=12159	time=154747.76614904404
i = 49800/65536	verts=12188	time=155040.4963645935
i = 49900/65536	verts=12220	time=155337.50387215614
i = 50000/65536	verts=12252	time=155631.7814874649
i = 50100/65536	verts=12282	time=155927.15117931366
i = 50200/65536	verts=12314	time=156233.93265509605
i = 50300/65536	verts=12325	time=156569.09540748596
i = 50400/65536	verts=12335	time=156893.60634493828
i = 50500/65536	verts=12357	time=157220.07145881653
i = 50600/65536	verts=12374	time=157535.17072701454
i = 50700/65536	verts=12393	time=157858.68368148804
i = 50800/65536	verts=12420	time=158180.27321600914
i = 50900/65536	ve

i = 64900/65536	verts=15375	time=204925.15432286263
i = 65000/65536	verts=15472	time=205268.48317170143
i = 65100/65536	verts=15567	time=205605.23093509674
i = 65200/65536	verts=15656	time=205938.86098003387
i = 65300/65536	verts=15751	time=206275.19729804993
i = 65400/65536	verts=15848	time=206612.5327425003
i = 65500/65536	verts=15942	time=206949.18040013313
i = 65536/65536	verts=15978	time=207071.59475588799
