In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm

from tools.data_setter import cifar_100_setter
from models import cifar, imagenet

---

# Accuracy

In [None]:
dirname = os.getcwd()
dataset = 'cifar10'
teacher = 'wrn-28-4'
student = 'wrn-16-2'

result_dirname = os.path.join(dirname, 'results', dataset, teacher, student)

t_list = ["1.0", "3.0", "5.0", "20.0", "inf"]
alpha_list = ["0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9", "1.0"]
accuracy_list = np.zeros([len(t_list), len(alpha_list)])

t_list = ["20.0"]
alpha_list = ["1.0"]

for t_enum, t in enumerate(t_list):
    for alpha_enum, alpha in enumerate(alpha_list):
        accuracy = []
        current_dirname = os.path.join(result_dirname, 'alp_{}_T_{}'.format(alpha, t))
        current_files = os.listdir(current_dirname)
        
        current_files = [f for f in current_files if 'csv' in f]
        current_files = [f for f in current_files if f.split('_')[2]=="1.0" and f.split('_')[5]=="1.0"]
        for file in current_files:
            accuracy.append(max(list(pd.read_csv(os.path.join(current_dirname, file))['train_accuracy'])))
    accuracy_list[t_enum, alpha_enum] = np.mean(accuracy)

In [None]:
matplotlib.rcParams.update({'font.size': 20})

plt.figure(figsize=(10, 4))
plt.pcolor(accuracy_list, cmap=plt.get_cmap('Blues'))
plt.xticks(np.arange(0.5, len(alpha_list), 1), alpha_list)
plt.yticks(np.arange(0.5, len(t_list), 1), t_list)
plt.xlabel('Alpha')
plt.ylabel('Temperature')

plt.colorbar()
plt.show()
plt.subplots_adjust(wspace=0.5)
# plt.savefig('train_accuracy.pdf', bbox_inches='tight', format='pdf')
plt.close()

---

# Entropy heatmap & 4 cases analysis

In [None]:
def get_entropy_list(dataloaders, dataloaders_mode, alpha=None, t=None, mode=None):
    device = torch.device('cuda:1')

    teacher = cifar.WideResNet(depth=28, widen_factor=4, num_classes=100)
    filename = './model_checkpoints/cifar100/None/wrn-28-4/alp_0.1_T_1.0/random_highest_1.0_random_highest_1.0_seed9999.t1'
    checkpoint = torch.load(filename, map_location=device)['199']
    teacher.load_state_dict(checkpoint, strict=True)
    
    student = cifar.WideResNet(depth=16, widen_factor=2, num_classes=100)
    filename = './model_checkpoints/cifar100/wrn-28-4/wrn-16-2/alp_{}_T_{}/random_highest_1.0_random_highest_1.0_seed9999.t1'.format(alpha, t)
    checkpoint = torch.load(filename, map_location=device)['199']
    student.load_state_dict(checkpoint, strict=True)
    
    teacher.eval().to(device)
    student.eval().to(device)
    
    images = []
    labels = []
    entropy_list = []
    student_labels = []
    
    for i, data in enumerate(dataloaders[dataloaders_mode]):
        image = data[0].type(torch.FloatTensor).to(device)
        label = data[1].type(torch.LongTensor).to(device)
        
        teacher_label = teacher(image)
        student_label = student(image)
        
        teacher_prob = torch.softmax(teacher_label, dim=1)
        entropy = torch.sum(-teacher_prob*torch.log(teacher_prob), dim=1)
        
        student_label = torch.max(student_label, dim=1)[1]
    
        labels += label.tolist()
        entropy_list += entropy.tolist()
        student_labels += student_label.tolist()

    return labels, entropy_list, student_labels

def check_tf_entropy(labels, entropy_list, student_labels):
    gt = (np.array(student_labels)==np.array(labels)).tolist()
    
    tf_dict = {}
    keys = ['t', 'f']
    for key in keys:
        tf_dict[key] = []

    for idx, gt_ in enumerate(gt):
        key = str(gt_).lower()[0]
        tf_dict[key].append(idx)
        
    selected_entropy_dict = {}
    for k, v in tf_dict.items():
        selected_entropy_list = []
        for idx in v:
            selected_entropy_list.append(entropy_list[idx])
        selected_entropy_dict[k] = np.mean(selected_entropy_list)
        
    return selected_entropy_dict['t'], selected_entropy_dict['f']

def check_changed_index(labels, student1_labels, student2_labels):
    changed_index_dict = {}
    keys = ['t->t', 't->f', 'f->t', 'f->f']
    for key in keys:
        changed_index_dict[key] = []
        
    gt1 = (np.array(student1_labels)==np.array(labels)).tolist()
    gt2 = (np.array(student2_labels)==np.array(labels)).tolist()
    
    for idx, (gt1_, gt2_) in enumerate(zip(gt1, gt2)):
        pre = str(gt1_).lower()[0]
        post = str(gt2_).lower()[0]
        key = pre + '->' + post
        changed_index_dict[key].append(idx)
        
    return changed_index_dict

In [None]:
# dataset fix
teacher = cifar.WideResNet(depth=28, widen_factor=4, num_classes=100)
dataloaders, dataset_size = cifar_100_setter(teacher=teacher,
                                             mode='crop',
                                             batch_size=128,
                                             root='/home/osilab7/hdd/cifar',
                                             model_name='cifar100/wrn-28-4/wrn-16-2/alp_1.0_T_20.0/random_highest_1.0_random_highest_1.0_seed9999.t1',
                                             cls_acq='random',
                                             cls_order='highest',
                                             zeta=1.0,
                                             sample_acq='random',
                                             sample_order='highest',
                                             delta=1.0)

for i, (image, target) in enumerate(dataloaders['train']):
    if i == 0:
        images = image.cpu()
        targets = target.cpu()
    else:
        images = torch.cat([images, image.cpu()], dim=0)
        targets = torch.cat([targets, target.cpu()], dim=0)
        
dataloaders['train'] = torch.utils.data.DataLoader(list(zip(images, targets)), batch_size=128, shuffle=False)

In [None]:
dataloaders_mode = "train"
t_list = ["1.0", "3.0", "5.0", "20.0", "inf"]
alpha_list = ["0.1", "0.2", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "0.9", "1.0"]
true_entropy_list = np.zeros([len(t_list), len(alpha_list)])
false_entropy_list = np.zeros([len(t_list), len(alpha_list)])

for t_enum, t in enumerate(t_list):
    for alpha_enum, alpha in enumerate(alpha_list):
        labels, entropy_list, student_labels = get_entropy_list(dataloaders, dataloaders_mode, alpha, t)
        true_entropy, false_entropy = check_tf_entropy(labels, entropy_list, student_labels)
        true_entropy_list[t_enum, alpha_enum] = true_entropy
        false_entropy_list[t_enum, alpha_enum] = false_entropy

In [None]:
plt.figure(figsize=(10,4))
plt.pcolor(true_entropy_list, cmap=plt.get_cmap('Blues'))
plt.xticks(np.arange(0.5, len(alpha_list), 1), alpha_list)
plt.yticks(np.arange(0.5, len(t_list), 1), t_list)
plt.xlabel('Alpha')
plt.ylabel('Temperature')

plt.colorbar()
plt.show()
# plt.savefig('train_entropy.pdf', bbox_inches='tight', format='pdf')
plt.close()

---