In [1]:
import warnings
warnings.filterwarnings('ignore')

import os

import os.path as p
from evaluation import evaluate_base, evaluate_ensemble
from cnn_be import CNN_be
from cnn import CNN
import torch
import torchvision
from torch.nn.utils import prune
from torchvision import datasets, transforms

In [2]:
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])

train_dataset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
test_dataset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)

train_size = int(0.8*len(train_dataset))
valid_size = len(train_dataset) - train_size

generator = torch.Generator()
generator.manual_seed(0)
    
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size], generator=generator)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False) 
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Single Model & Ensemble Teacher Model 

In [3]:
# model update
tmodel = []
ckpt = '/home/chaeyoon-jang/test/l/ckpt2'
for num in range(4):
    seed = 42 + num
    model = CNN()
    ckpt_p = p.join(ckpt, str(seed)+"_teacher_model_checkpoint.pt")
    model.load_state_dict(torch.load(ckpt_p)['model_state_dict'])
    tmodel.append(model)

# evaluate singel model
evaluate_base(tmodel[0], test_loader, valid_loader, device)

# evaluate ensemble teacher model
evaluate_ensemble(tmodel, test_loader, valid_loader, 0, device)

79it [00:01, 57.36it/s]


0.8845925632911392
0.331267992529688
0.062311239540576935
0.3181883368311049
0.0593358613550663


79it [00:01, 50.59it/s]

0.9081289556962026
0.25864631088474127
0.05616796016693115
0.25673888887785656
0.05613895133137703





# Knowledge Distillation Model

In [4]:
# model update
ckpt = "/home/chaeyoon-jang/test/l/ckpt2/KD_model_checkpoint_epoch_166.pt"
kd = CNN()
kd.load_state_dict(torch.load(ckpt)['model_state_dict'])

# evaluate kd model
evaluate_base(kd, test_loader, valid_loader, device)

79it [00:01, 57.75it/s]

0.8855814873417721
0.31700533892534954
0.06054038181900978
0.313682463161553
0.06377920508384705





# General BatchEnsemble & LatentBE

In [6]:
# model update
ckpt = "/home/chaeyoon-jang/test/l/ckpt/LatentBE_model_checkpoint_epoch_194.pt"

latentbe = CNN_be(inference=True, bias_is=True)
generalbe = CNN_be(bias_is=True)

latentbe.load_state_dict(torch.load(ckpt)['model_state_dict'])
generalbe.load_state_dict(torch.load(ckpt)['model_state_dict'])

# make latentBE
for name, p in latentbe.named_parameters():
    if name == "layer1.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer1.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer1.0.bias'), dim=0)
        
    elif name == "layer2.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer2.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer2.0.bias'), dim=0)
        
    elif name == "fc1.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.s_factor'), dim=0).view(-1, 1))
    
    elif name == "fc1.bias":
        p.data = torch.mean(latentbe.get_parameter('fc1.bias'), dim=0)
        
    elif name == "fc2.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.s_factor'), dim=0).view(-1, 1))
        
    elif name == "fc2.bias":
        p.data = torch.mean(latentbe.get_parameter('fc2.bias'), dim=0)
        
    elif name == "fc3.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.s_factor'), dim=0).view(-1, 1))

    elif name == "fc3.bias":
        p.data = torch.mean(latentbe.get_parameter('fc3.bias'), dim=0)

# evaluate general BE
evaluate_base(generalbe, test_loader, valid_loader, device)

# evaluate latentBE
evaluate_base(latentbe, test_loader, valid_loader, device)

79it [00:01, 56.77it/s]


0.8571993670886076
0.40551711675486984
0.07501891255378723
0.40276484097106546
0.06937651336193085


79it [00:01, 56.80it/s]

0.858682753164557
0.40433049692383294
0.0753621831536293
0.4015984510696387
0.06946507841348648





# LatentBE + div

In [7]:
# model update
ckpt = "/home/chaeyoon-jang/test/l/ckpt2/LatentBE_div_model_checkpoint_epoch_99.pt"

latentbe = CNN_be(inference=True, bias_is=True)

latentbe.load_state_dict(torch.load(ckpt)['model_state_dict'])

# make latentBE
for name, p in latentbe.named_parameters():
    if name == "layer1.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer1.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer1.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer1.0.bias'), dim=0)
        
    elif name == "layer2.0.conv.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.r_factor'), dim=0).view(1, -1, 1, 1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('layer2.0.s_factor'), dim=0).view(-1, 1, 1, 1))
    
    elif name == "layer2.0.bias":
        p.data = torch.mean(latentbe.get_parameter('layer2.0.bias'), dim=0)
        
    elif name == "fc1.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc1.s_factor'), dim=0).view(-1, 1))
    
    elif name == "fc1.bias":
        p.data = torch.mean(latentbe.get_parameter('fc1.bias'), dim=0)
        
    elif name == "fc2.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc2.s_factor'), dim=0).view(-1, 1))
        
    elif name == "fc2.bias":
        p.data = torch.mean(latentbe.get_parameter('fc2.bias'), dim=0)
        
    elif name == "fc3.linear.weight":
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.r_factor'), dim=0).view(1, -1))
        p.data = p.data.mul(torch.mean(latentbe.get_parameter('fc3.s_factor'), dim=0).view(-1, 1))

    elif name == "fc3.bias":
        p.data = torch.mean(latentbe.get_parameter('fc3.bias'), dim=0)

# evaluate latentBE
evaluate_base(latentbe, test_loader, valid_loader, device)

79it [00:01, 56.36it/s]

0.8537381329113924
0.41122505974166
0.07770675420761108
0.40810731732392613
0.06784618645906448



