In [1]:
import argparse
import logging
import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
from tqdm import tqdm

import utils
import model.net as net
import model.resnet as resnet
import model.data_loader as data_loader
from evaluate import evaluate, evaluate_kd

In [2]:
params = utils.Params('./experiments/cnn_distill_subset/params.json')
params.cuda = torch.cuda.is_available()
params.subset_percent = 0.05
random.seed(1)
torch.manual_seed(1)
if params.cuda: 
    torch.cuda.manual_seed(1)
utils.set_logger(os.path.join('./experiments/cnn_distill_subset', 'train.log'))
logging.info("Loading the datasets...")
train_dl = data_loader.fetch_subset_dataloader('train', params)
dev_dl = data_loader.fetch_subset_dataloader('dev', params)
logging.info("- done.")

Loading the datasets...


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


- done.


In [3]:
# Student model : Conv-5 network
model = net.Net(params).cuda() if params.cuda else net.Net(params)
optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
metrics = net.metrics

# Teacher model : ResNet-18 network (pretrained)
teacher_model = resnet.ResNet18()
teacher_checkpoint = './experiments/base_resnet18/best.pth.tar'
teacher_model = teacher_model.cuda() if params.cuda else teacher_model
utils.load_checkpoint(teacher_checkpoint, teacher_model)

logging.info("Experiment - model version: {}".format(params.model_version))
logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
logging.info("First, loading the teacher model and computing its outputs...")

Experiment - model version: cnn_distill
Starting training for 100 epoch(s)
First, loading the teacher model and computing its outputs...


In [4]:
def loss_fn_kd(outputs, labels, teacher_outputs, params):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha

    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    alpha = params.alpha
    T = params.temperature
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

In [None]:
best_val_acc = 0.0
# Teacher model은 학습하는게 아님. -> eval mode로 변경
teacher_model.eval()
scheduler = StepLR(optimizer, step_size=100, gamma=0.2)
val_acc_distill = []

for epoch in range(params.num_epochs):
    logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))
    model.train()
    summ = []
    loss_avg = utils.RunningAverage()
    for i, (train_batch, labels_batch) in enumerate(train_dl):
        if params.cuda:
            train_batch, labels_batch = train_batch.cuda(), labels_batch.cuda()
            # teacher model의 output 구하기
            with torch.no_grad():
                teacher_outputs = teacher_model(train_batch)
            # student model의 output 구하기
            output_batch = model(train_batch)
            # KD Loss
            loss = loss_fn_kd(output_batch, labels_batch, teacher_outputs, params)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_avg.update(loss.item())
    scheduler.step()
    val_metrics = evaluate_kd(model, dev_dl, metrics, params)     
    val_acc = val_metrics['accuracy']
    val_acc_distill.append(val_acc)
    is_best = val_acc>=best_val_acc
    #print("Epoch {}/{} | Val-Acc {:.2f}".format(epoch+1, params.num_epochs, val_acc))

Epoch 1/100
- Eval metrics : accuracy: 0.259 ; loss: 0.000
Epoch 2/100
- Eval metrics : accuracy: 0.363 ; loss: 0.000
Epoch 3/100
- Eval metrics : accuracy: 0.400 ; loss: 0.000
Epoch 4/100
- Eval metrics : accuracy: 0.459 ; loss: 0.000
Epoch 5/100
- Eval metrics : accuracy: 0.516 ; loss: 0.000
Epoch 6/100
- Eval metrics : accuracy: 0.535 ; loss: 0.000
Epoch 7/100
- Eval metrics : accuracy: 0.560 ; loss: 0.000
Epoch 8/100
- Eval metrics : accuracy: 0.591 ; loss: 0.000
Epoch 9/100
- Eval metrics : accuracy: 0.573 ; loss: 0.000
Epoch 10/100
- Eval metrics : accuracy: 0.579 ; loss: 0.000
Epoch 11/100
- Eval metrics : accuracy: 0.610 ; loss: 0.000
Epoch 12/100
- Eval metrics : accuracy: 0.608 ; loss: 0.000
Epoch 13/100
- Eval metrics : accuracy: 0.591 ; loss: 0.000
Epoch 14/100
- Eval metrics : accuracy: 0.596 ; loss: 0.000
Epoch 15/100
- Eval metrics : accuracy: 0.586 ; loss: 0.000
Epoch 16/100
- Eval metrics : accuracy: 0.616 ; loss: 0.000
Epoch 17/100
- Eval metrics : accuracy: 0.608 ; l

In [None]:
params = utils.Params('./experiments/base_cnn_subset/params.json')
params.cuda = True
params.subset_percent = 0.05
utils.set_logger(os.path.join('./experiments/base_cnn_subset', 'train.log'))
logging.info("Loading the datasets...")
train_dl = data_loader.fetch_subset_dataloader('train', params)
dev_dl = data_loader.fetch_subset_dataloader('dev', params)
logging.info("- done.")

In [None]:
# Student model : Conv-5 network
model = net.Net(params).cuda() if params.cuda else net.Net(params)
optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
metrics = net.metrics

best_val_acc = 0.0
scheduler = StepLR(optimizer, step_size=100, gamma=0.2)
val_acc_nodistill = []

for epoch in range(params.num_epochs):
    logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))
    model.train()
    summ = []
    loss_avg = utils.RunningAverage()
    for i, (train_batch, labels_batch) in enumerate(train_dl):
        if params.cuda:
            train_batch, labels_batch = train_batch.cuda(), labels_batch.cuda()
            output_batch = model(train_batch)
            # CE Loss
            loss = nn.CrossEntropyLoss()(output_batch, labels_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_avg.update(loss.item())
    scheduler.step()
    val_metrics = evaluate_kd(model, dev_dl, metrics, params)     
    val_acc = val_metrics['accuracy']
    val_acc_nodistill.append(val_acc)
    is_best = val_acc>=best_val_acc

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(1, 51), val_acc_distill)
plt.plot(range(1, 51), val_acc_nodistill)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['Distill', 'No Distill'])