In [1]:
import torch
import torchvision
import torch.optim as optim
from data.data import get_train_test_loader
from model.network import TeacherNet
import torch.nn.functional as F

In [2]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f3648120e30>

In [3]:
train_loader, test_loader = get_train_test_loader()

In [13]:
network = TeacherNet()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

In [5]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [17]:
def train(epoch, network):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      #   epoch, batch_idx * len(data), len(train_loader.dataset),
      #   100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      # torch.save(network.state_dict(), './results/model.pth')
      # torch.save(optimizer.state_dict(), './results/optimizer.pth')

In [16]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [8]:
test(network)
for epoch in range(1, n_epochs + 1):
  train(epoch, network)
  test(network)

  return F.log_softmax(x)



Test set: Avg. loss: 2.3013, Accuracy: 1388/10000 (14%)


Test set: Avg. loss: 0.3622, Accuracy: 8947/10000 (89%)


Test set: Avg. loss: 0.2077, Accuracy: 9382/10000 (94%)


Test set: Avg. loss: 0.1572, Accuracy: 9524/10000 (95%)



In [9]:
# this part is distill teacher network
network.distill = True
optimizer = optim.SGD(network.distilledLayer.parameters(), lr=learning_rate, momentum=momentum)
test(network)
for epoch in range(1, n_epochs + 1):
  train(epoch, network)
  test(network)

  return F.log_softmax(x)



Test set: Avg. loss: 2.3936, Accuracy: 715/10000 (7%)


Test set: Avg. loss: 0.2006, Accuracy: 9462/10000 (95%)


Test set: Avg. loss: 0.1727, Accuracy: 9522/10000 (95%)


Test set: Avg. loss: 0.1598, Accuracy: 9548/10000 (95%)



In [74]:
class TeacherNet_MGen(torch.nn.Module):
    def __init__(self, teacherNet):
        super(TeacherNet_MGen, self).__init__()
        self.fc1 = teacherNet.fc1
        self.distillable = teacherNet.distillable
        self.fc3 = teacherNet.fc3
        self.distilledLayer = teacherNet

        self.m = []

    def forward(self, x):
        o_1 = torch.reshape(x, (x.shape[0], 28*28))
        
        o_2 = F.relu(self.fc1(o_1))

        o_3 = self.distillable(o_2)
        o_3 = o_3.detach()
        a = torch.linalg.pinv(o_2)
        m = F.linear(a, torch.transpose(o_3,0 ,1)).detach()
        o_3 = F.linear(o_2, torch.transpose(m,0 ,1))
        o_4 = self.fc3(o_3)
        return F.log_softmax(o_4)

In [75]:

mGen = TeacherNet_MGen(network)

In [76]:
test(mGen)

  return F.log_softmax(o_4)



Test set: Avg. loss: 0.1623, Accuracy: 9517/10000 (95%)

