In [None]:
import torch
import torch.nn as nn
import random

import numpy as np
import matplotlib.pyplot as plt
from tools.data_setter import cifar_100_setter
from models import cifar

In [None]:
def fix_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)

In [None]:
def calc_dist(dataloaders, teacher_str, student_str, mode, device, seed):
    fix_seed(seed)
    device = torch.device(device)

    # Load teacher and original student
    teacher = cifar.WideResNet(depth=int(teacher_str.split('-')[1]),
                               widen_factor=int(teacher_str.split('-')[2]),
                               num_classes=100)

    student = cifar.WideResNet(depth=int(student_str.split('-')[1]),
                               widen_factor=int(student_str.split('-')[2]),
                               num_classes=100)

    if mode == 'original':
        teacher_filename = './model_checkpoints/cifar100/None/{}/alp_0.1_T_1.0/random_highest_1.0_random_highest_1.0_seed9999_none.t1'.format(teacher_str)
        student_filename = './model_checkpoints/cifar100/None/{}/alp_0.1_T_1.0/random_highest_1.0_random_highest_1.0_seed9999_none.t1'.format(student_str)
    elif mode == 'distill':
        teacher_filename = './model_checkpoints/cifar100/None/{}/alp_0.1_T_1.0/random_highest_1.0_random_highest_1.0_seed9999_none.t1'.format(teacher_str)
        student_filename = './model_checkpoints/cifar100/{}/{}/alp_1.0_T_20.0/random_highest_1.0_random_highest_1.0_seed9999_none.t1'.format(teacher_str, student_str)

    teacher_checkpoint = torch.load(teacher_filename, map_location='cpu')['199']
    teacher.load_state_dict(teacher_checkpoint, strict=True)
    teacher.eval().to(device)

    student_checkpoint = torch.load(student_filename, map_location='cpu')['199']
    student.load_state_dict(student_checkpoint, strict=True)
    student.eval().to(device)

    # Set optimizer and loss
    teacher_batch_params = [module for module in teacher.parameters() if module.ndimension() == 1]
    teacher_other_params = [module for module in teacher.parameters() if module.ndimension() > 1]
    teacher_optimizer = torch.optim.SGD([{'params': teacher_batch_params, 'weight_decay': 0},
                                         {'params': teacher_other_params, 'weight_decay': 5e-4}],
                                         lr=1e-3,
                                         momentum=0.9,
                                         nesterov=False)

    student_batch_params = [module for module in student.parameters() if module.ndimension() == 1]
    student_other_params = [module for module in student.parameters() if module.ndimension() > 1]
    student_optimizer = torch.optim.SGD([{'params': student_batch_params, 'weight_decay': 0},
                                         {'params': student_other_params, 'weight_decay': 5e-4}],
                                         lr=1e-3,
                                         momentum=0.9,
                                         nesterov=False)

    criterion = nn.CrossEntropyLoss()
    
    teacher_input_grad = torch.tensor([])
    student_input_grad = torch.tensor([])
    distance_input_grad_list = []

    for i, data in enumerate(dataloaders['train']):
        image = data[0].type(torch.FloatTensor).to(device)
        label = data[1].type(torch.LongTensor).to(device)

        image.requires_grad = True
        if image.grad is not None:
            image.grad.data.fill_(0.)
        teacher_optimizer.zero_grad()
        teacher_label = teacher(image)
        teacher_loss = criterion(teacher_label, label)
        teacher_loss.backward()

        teacher_input_grad = torch.cat([teacher_input_grad, image.grad.data.cpu()], dim=0)
        
        if image.grad is not None:
            image.grad.data.fill_(0.)
        student_optimizer.zero_grad()
        student_label = student(image)
        student_loss = criterion(student_label, label)
        student_loss.backward()
        
        student_input_grad = torch.cat([student_input_grad, image.grad.data.cpu()], dim=0)
        
    for t, s in list(zip(teacher_input_grad, student_input_grad)):
        distance_input_grad_list.append(torch.norm(t-s).item())
    
    return distance_input_grad_list

---

In [None]:
dataloaders, dataset_size = cifar_100_setter(teacher=None,
                                             mode=None,
                                             batch_size=128,
                                             root='/home/osilab7/hdd/cifar',
                                             model_name='cifar100/wrn-16-2/wrn-16-4/alp_1.0_T_20.0/random_highest_1.0_random_highest_1.0_seed9999_none',
                                             cls_acq='random',
                                             cls_order='highest',
                                             zeta=1.0,
                                             sample_acq='random',
                                             sample_order='highest',
                                             delta=1.0)

In [None]:
teacher_str = 'wrn-28-4'
student_str = 'wrn-16-4'
device = 'cuda:0'

max_val = 0.3
bin_num = 30

fig, axes = plt.subplots(1, 2, figsize=(16,6))
for idx, mode in enumerate(['original', 'distill']):
    distance_input_grad_list = calc_dist(dataloaders=dataloaders,
                                         teacher_str=teacher_str,
                                         student_str=student_str,
                                         mode=mode,
                                         device=device,
                                         seed=9999)
    
    hist, bins = np.histogram(distance_input_grad_list, bins=bin_num, range=(0, max_val))
    bins_center = []
    for i in range(len(bins)-1):
        bins_center.append( (bins[i] + bins[i+1]) / 2 )

    axes[idx].grid()
    axes[idx].bar(bins_center, hist, width=max_val/bin_num)
    axes[idx].set_ylim(0, 50000+500.)
    axes[idx].set_title('{} student'.format(mode))

fig.suptitle('teacher: {} & student: {}'.format(teacher_str, student_str))
fig.tight_layout()
fig.subplots_adjust(top=0.88)
plt.show()
plt.close()