In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision.transforms as transforms

from tools.data_setter import cifar_100_setter
from models import cifar, imagenet

---

In [None]:
device = torch.device('cuda:0')

teacher = cifar.WideResNet(depth=16, widen_factor=4, num_classes=100)
filename = './model_checkpoints/cifar100/None/wrn-16-4/alp_0.1_T_1.0/random_highest_1.0_random_highest_1.0_seed9999.t1'
checkpoint = torch.load(filename, map_location=device)['199']
teacher.load_state_dict(checkpoint, strict=True)
teacher.eval()

In [None]:
dataloaders, dataset_size = cifar_100_setter(teacher=teacher,
                                             mode=None,
                                             batch_size=10,
                                             root='/home/osilab7/hdd/cifar',
                                             model_name='cifar100/None/wrn-16-4/alp_0.1_T_1.0/random_highest_1.0_random_highest_1.0_seed1.t1',
                                             cls_acq='random',
                                             cls_order='highest',
                                             zeta=1.0,
                                             sample_acq='random',
                                             sample_order='highest',
                                             delta=1.0)

sample_lst = next(iter(dataloaders['train']))
img_lst = sample_lst[0]
label_lst = sample_lst[1]

mean = torch.tensor([0.5071, 0.4865, 0.4409]).view(3, 1, 1)
std = torch.tensor([0.2673, 0.2564, 0.2762]).view(3, 1, 1)
pil_trans = transforms.ToPILImage()

original_img_lst = []
for s in img_lst:
    original_img_lst.append(pil_trans(s*std+mean))

---

In [None]:
def get_entropy(img):
    probs = torch.softmax(teacher(img.unsqueeze(0)), dim=1)
    log_probs = torch.log(probs)
    entropy = -torch.sum(probs*log_probs)
    
    pred_label = torch.argmax(probs).item()
    return pred_label, entropy

In [None]:
def plot_per_img_trans(img_trans=None):
    fig, axes = plt.subplots(1, 10, figsize=(25, 10))

    nrm_trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])])

    for idx, (s, l) in enumerate(list(zip(original_img_lst, label_lst))):
        # Add trans except for nrm_trans
        if img_trans is not None:
            trans_s = img_trans(s)
        else:
            trans_s = s
        axes[idx].imshow(trans_s)

        # Get entropy after nrm_trans
        pred_label, entropy = get_entropy(nrm_trans(trans_s))
        axes[idx].set_title('{:.6f} / {}'.format(entropy.item(), l==pred_label))

    plt.show()
    plt.close()

In [None]:
# None
plot_per_img_trans()

In [None]:
# Horizontal Flip
img_trans = transforms.Compose([transforms.RandomHorizontalFlip(p=1.0)])
plot_per_img_trans(img_trans)

In [None]:
# Crop
img_trans = transforms.Compose([transforms.RandomCrop(32, padding=4)])
plot_per_img_trans(img_trans)

In [None]:
# Flip and Crop
img_trans = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                transforms.RandomHorizontalFlip()])
plot_per_img_trans(img_trans)