In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
# import seaborn as sns
# import scipy
# from netcal.metrics import ECE, ACE

from dataloader import get_dataloaders
from models.msdnet_ge import MSDNet
from utils import parse_args
from utils_notebook import *

from collections import OrderedDict, Counter
import random
from typing import Dict

In [2]:
MODEL = "model1"
EPOCH = "099"

In [3]:
# CIFAR-100
ARGS = parse_args()
ARGS.data_root = 'data'
ARGS.data = 'cifar100'
ARGS.save= f'/home/metod/Desktop/PhD/year1/PoE/IMTA/_models/{ARGS.data}/{MODEL}'
ARGS.arch = 'msdnet_ge'
ARGS.grFactor = [1, 2, 4]
ARGS.bnFactor = [1, 2, 4]
ARGS.growthRate = 6
ARGS.batch_size = 64
ARGS.epochs = 300
ARGS.nBlocks = 7
ARGS.stepmode = 'even'
ARGS.base = 4
ARGS.nChannels = 16
ARGS.num_classes = 100
ARGS.step = 2
ARGS.use_valid = True
ARGS.splits = ['train', 'val', 'test']
ARGS.nScales = len(ARGS.grFactor)

In [4]:
# load pre-trained model
model = MSDNet(args=ARGS)
MODEL_PATH = f'_models/{ARGS.data}/{MODEL}/save_models/checkpoint_{EPOCH}.pth.tar'
state = torch.load(MODEL_PATH)
params = OrderedDict()
for params_name, params_val in state['state_dict'].items():
    params[params_name.replace('module.', '')] = params_val
    # state['state_dict'][params_name.replace('module.', '')] = state['state_dict'].pop(params_name)
model.load_state_dict(params)
model = model.cuda()
model.eval()

building network of steps: 
[4, 2, 2, 2, 2, 2, 2] 16
 ********************** Block 1  **********************
|		inScales 3 outScales 3 inChannels 16 outChannels 6		|

|		inScales 3 outScales 3 inChannels 22 outChannels 6		|

|		inScales 3 outScales 3 inChannels 28 outChannels 6		|

|		inScales 3 outScales 3 inChannels 34 outChannels 6		|

 ********************** Block 2  **********************
|		inScales 3 outScales 3 inChannels 40 outChannels 6		|

|		inScales 3 outScales 3 inChannels 46 outChannels 6		|

 ********************** Block 3  **********************
|		inScales 3 outScales 2 inChannels 52 outChannels 6		|
|		Transition layer inserted! (max), inChannels 58, outChannels 29	|

|		inScales 2 outScales 2 inChannels 29 outChannels 6		|

 ********************** Block 4  **********************
|		inScales 2 outScales 2 inChannels 35 outChannels 6		|

|		inScales 2 outScales 2 inChannels 41 outChannels 6		|

 ********************** Block 5  **********************
|		inScales 2 outS

MSDNet(
  (blocks): ModuleList(
    (0): Sequential(
      (0): MSDNFirstLayer(
        (layers): ModuleList(
          (0): ConvBasic(
            (net): Sequential(
              (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
          (1): ConvBasic(
            (net): Sequential(
              (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
          (2): ConvBasic(
            (net): Sequential(
              (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    

In [5]:
# data
_, _, test_loader = get_dataloaders(ARGS)

logits = []
targets = []
with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        y = y.cuda(device=None)
        x = x.cuda()

        input_var = torch.autograd.Variable(x)
        target_var = torch.autograd.Variable(y)

        output, _ = model(input_var)
        if not isinstance(output, list):
            output = [output]

        logits.append(torch.stack(output))
        targets.append(target_var)

logits = torch.cat(logits, dim=1).cpu()
targets = torch.cat(targets).cpu()

Files already downloaded and verified
Files already downloaded and verified
!!!!!! Load train_set_index !!!!!!
------------------------------------
split num_sample_valid: 5000
------------------------------------


In [6]:
logits.shape

torch.Size([7, 10000, 100])

In [7]:
targets.shape

torch.Size([10000])

In [8]:
L = len(logits)
N = logits.shape[1]

In [9]:
probs = torch.softmax(logits, dim=2)
preds = {i: torch.argmax(probs, dim=2)[i, :] for i in range(L)}
acc = [(targets == preds[i]).sum() / len(targets) for i in range(L)]
msp = {i: torch.max(probs, dim=2).values[i, :] for i in range(L)}

In [10]:
acc

[tensor(0.6131),
 tensor(0.6390),
 tensor(0.6679),
 tensor(0.6838),
 tensor(0.6867),
 tensor(0.6888),
 tensor(0.6922)]

In [11]:
modal_probs_decreasing(preds, probs, L, N=N)

{0.01: 75.99000000000001,
 0.05: 63.89,
 0.1: 55.61000000000001,
 0.2: 43.97,
 0.5: 16.939999999999998}

In [12]:
probs_poe_ovr_break_ties = torch.tensor(f_probs_ovr_poe_logits_weighted(logits))
preds_poe_ovr_break_ties = {i: torch.argmax(probs_poe_ovr_break_ties, dim=2)[i, :] for i in range(L)}
acc_poe_ovr_break_ties = [(targets == preds_poe_ovr_break_ties[i]).sum() / len(targets) for i in range(L)]

In [13]:
acc_poe_ovr_break_ties

[tensor(0.6131),
 tensor(0.6564),
 tensor(0.6882),
 tensor(0.7040),
 tensor(0.7104),
 tensor(0.7163),
 tensor(0.7212)]

In [14]:
modal_probs_decreasing(preds_poe_ovr_break_ties, probs_poe_ovr_break_ties, L, N=N)

{0.01: 3.02, 0.05: 0.45999999999999996, 0.1: 0.08, 0.2: 0.0, 0.5: 0.0}