In [1]:
import torch
import torchvision
import torch.optim as optim
from data.data import get_train_test_loader
from model.network import FastUpdateNet, TeacherNet
import torch.nn.functional as F
import time
import pickle

In [5]:
n_epochs = 8
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
# torch.manual_seed(random_seed)

In [6]:
train_loader, test_loader = get_train_test_loader('mnist')

In [7]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [8]:
from threading import Thread

def do(network):
  network.mNet.backwardHidden() 

def train_fastNet(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    p1 = Thread(target=do, args=[network]) # start two independent threads

    p1.start()
        
    p1.join()
    # split into threads
    correctness = (target == torch.argmax(output))
    optimizer.step()
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))


def train_teacherNet(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    correctness = (target == torch.argmax(output))
    optimizer.step()
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [9]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  test_accuracies.append(100. * correct / len(test_loader.dataset))
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [7]:
# tNet = torch.load('tNet.pt')
# fNet = FastUpdateNet(teacherNet=tNet)
# test(fNet)

In [10]:
fNet = FastUpdateNet()
optimizer = optim.SGD(fNet.get_parameters(), lr=learning_rate, momentum=momentum)

with inv
error tensor(0.0003, grad_fn=<SumBackward0>)
0.018805503845214844

with pinv
error tensor(0.00004, grad_fn=<SumBackward0>)
0.02195906639099121

0.01661086082458496

In [29]:
net = None
optimizer = None
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []
net_types = ['as_is']
results = {}
trials = 5
for nt in net_types:
  results[nt] = []
  for i in range(trials):
    print('model type', nt, 'trail', i, 'starts')
    train_losses = []
    train_counter = []
    test_losses = []
    test_accuracies = []
    if nt == 'M':
      net = FastUpdateNet()
      optimizer = optim.SGD(net.get_parameters(), lr=learning_rate, momentum=momentum)
    else:
      net = TeacherNet()
      optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
    for epoch in range(1, n_epochs + 1):
      if nt == 'M':
        train_fastNet(epoch, net)
      else:
        train_teacherNet(epoch, net)
      test(net)
    results[nt].append({'train_losses':train_losses, 'train_counter':train_counter, 'test_losses':test_losses, 'test_accuracies':test_accuracies})

torch.save(results, './MNIST-result-as_is.pt')

model type base trail 0 starts


  return F.log_softmax(x)



Test set: Avg. loss: 0.3878, Accuracy: 8837/10000 (88%)


Test set: Avg. loss: 0.2021, Accuracy: 9394/10000 (94%)


Test set: Avg. loss: 0.1521, Accuracy: 9545/10000 (95%)


Test set: Avg. loss: 0.1350, Accuracy: 9610/10000 (96%)


Test set: Avg. loss: 0.1055, Accuracy: 9677/10000 (97%)


Test set: Avg. loss: 0.0971, Accuracy: 9703/10000 (97%)


Test set: Avg. loss: 0.0977, Accuracy: 9697/10000 (97%)


Test set: Avg. loss: 0.0960, Accuracy: 9705/10000 (97%)

model type base trail 1 starts

Test set: Avg. loss: 0.3429, Accuracy: 9038/10000 (90%)


Test set: Avg. loss: 0.1969, Accuracy: 9434/10000 (94%)


Test set: Avg. loss: 0.1461, Accuracy: 9565/10000 (96%)


Test set: Avg. loss: 0.1184, Accuracy: 9649/10000 (96%)


Test set: Avg. loss: 0.1018, Accuracy: 9701/10000 (97%)


Test set: Avg. loss: 0.0851, Accuracy: 9733/10000 (97%)


Test set: Avg. loss: 0.0771, Accuracy: 9754/10000 (98%)


Test set: Avg. loss: 0.0749, Accuracy: 9775/10000 (98%)

model type base trail 2 starts

Test set:

  return F.log_softmax(o_4)



Test set: Avg. loss: 0.5021, Accuracy: 8555/10000 (86%)


Test set: Avg. loss: 0.2503, Accuracy: 9230/10000 (92%)


Test set: Avg. loss: 0.1642, Accuracy: 9501/10000 (95%)


Test set: Avg. loss: 0.1371, Accuracy: 9582/10000 (96%)


Test set: Avg. loss: 0.1129, Accuracy: 9644/10000 (96%)


Test set: Avg. loss: 0.1113, Accuracy: 9677/10000 (97%)


Test set: Avg. loss: 0.0950, Accuracy: 9699/10000 (97%)


Test set: Avg. loss: 0.0956, Accuracy: 9709/10000 (97%)

model type M trail 1 starts

Test set: Avg. loss: 0.4086, Accuracy: 8776/10000 (88%)


Test set: Avg. loss: 0.2364, Accuracy: 9283/10000 (93%)


Test set: Avg. loss: 0.1678, Accuracy: 9501/10000 (95%)


Test set: Avg. loss: 0.1372, Accuracy: 9589/10000 (96%)


Test set: Avg. loss: 0.1274, Accuracy: 9598/10000 (96%)


Test set: Avg. loss: 0.1057, Accuracy: 9687/10000 (97%)


Test set: Avg. loss: 0.1052, Accuracy: 9671/10000 (97%)


Test set: Avg. loss: 0.0973, Accuracy: 9705/10000 (97%)

model type M trail 2 starts

Test set: Avg. 

In [10]:
import torch

net_types = ['base', 'M']
results = torch.load('./MNIST-result.pt')
for nt in net_types:
    total = 0
    all_r = []
    for r in results[nt]:
        total = total + r['test_accuracies'][-1]
        all_r.append(r['test_accuracies'][-1])
    print(nt, total / len(results[nt]), torch.std(torch.tensor(all_r), dim = 0))

base tensor(97.4100) tensor(0.2647)
M tensor(97.0200) tensor(0.2088)


In [27]:
results['base'][-2]['test_accuracies'][-2]

tensor(97.1500)

In [11]:
tNet = torch.load('tNet.pt')
fNet = FastUpdateNet(teacherNet=tNet)
fNet.fc1 = torch.nn.Linear(784, 392)
optimizer = optim.SGD(fNet.fc1.parameters(), lr=learning_rate, momentum=momentum)
test(fNet)
for epoch in range(1, 4):
  train_teacherNet(epoch, fNet)
  test(fNet)


Test set: Avg. loss: 2.9786, Accuracy: 1234/10000 (12%)


Test set: Avg. loss: 0.2798, Accuracy: 9188/10000 (92%)



KeyboardInterrupt: 

In [None]:
import torch.nn as nn

In [None]:
fNet = FastUpdateNet()
fNet.mNet = torch.nn.Sequential(nn.Linear(392, 196), nn.ReLU(), nn.Linear(196, 98), nn.ReLU(), nn.Linear(98, 49), nn.ReLU())
optimizer = optim.SGD(fNet.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []

test(fNet)
for epoch in range(1, n_epochs + 1):
  train_teacherNet(epoch, fNet)
  test(fNet)

  return F.log_softmax(o_4)



Test set: Avg. loss: 2.3096, Accuracy: 891/10000 (9%)


Test set: Avg. loss: 0.3962, Accuracy: 8793/10000 (88%)


Test set: Avg. loss: 0.2001, Accuracy: 9407/10000 (94%)


Test set: Avg. loss: 0.1858, Accuracy: 9440/10000 (94%)


Test set: Avg. loss: 0.1283, Accuracy: 9602/10000 (96%)

