In [1]:
import torch
import torchvision
import torch.optim as optim
from data.data import get_train_test_loader
from model.network import TeacherNet
import torch.nn.functional as F

In [2]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f1f8891aeb0>

In [3]:
train_loader, test_loader = get_train_test_loader()

In [4]:
network = TeacherNet()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

In [5]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [6]:
def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      #   epoch, batch_idx * len(data), len(train_loader.dataset),
      #   100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      # torch.save(network.state_dict(), './results/model.pth')
      # torch.save(optimizer.state_dict(), './results/optimizer.pth')

In [7]:
def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [8]:
# this part is training teacher network

test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

  return F.log_softmax(x)


Test set: Avg. loss: 2.3013, Accuracy: 1388/10000 (14%)

Test set: Avg. loss: 0.3622, Accuracy: 8947/10000 (89%)

Test set: Avg. loss: 0.2077, Accuracy: 9382/10000 (94%)

Test set: Avg. loss: 0.1572, Accuracy: 9524/10000 (95%)



In [9]:
# this part is distill teacher network
network.distill = True
optimizer = optim.SGD(network.distilledLayer.parameters(), lr=learning_rate, momentum=momentum)
test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

Test set: Avg. loss: 0.1572, Accuracy: 9524/10000 (95%)

Test set: Avg. loss: 0.1572, Accuracy: 9524/10000 (95%)



KeyboardInterrupt: 