In [2]:
import torch
import torchvision
import torch.optim as optim
from data.data import get_train_test_loader
from model.network import TeacherNet
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f34186261d0>

In [4]:
ds_name = 'mnist'
train_loader, test_loader = get_train_test_loader(ds_name)

### train teacher net

In [11]:
network = TeacherNet()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

In [12]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [13]:
def train(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      #   epoch, batch_idx * len(data), len(train_loader.dataset),
      #   100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      # torch.save(network.state_dict(), './results/model.pth')
      # torch.save(optimizer.state_dict(), './results/optimizer.pth')
  return loss

In [14]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [15]:
test(network)
for epoch in range(1, n_epochs + 1):
  train(epoch, network)
  test(network)

  return F.log_softmax(x)



Test set: Avg. loss: 2.3036, Accuracy: 1135/10000 (11%)


Test set: Avg. loss: 0.3453, Accuracy: 9001/10000 (90%)


Test set: Avg. loss: 0.2073, Accuracy: 9397/10000 (94%)


Test set: Avg. loss: 0.1489, Accuracy: 9558/10000 (96%)



### define MProp

In [16]:
class TeacherNet_MProp(torch.nn.Module):
    def __init__(self, teacherNet):
        super(TeacherNet_MProp, self).__init__()
        self.fc1 = teacherNet.fc1
        self.distillable = teacherNet.distillable
        self.fc3 = teacherNet.fc3
        self.distilledLayer = teacherNet

        self.m = []

    def forward(self, x):
        # construct 2 forward passes such that loss can be propagated on two threads
        # a thread: output_a, fc5, M, fc1
        # b thread: output_b, fc5, fc4, fc3 (somehow need to stop gradient from propagated further up (right now trying to detach it))
        o_1 = torch.reshape(x, (x.shape[0], 28*28))
        o_2 = F.relu(self.fc1(o_1))

        self.o_3b = self.distillable(o_2.detach())

        a = torch.linalg.pinv(o_2)
        self.m = F.linear(a, torch.transpose(self.o_3b.detach(),0 ,1)).detach()
        
        self.o_3a = F.linear(o_2, torch.transpose(self.m,0 ,1)) # propagate loss through here
        o_4a = self.fc3(self.o_3a)
        o_4b = self.fc3(self.o_3b)
        return F.log_softmax(o_4a), F.log_softmax(o_4b) # return two outputs to propagate the loss on two threads

In [24]:
mProp = TeacherNet_MProp(TeacherNet()) # init untrained model
optimizer = optim.SGD(mProp.parameters(), lr=learning_rate, momentum=momentum)

### multiprocessing work

In [45]:
x1 = torch.ones(8)  # input tensor
y1 = torch.zeros(10)  # expected output
W1 = torch.randn(8, 10, requires_grad=True) # weights
b1 = torch.randn(10, requires_grad=True) # bias vector
z1 = torch.matmul(x1, W1) + b1 # output
# loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss1 = torch.nn.CrossEntropyLoss()(z1,y1)

x2 = torch.ones(8)  # input tensor
y2 = torch.zeros(10)  # expected output
W2 = torch.randn(8, 10, requires_grad=True) # weights
b2 = torch.randn(10, requires_grad=True) # bias vector
z2 = torch.matmul(x2, W1) + b2 # output
# loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss2 = torch.nn.CrossEntropyLoss()(z2,y2)

In [53]:
import torch.multiprocessing as mp
from threading import Thread
import time

def compute_back(loss):
    time.sleep(5)
    loss.backward(retain_graph=True)

# is memory shared
# p1 = mp.Process(target=compute_back, args=(loss1,))
# p2 = mp.Process(target=compute_back, args=(loss2,))
p1 = Thread(target=compute_back, args=[loss1])
print(p1)
p2 = Thread(target=compute_back, args=[loss2])
print(p2)
p1.start()
p2.start()

# p1.run()
# print(W1.grad)
# p2.run()
# print(W1.grad)
print(p1.is_alive())
# p1.join()
# p2.join()

p1.join()
p2.join()
print('test')

# loss1.backward() # thread 1
# loss2.backward() # thread 2

<Thread(Thread-50 (compute_back), initial)>
<Thread(Thread-51 (compute_back), initial)>
True
test


### train using two forward passes, and split backward computation into two threads

In [25]:
from threading import Thread

def compute_back(loss):
    loss.backward()#retain_graph=True)

def train_threads(epoch, network):
  # in python threading is used for concurrent processing, and works well for IO, but only one thread is executed at a time in a process. Thus this threading approach is useless in this case
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output1, output2 = network(data)
    # print(output1, output2)

    loss1 = F.nll_loss(output1, target)
    loss2 = F.nll_loss(output2, target)
    
    # print(loss1, loss2)

    p1 = Thread(target=compute_back, args=[loss1]) # start two independent threads
    p2 = Thread(target=compute_back, args=[loss2])
    p1.start()
    p2.start()
        
    p1.join() # wait for the two threads to finish
    p2.join()

    optimizer.step() 
    # return
    
    if batch_idx % log_interval == 0:
      # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      #   epoch, batch_idx * len(data), len(train_loader.dataset),
      #   100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss1.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      # torch.save(network.state_dict(), './results/model.pth')
      # torch.save(optimizer.state_dict(), './results/optimizer.pth')
  return loss1

In [None]:
import torch.multiprocessing as mp

def compute_back(loss):
    loss.backward()#retain_graph=True)

def train_parallel(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output1, output2 = network(data)
    # print(output1, output2)

    loss1 = F.nll_loss(output1, target)
    loss2 = F.nll_loss(output2, target)
    
    # print(loss1, loss2)

    p1 = Thread(target=compute_back, args=[loss1]) # start two independent threads
    p2 = Thread(target=compute_back, args=[loss2])
    p1.start()
    p2.start()
        
    p1.join() # wait for the two threads to finish
    p2.join()

    optimizer.step() 
    # return
    
    if batch_idx % log_interval == 0:
      # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      #   epoch, batch_idx * len(data), len(train_loader.dataset),
      #   100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss1.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      # torch.save(network.state_dict(), './results/model.pth')
      # torch.save(optimizer.state_dict(), './results/optimizer.pth')
  return loss1

In [26]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]


In [27]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output, _ = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [28]:
test(mProp)
for epoch in range(1, n_epochs + 1):
  train_parallel(epoch, mProp)
  test(mProp)
  # break

  return F.log_softmax(o_4a), F.log_softmax(o_4b) # return two outputs to propagate the loss on two threads



Test set: Avg. loss: 2.3020, Accuracy: 1036/10000 (10%)



AttributeError: 'Thread' object has no attribute 'duration'

In [114]:
# check to ensure loss not propagated thru sequential module layers

layers = {
    'fc1' : mGen.fc1.weight.data,
    'fc1b' : mGen.fc1.bias.data,
    'distil0' : mGen.distillable[0].weight.data,
    'distil2' : mGen.distillable[2].weight.data,
    'distil4' : mGen.distillable[4].weight.data,
    'fc3' : mGen.fc3.weight.data,
    'fc3b' : mGen.fc3.bias.data,
}
nf_ct = 0
def getBack(var_grad_fn):
    print(var_grad_fn)
    for n in var_grad_fn.next_functions:
        if n[0]:
            try:
                tensor = getattr(n[0], 'variable')
                print(n[0])
                # print('Tensor with grad found:', tensor)
                found = False
                for k in layers.keys():
                    if tensor.data.shape == layers[k].shape and tensor.data.eq( layers[k]).all():
                        found = True
                        print(k)
                        break
                if not found:
                    nf_ct += 1
                    print('unknown ', tensor.data.shape)
                # print(' - gradient:', tensor.grad)
                print()
            except AttributeError as e:
                getBack(n[0])


getBack(loss.grad_fn)
if nf_ct < 1:
    print('Found all expected tensors in computation graph')

<NllLossBackward0 object at 0x7fd47022e7a0>
<LogSoftmaxBackward0 object at 0x7fd30030c1f0>
<AddmmBackward0 object at 0x7fd30030ca60>
<AccumulateGrad object at 0x7fd30030d780>
fc3b

<MmBackward0 object at 0x7fd30030d7e0>
<ReluBackward0 object at 0x7fd30232c820>
<AddmmBackward0 object at 0x7fd3180f5420>
<AccumulateGrad object at 0x7fd3023240a0>
fc1b

<TBackward0 object at 0x7fd302324040>
<AccumulateGrad object at 0x7fd331c54490>
fc1

<TBackward0 object at 0x7fd30030e290>
<AccumulateGrad object at 0x7fd30232c820>
fc3

Found all expected tensors in computation graph
