In [1]:
import torch
import torchvision
import torch.optim as optim
from data.data import get_train_test_loader
from model.network import FastUpdateNet, TeacherNet
import torch.nn.functional as F
import time
import pickle

In [2]:
n_epochs = 4
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f9e040b4a50>

In [3]:
train_loader, test_loader = get_train_test_loader('mnist')

In [4]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [5]:
from threading import Thread

def do(network):
  network.mNet.backwardHidden() 

def train_fastNet(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    p1 = Thread(target=do, args=[network]) # start two independent threads

    p1.start()
        
    p1.join()
    # split into threads
    correctness = (target == torch.argmax(output))
    optimizer.step()
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))


def train_teacherNet(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    # print(network.mNet.saver.grad)
    correctness = (target == torch.argmax(output))
    optimizer.step()
    if batch_idx % log_interval == 0:
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [6]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  test_accuracies.append(100. * correct / len(test_loader.dataset))
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [7]:
# tNet = torch.load('tNet.pt')
# fNet = FastUpdateNet(teacherNet=tNet)
# test(fNet)

In [8]:
fNet = FastUpdateNet()
optimizer = optim.SGD(fNet.get_parameters(), lr=learning_rate, momentum=momentum)

with inv
error tensor(0.0003, grad_fn=<SumBackward0>)
0.018805503845214844

with pinv
error tensor(0.00004, grad_fn=<SumBackward0>)
0.02195906639099121

0.01661086082458496

In [9]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []

test(fNet)
for epoch in range(1, n_epochs + 1):
  train_fastNet(epoch, fNet)
  test(fNet)

  return F.log_softmax(o_4)



Test set: Avg. loss: 2.3013, Accuracy: 1388/10000 (14%)


Test set: Avg. loss: 0.5196, Accuracy: 8230/10000 (82%)


Test set: Avg. loss: 0.3500, Accuracy: 8850/10000 (88%)


Test set: Avg. loss: 0.3488, Accuracy: 8849/10000 (88%)


Test set: Avg. loss: 2.3020, Accuracy: 1013/10000 (10%)



In [10]:
a = torch.nn.Linear(2 , 3)
# print(a.weight.shape)
x = torch.rand((2, 2))
y = a(x)
external_grad = torch.rand((2, 3))
y.backward(gradient = external_grad)

# print(y.shape)
print(a.weight.grad)
print(torch.matmul( a.weight, external_grad))
# print(external_grad * a.weight)

print(torch.matmul(torch.transpose(external_grad, 0, 1), x))
print(external_grad)

tensor([[0.4140, 0.5576],
        [0.2882, 0.2776],
        [0.4426, 0.6020]])
tensor([[ 0.2217, -0.0106,  0.2459],
        [-0.2190,  0.1012, -0.2478],
        [-0.4186,  0.0808, -0.4676]], grad_fn=<MmBackward0>)
tensor([[0.4140, 0.5576],
        [0.2882, 0.2776],
        [0.4426, 0.6020]])
tensor([[0.8787, 0.9273, 0.9224],
        [0.8646, 0.0360, 0.9548]])


In [11]:
tNet = torch.load('tNet.pt')
fNet = FastUpdateNet(teacherNet=tNet)
fNet.fc1 = torch.nn.Linear(784, 392)
optimizer = optim.SGD(fNet.fc1.parameters(), lr=learning_rate, momentum=momentum)
test(fNet)
for epoch in range(1, 4):
  train_teacherNet(epoch, fNet)
  test(fNet)


Test set: Avg. loss: 2.7455, Accuracy: 1196/10000 (12%)



KeyboardInterrupt: 

In [None]:
import torch.nn as nn

In [None]:
fNet = FastUpdateNet()
fNet.mNet = torch.nn.Sequential(nn.Linear(392, 196), nn.ReLU(), nn.Linear(196, 98), nn.ReLU(), nn.Linear(98, 49), nn.ReLU())
optimizer = optim.SGD(fNet.parameters(), lr=learning_rate, momentum=momentum)

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_accuracies = []

test(fNet)
for epoch in range(1, n_epochs + 1):
  train_teacherNet(epoch, fNet)
  test(fNet)


Test set: Avg. loss: 2.3068, Accuracy: 872/10000 (9%)


Test set: Avg. loss: 0.3542, Accuracy: 8966/10000 (90%)


Test set: Avg. loss: 0.1977, Accuracy: 9398/10000 (94%)


Test set: Avg. loss: 0.1527, Accuracy: 9541/10000 (95%)


Test set: Avg. loss: 0.1202, Accuracy: 9646/10000 (96%)

