In [1]:
import torch
import torchvision
import torch.optim as optim
from data.data import get_train_test_loader
from model.network import FastUpdateNet, TeacherNet
import torch.nn.functional as F
import time
import pickle

In [2]:
n_epochs = 20
batch_size_train = 512
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
# torch.manual_seed(random_seed)

In [3]:
train_loader, test_loader = get_train_test_loader('mnist', batch_size_train = batch_size_train)

In [4]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [5]:
from threading import Thread

def do(network):
  network.mNet.backwardHidden() 

def train_fastNet(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    p1 = Thread(target=do, args=[network]) # start two independent threads

    p1.start()
        
    p1.join()
    # split into threads
    correctness = (target == torch.argmax(output))
    optimizer.step()
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))


def train_teacherNet(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    correctness = (target == torch.argmax(output))
    optimizer.step()
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [6]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  test_accuracies.append(100. * correct / len(test_loader.dataset))
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [7]:
# tNet = torch.load('tNet.pt')
# fNet = FastUpdateNet(teacherNet=tNet)
# test(fNet)

In [8]:
fNet = FastUpdateNet()
optimizer = optim.SGD(fNet.get_parameters(), lr=learning_rate, momentum=momentum)

with inv
error tensor(0.0003, grad_fn=<SumBackward0>)
0.018805503845214844

with pinv
error tensor(0.00004, grad_fn=<SumBackward0>)
0.02195906639099121

0.01661086082458496

In [9]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []

test(fNet)
for epoch in range(1, n_epochs + 1):
  train_fastNet(epoch, fNet)
  test(fNet)

  return F.log_softmax(o_4)



Test set: Avg. loss: 2.3045, Accuracy: 1036/10000 (10%)


Test set: Avg. loss: 2.2702, Accuracy: 2049/10000 (20%)


Test set: Avg. loss: 2.1180, Accuracy: 4517/10000 (45%)


Test set: Avg. loss: 1.1156, Accuracy: 6614/10000 (66%)


Test set: Avg. loss: 0.7082, Accuracy: 7745/10000 (77%)


Test set: Avg. loss: 0.5238, Accuracy: 8395/10000 (84%)


Test set: Avg. loss: 0.4253, Accuracy: 8754/10000 (88%)


Test set: Avg. loss: 0.3713, Accuracy: 8907/10000 (89%)


Test set: Avg. loss: 0.3352, Accuracy: 9013/10000 (90%)


Test set: Avg. loss: 0.3078, Accuracy: 9093/10000 (91%)


Test set: Avg. loss: 0.2835, Accuracy: 9150/10000 (92%)


Test set: Avg. loss: 0.2661, Accuracy: 9203/10000 (92%)


Test set: Avg. loss: 0.2461, Accuracy: 9267/10000 (93%)


Test set: Avg. loss: 0.2398, Accuracy: 9303/10000 (93%)


Test set: Avg. loss: 0.2302, Accuracy: 9307/10000 (93%)


Test set: Avg. loss: 0.2111, Accuracy: 9358/10000 (94%)


Test set: Avg. loss: 0.1966, Accuracy: 9400/10000 (94%)


Test set: Avg

In [10]:
# without our method batchsize 512 
Test set: Avg. loss: 2.3055, Accuracy: 984/10000 (10%)


Test set: Avg. loss: 2.2689, Accuracy: 1683/10000 (17%)


Test set: Avg. loss: 2.0823, Accuracy: 3535/10000 (35%)


Test set: Avg. loss: 1.1515, Accuracy: 6768/10000 (68%)


Test set: Avg. loss: 0.6305, Accuracy: 8208/10000 (82%)


Test set: Avg. loss: 0.4735, Accuracy: 8594/10000 (86%)


Test set: Avg. loss: 0.4033, Accuracy: 8820/10000 (88%)


Test set: Avg. loss: 0.3620, Accuracy: 8917/10000 (89%)


Test set: Avg. loss: 0.3314, Accuracy: 9027/10000 (90%)


Test set: Avg. loss: 0.3024, Accuracy: 9092/10000 (91%)


Test set: Avg. loss: 0.2778, Accuracy: 9175/10000 (92%)


Test set: Avg. loss: 0.2556, Accuracy: 9236/10000 (92%)


Test set: Avg. loss: 0.2465, Accuracy: 9245/10000 (92%)


Test set: Avg. loss: 0.2244, Accuracy: 9322/10000 (93%)


Test set: Avg. loss: 0.2122, Accuracy: 9360/10000 (94%)


Test set: Avg. loss: 0.2006, Accuracy: 9403/10000 (94%)


Test set: Avg. loss: 0.1868, Accuracy: 9436/10000 (94%)


Test set: Avg. loss: 0.1802, Accuracy: 9443/10000 (94%)


Test set: Avg. loss: 0.1739, Accuracy: 9460/10000 (95%)


Test set: Avg. loss: 0.1652, Accuracy: 9514/10000 (95%)


Test set: Avg. loss: 0.1587, Accuracy: 9521/10000 (95%)


SyntaxError: invalid syntax (3724229338.py, line 2)

In [None]:
a = torch.nn.Linear(2 , 3)
# print(a.weight.shape)
x = torch.rand((2, 2))
y = a(x)
external_grad = torch.rand((2, 3))
y.backward(gradient = external_grad)

# print(y.shape)
print(a.weight.grad)
print(torch.matmul( a.weight, external_grad))
# print(external_grad * a.weight)

print(torch.matmul(torch.transpose(external_grad, 0, 1), x))
print(external_grad)

tensor([[0.8751, 0.7492],
        [0.5353, 0.2133],
        [0.8195, 0.2682]])
tensor([[0.5097, 0.3156, 0.4840],
        [0.1640, 0.0265, 0.0229],
        [0.3872, 0.1735, 0.2506]], grad_fn=<MmBackward0>)
tensor([[0.8751, 0.7492],
        [0.5353, 0.2133],
        [0.8195, 0.2682]])
tensor([[7.4310e-01, 5.4924e-01, 8.6340e-01],
        [5.6135e-01, 4.6440e-02, 2.4873e-04]])


In [None]:
tNet = torch.load('tNet.pt')
fNet = FastUpdateNet(teacherNet=tNet)
fNet.fc1 = torch.nn.Linear(784, 392)
optimizer = optim.SGD(fNet.fc1.parameters(), lr=learning_rate, momentum=momentum)
test(fNet)
for epoch in range(1, 4):
  train_teacherNet(epoch, fNet)
  test(fNet)

FileNotFoundError: [Errno 2] No such file or directory: 'tNet.pt'

In [None]:
import torch.nn as nn

In [None]:
fNet = FastUpdateNet()
fNet.mNet = torch.nn.Sequential(nn.Linear(392, 196), nn.ReLU(), nn.Linear(196, 98), nn.ReLU(), nn.Linear(98, 49), nn.ReLU())
optimizer = optim.SGD(fNet.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []

test(fNet)
for epoch in range(1, n_epochs + 1):
  train_teacherNet(epoch, fNet)
  test(fNet)

  return F.log_softmax(o_4)



Test set: Avg. loss: 2.3096, Accuracy: 891/10000 (9%)


Test set: Avg. loss: 0.3962, Accuracy: 8793/10000 (88%)


Test set: Avg. loss: 0.2001, Accuracy: 9407/10000 (94%)


Test set: Avg. loss: 0.1858, Accuracy: 9440/10000 (94%)


Test set: Avg. loss: 0.1283, Accuracy: 9602/10000 (96%)

