In [1]:
import torch
import torchvision
import torch.optim as optim
from data.data import get_train_test_loader
from model.network import FastUpdateNet, FastUpdateNetLarge, TeacherNetLarge
import torch.nn.functional as F

import pickle
import time 

In [22]:
n_epochs = 20
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
# torch.manual_seed(random_seed)

In [3]:
train_loader, test_loader = get_train_test_loader('fashionmnist')

In [4]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [5]:
from threading import Thread

def compute_back_M1(network):
  network.mNet1.backwardHidden()

def compute_back_M2(network):
  network.mNet2.backwardHidden()

def train(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    # p1 = Thread(target=compute_back_M1, args=[network]) # start two independent threads
    # p2 = Thread(target=compute_back_M2, args=[network])
    
    # p1.start()
    # p2.start()
    network.mNet1.backwardHidden()
    network.mNet2.backwardHidden()
    # p1.join() # wait for the two threads to finish
    # return 
    # p2.join()

    correctness = (target == torch.argmax(output))
    optimizer.step()
    # network.mNet.weightUpdate(correctness, lr = learning_rate)
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))


In [6]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  test_accuracies.append(100. * correct / len(test_loader.dataset))
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [7]:
def train_teacherNet(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    correctness = (target == torch.argmax(output))
    optimizer.step()
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  test_accuracies.append(100. * correct / len(test_loader.dataset))
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

## Test using 2 Ms (computing grad in main thread and separate thread)

In [8]:

import torch.nn as nn
fNet = FastUpdateNetLarge()
# fNet.mNet1 = torch.nn.Sequential(nn.Linear(392, 196), nn.ReLU(), nn.Linear(196, 98), nn.ReLU(), nn.Linear(98, 49), nn.ReLU())
# fNet.mNet2 = torch.nn.Sequential(nn.Linear(392, 196), nn.ReLU(), nn.Linear(196, 98), nn.ReLU(), nn.Linear(98, 49), nn.ReLU())
optimizer = optim.SGD(fNet.get_parameters(), lr=learning_rate, momentum=momentum)

In [23]:
net = None
optimizer = None
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []
net_types = ['M']
results = {}
trials = 5
for nt in net_types:
  results[nt] = []
  for i in range(trials):
    print('model type', nt, 'trail', i, 'starts')
    train_losses = []
    train_counter = []
    test_losses = []
    test_accuracies = []
    if nt == 'M':
      net = FastUpdateNetLarge()
      optimizer = optim.SGD(net.get_parameters(), lr=learning_rate, momentum=momentum)
    else:
      net = TeacherNetLarge()
      optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
    for epoch in range(1, n_epochs + 1):
      if nt == 'M':
        train(epoch, net)
      else:
        train_teacherNet(epoch, net)
      test(net)
    results[nt].append({'train_losses':train_losses, 'train_counter':train_counter, 'test_losses':test_losses, 'test_accuracies':test_accuracies})

# torch.save(results, './FMNIST-result.pt')

model type M trail 0 starts

Test set: Avg. loss: 2.3024, Accuracy: 1000/10000 (10%)


Test set: Avg. loss: 2.3002, Accuracy: 2265/10000 (23%)


Test set: Avg. loss: 1.6839, Accuracy: 3158/10000 (32%)


Test set: Avg. loss: 0.9905, Accuracy: 6330/10000 (63%)


Test set: Avg. loss: 0.7703, Accuracy: 7459/10000 (75%)


Test set: Avg. loss: 0.6502, Accuracy: 7739/10000 (77%)


Test set: Avg. loss: 0.6509, Accuracy: 7629/10000 (76%)


Test set: Avg. loss: 0.9252, Accuracy: 6512/10000 (65%)


Test set: Avg. loss: 0.6284, Accuracy: 7799/10000 (78%)


Test set: Avg. loss: 0.6807, Accuracy: 7637/10000 (76%)


Test set: Avg. loss: 0.6868, Accuracy: 7572/10000 (76%)


Test set: Avg. loss: 0.6001, Accuracy: 7935/10000 (79%)


Test set: Avg. loss: 0.5648, Accuracy: 8004/10000 (80%)


Test set: Avg. loss: 0.5688, Accuracy: 7995/10000 (80%)


Test set: Avg. loss: 0.5815, Accuracy: 7925/10000 (79%)


Test set: Avg. loss: 0.5677, Accuracy: 8014/10000 (80%)


Test set: Avg. loss: 0.5550, Accuracy: 8021

In [2]:
import torch

net_types = ['base', 'M']
results = torch.load('./FMNIST-result.pt')
for nt in net_types:
    total = 0
    all_r = []
    for r in results[nt]:
        total = total + r['test_accuracies'][-1]
        all_r.append(r['test_accuracies'][-1])
    print(all_r)
    print(nt, total / len(results[nt]), torch.std(torch.tensor(all_r), dim = 0))

[tensor(87.6800), tensor(86.7100), tensor(85.2100), tensor(87.9900), tensor(86.4400)]
base tensor(86.8060) tensor(1.1018)
[tensor(85.4800), tensor(76.0700), tensor(85.8100), tensor(84.9400), tensor(85.9600)]
M tensor(83.6520) tensor(4.2565)


## Test using no Ms, and only on main thread

In [None]:
def train_no_M(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    # p1 = Thread(target=compute_back_M1, args=[network]) # start two independent threads
    # # p2 = Thread(target=compute_back_M2, args=[network])
    
    # p1.start()
    # # p2.start()
    # network.mNet2.backwardHidden()
    # p1.join() # wait for the two threads to finish
    # # return 
    # # p2.join()

    correctness = (target == torch.argmax(output))
    optimizer.step()
    # network.mNet.weightUpdate(correctness, lr = learning_rate)
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [None]:

import torch.nn as nn
fNet = FastUpdateNetLarge()
fNet.mNet1 = torch.nn.Sequential(nn.Linear(392, 196), nn.ReLU(), nn.Linear(196, 98), nn.ReLU(), nn.Linear(98, 49), nn.ReLU())
fNet.mNet2 = torch.nn.Sequential(nn.Linear(392, 196), nn.ReLU(), nn.Linear(196, 98), nn.ReLU(), nn.Linear(98, 49), nn.ReLU())
optimizer = optim.SGD(fNet.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []

s = time.perf_counter()

test(fNet)
for epoch in range(1, n_epochs + 1):
  train_no_M(epoch, fNet)
  test(fNet)
  # if (epoch + 1) % 10 == 0:
  #   torch.save(fNet, 'fNet-stdp-30000-epoch.pt')
  #   with open('fNet-stdp-3000-epoch-loss-accuracy.pkl', 'wb') as f:
  #     pickle.dump({'train_losses':train_losses, 'test_losses': test_losses, 'test_accuracies':test_accuracies}, f, protocol=pickle.HIGHEST_PROTOCOL)

# torch.save(fNet, 'fNet-stdp-30000-epoch.pt')
e = time.perf_counter()
print(f"Training time for {epoch} epochs: {e-s} seconds")


Test set: Avg. loss: 2.3051, Accuracy: 1000/10000 (10%)


Test set: Avg. loss: 2.3021, Accuracy: 1000/10000 (10%)


Test set: Avg. loss: 2.2995, Accuracy: 1000/10000 (10%)


Test set: Avg. loss: 1.6677, Accuracy: 2402/10000 (24%)


Test set: Avg. loss: 0.8659, Accuracy: 6525/10000 (65%)


Test set: Avg. loss: 1.0468, Accuracy: 6144/10000 (61%)

Training time for 5 epochs: 49.55497955996543 seconds
