In [1]:
import warnings
warnings.filterwarnings('ignore')

import os

import os.path as p
from evaluation import evaluate_base, evaluate_ensemble
from cnn_be import CNN_be
from cnn import CNN
import torch
import torchvision
from torch.nn.utils import prune
from torchvision import datasets, transforms

In [3]:
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])

test_dataset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Single Model & Ensemble Teacher Model 

In [3]:
# model update
tmodel = []
ckpt = '/home/chaeyoon-jang/test/fashion_mnist/ckpt2'
for num in range(4):
    seed = 42 + num
    model = CNN()
    ckpt_p = p.join(ckpt, str(seed)+"_teacher_model_checkpoint.pt")
    model.load_state_dict(torch.load(ckpt_p)['model_state_dict'])
    tmodel.append(model)

# evaluate singel model
evaluate_base(tmodel[0], test_loader, device)

# evaluate ensemble teacher model
evaluate_ensemble(tmodel, test_loader, 0, device)

79it [00:08,  9.81it/s]


0.8845925632911392
0.4163937359821947
0.062311239540576935
0.4461145848262159
0.09331009536981583


79it [00:02, 36.45it/s]

0.9081289556962026
0.32838801776306537
0.05616796016693115
0.3927274647845498
0.10199322551488876





# Knowledge Distillation Model

In [4]:
# model update
ckpt = "/home/chaeyoon-jang/test/fashion_mnist/ckpt2/KD_model_checkpoint_epoch_166.pt"
kd = CNN()
kd.load_state_dict(torch.load(ckpt)['model_state_dict'])

# evaluate kd model
evaluate_base(kd, test_loader, device)

79it [00:07,  9.98it/s]

0.8855814873417721
0.40195695001867754
0.06054038181900978
0.4642007240464416
0.1084030494093895





# General BatchEnsemble & LatentBE

In [6]:
# model update
ckpt = "/home/chaeyoon-jang/test/fashion_mnist/ckpt/LatentBE_model_checkpoint_epoch_974.pt"

latentbe = CNN_be(inference=True, bias_is=True)
generalbe = CNN_be(bias_is=True)

latentbe.load_state_dict(torch.load(ckpt)['model_state_dict'])
generalbe.load_state_dict(torch.load(ckpt)['model_state_dict'])

# make latentBE
for name, p in latentbe.named_parameters():
    if name == "layer1.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer1.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer1.0.bias'), dim=0)
        
    elif name == "layer2.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer2.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer2.0.bias'), dim=0)
        
    elif name == "fc1.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.s_factor'), dim=0).view(-1, 1))
    
    elif name == "fc1.bias":
        p.data = torch.mean(latentbe.get_parameter('fc1.bias'), dim=0)
        
    elif name == "fc2.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.s_factor'), dim=0).view(-1, 1))
        
    elif name == "fc2.bias":
        p.data = torch.mean(latentbe.get_parameter('fc2.bias'), dim=0)
        
    elif name == "fc3.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.s_factor'), dim=0).view(-1, 1))

    elif name == "fc3.bias":
        p.data = torch.mean(latentbe.get_parameter('fc3.bias'), dim=0)

# evaluate general BE
evaluate_ensemble(generalbe, test_loader, 4, device)

# evaluate latentBE
evaluate_base(latentbe, test_loader, device)

79it [00:02, 34.00it/s]


0.8587816455696202
0.5114394671403909
0.0753195583820343
0.7353108087370667
0.21555544435977936


79it [00:06, 11.44it/s]

0.858682753164557
0.5113549387605885
0.0753621831536293
0.7345952287504944
0.21516306698322296





# LatentBE + div

In [4]:
# model update
ckpt = "/home/chaeyoon-jang/test/fashion_mnist/ckpt2/LatentBE_div_model_checkpoint_epoch_99.pt"

latentbe = CNN_be(inference=True, bias_is=True)

latentbe.load_state_dict(torch.load(ckpt)['model_state_dict'])

# make latentBE
for name, p in latentbe.named_parameters():
    if name == "layer1.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer1.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer1.0.bias'), dim=0)
        
    elif name == "layer2.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer2.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer2.0.bias'), dim=0)
        
    elif name == "fc1.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.s_factor'), dim=0).view(-1, 1))
    
    elif name == "fc1.bias":
        p.data = torch.mean(latentbe.get_parameter('fc1.bias'), dim=0)
        
    elif name == "fc2.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.s_factor'), dim=0).view(-1, 1))
        
    elif name == "fc2.bias":
        p.data = torch.mean(latentbe.get_parameter('fc2.bias'), dim=0)
        
    elif name == "fc3.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.s_factor'), dim=0).view(-1, 1))

    elif name == "fc3.bias":
        p.data = torch.mean(latentbe.get_parameter('fc3.bias'), dim=0)

# evaluate latentBE
evaluate_base(latentbe, test_loader, device)

79it [00:07, 10.92it/s]

0.8520569620253164
0.5166575372671779
0.07369819283485413
0.7391640448268455
0.21054105460643768



