In [None]:
import os
import torch
import scipy
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import scipy.stats
from tqdm import tqdm

from tools.data_setter import cifar_100_setter
from models import cifar, imagenet

---

In [None]:
# dataset fix
dataloaders, dataset_size = cifar_100_setter(teacher=None,
                                             mode='crop',
                                             batch_size=128,
                                             root='./data',
                                             model_name='cifar100/None/wrn-28-4/alp_0.1_T_1.0/random_0.0-1.0_random_0.0-1.0_seed9999_none_noclas.t1',
                                             per_class=False,
                                             cls_acq='random',
                                             cls_lower_qnt=0.0,
                                             cls_upper_qnt=1.0,
                                             sample_acq='random',
                                             sample_lower_qnt=0.0,
                                             sample_upper_qnt=1.0)

for i, (image, target) in enumerate(dataloaders['train']):
    if i == 0:
        images = image.cpu()
        targets = target.cpu()
    else:
        images = torch.cat([images, image.cpu()], dim=0)
        targets = torch.cat([targets, target.cpu()], dim=0)
        
dataloaders['train'] = torch.utils.data.DataLoader(list(zip(images, targets)), batch_size=128, shuffle=False)

In [None]:
pairs = [('wrn-28-4', 'wrn-16-2'), ('wrn-16-2', 'wrn-16-2'), ('wrn-16-2', 'wrn-28-4')]

In [None]:
matplotlib.rcParams.update({'font.size': 16})
pair = pairs[0]
t_list = ["1.0", "3.0", "5.0", "20.0"]
alpha_list = ["1.0"]

# t_list = ["20.0"]
# alpha_list = ["0.1", "0.4", "0.7", "1.0"]

device = torch.device('cuda:0')

plt.figure(figsize=(10,7)) # 10,7


if pair[0] == 'wrn-28-4':
    teacher = cifar.WideResNet(depth=28, widen_factor=4, num_classes=100)
    filename = './model_checkpoints/cifar100/None/wrn-28-4/alp_0.1_T_1.0/random_0.0-1.0_random_0.0-1.0_seed9999_none_noclas.t1'
else:
    teacher = cifar.WideResNet(depth=16, widen_factor=2, num_classes=100)
    filename = './model_checkpoints/cifar100/None/wrn-16-2/alp_0.1_T_1.0/random_0.0-1.0_random_0.0-1.0_seed9999_none_noclas.t1'

checkpoint = torch.load(filename, map_location=device)['199']
teacher.load_state_dict(checkpoint, strict=True)
teacher.eval()
teacher.to(device)

for i, data in enumerate(dataloaders['train']):
    image = data[0].type(torch.FloatTensor).to(device)
    label = data[1].type(torch.LongTensor).to(device)

    if i==0:
        labels = label.cpu()
        teachers_labels = teacher(image).detach().cpu()
    else:
        labels = torch.cat([labels, label.cpu()])
        teachers_labels = torch.cat([teachers_labels, teacher(image).detach().cpu()], dim=0)

accuracy = labels == torch.argmax(teachers_labels, dim=1)    
values, _ = torch.topk(teachers_labels, k=3, dim=1)
top1_top2_list = values[:,0]-values[:,1]

gt_list=[]
for i in range(len(teachers_labels)):
    gt_list.append(teachers_labels[i,labels[i].item()].item())
gt_list = torch.Tensor(gt_list)
top1_gt_list = values[:,0] - gt_list
top1_top2_true_list = top1_top2_list[accuracy]
top1_top2_false_list = -top1_gt_list[~accuracy]

top1_top2_total_list = top1_top2_true_list.tolist() + top1_top2_false_list.tolist()
hist = np.histogram(top1_top2_total_list, bins=300, range=(-10, 20))
hist_dist = scipy.stats.rv_histogram(hist)

X = np.linspace(-10, 20, 10000)
plt.plot(X, hist_dist.pdf(X), label='teacher')

for alpha in alpha_list:
    for t in t_list:
        if pair[0] == 'wrn-28-4':
            student = cifar.WideResNet(depth=16, widen_factor=2, num_classes=100)
            filename = './model_checkpoints/cifar100/wrn-28-4/wrn-16-2/alp_{}_T_{}/random_0.0-1.0_random_0.0-1.0_seed9999_none.t1'.format(alpha, t)
        elif pair[1] == 'wrn-16-2':
            student = cifar.WideResNet(depth=16, widen_factor=2, num_classes=100)
            filename = './model_checkpoints/cifar100/wrn-16-2/wrn-16-2/alp_{}_T_{}/random_0.0-1.0_random_0.0-1.0_seed9999_none.t1'.format(alpha, t)
        elif pair[1] == 'wrn-28-4':
            student = cifar.WideResNet(depth=28, widen_factor=4, num_classes=100)
            filename = './model_checkpoints/cifar100/wrn-16-2/wrn-28-4/alp_{}_T_{}/random_0.0-1.0_random_0.0-1.0_seed9999_none.t1'.format(alpha, t)

        checkpoint = torch.load(filename, map_location=device)['199']
        
        student.load_state_dict(checkpoint, strict=True)
        student.eval()
        student.to(device)

        for i, data in enumerate(dataloaders['train']):
            image = data[0].type(torch.FloatTensor).to(device)
            label = data[1].type(torch.LongTensor).to(device)

            if i==0:
                labels = label.cpu()
                students_labels = student(image).detach().cpu()
            else:
                labels = torch.cat([labels, label.cpu()])
                students_labels = torch.cat([students_labels, student(image).detach().cpu()], dim=0)

        accuracy = labels == torch.argmax(students_labels, dim=1)    
        values, _ = torch.topk(students_labels, k=3, dim=1)
        top1_top2_list = values[:,0]-values[:,1]

        gt_list=[]
        for i in range(len(students_labels)):
            gt_list.append(students_labels[i,labels[i].item()].item())
        gt_list = torch.Tensor(gt_list)
        top1_gt_list = values[:,0] - gt_list
        top1_top2_true_list = top1_top2_list[accuracy]
        top1_top2_false_list = -top1_gt_list[~accuracy]
        
        top1_top2_total_list = top1_top2_true_list.tolist() + top1_top2_false_list.tolist()
        hist = np.histogram(top1_top2_total_list, bins=300, range=(-10, 20))
        hist_dist = scipy.stats.rv_histogram(hist)

        plt.plot(X, hist_dist.pdf(X), label='student_τ{}'.format(t.split('.')[0]))
#         plt.plot(X, hist_dist.pdf(X), label='student_α{}'.format(alpha))
        
plt.yticks([0, 0.05, 0.10, 0.15, 0.20])
plt.legend()
plt.show()
# plt.savefig('{}_{}_top1_top2_tau.pdf'.format(pair[0], pair[1]), bbox_inches='tight', format='pdf')
plt.close()