In [12]:
import torch
import numpy as np
import matplotlib.pyplot as plt
# import seaborn as sns
# import scipy
# from netcal.metrics import ECE, ACE

from dataloader import get_dataloaders
from models.msdnet_ge import MSDNet
from models.msdnet_imta import IMTA_MSDNet
from utils import parse_args
from utils_notebook import f_probs_ovr_poe_logits_weighted, modal_probs_decreasing, f_probs_ovr_poe_logits_weighted_generalized, anytime_caching

import torchvision.transforms as transforms
import torchvision.datasets as datasets

from collections import OrderedDict, Counter
import random
import os
from typing import Dict

from tqdm import tqdm

In [2]:
MODEL = "model1"
IMTA = True
DATASET = 'ImageNet'
EPOCH = '059'


if IMTA:
    MODEL += "_IMTA"

In [3]:
ARGS = parse_args()
ARGS.data_root = '/home/metod/Desktop/PhD/year1/PoE/MSDNet-PyTorch/data/image_net'
ARGS.data = DATASET
ARGS.save= f'/home/metod/Desktop/PhD/year1/PoE/IMTA/_models/{ARGS.data}/{MODEL}'
if IMTA:
    ARGS.arch = 'IMTA_MSDNet'
else:
    ARGS.arch = 'msdnet_ge'
ARGS.grFactor = [1, 2, 4, 4]
ARGS.bnFactor = [1, 2, 4, 4]
ARGS.growthRate = 16
ARGS.batch_size = 350
ARGS.epochs = 90
ARGS.nBlocks = 5
ARGS.stepmode = 'even'
ARGS.base = 4
ARGS.nChannels = 32
if ARGS.data == 'cifar10':
    ARGS.num_classes = 10
elif ARGS.data == 'cifar100':
    ARGS.num_classes = 100
elif ARGS.data == 'ImageNet':
    ARGS.num_classes = 1000
else:
    raise ValueError('Unknown dataset')
ARGS.step = 4
ARGS.use_valid = True
ARGS.splits = ['train', 'val', 'test']
ARGS.nScales = len(ARGS.grFactor)

if IMTA:
    ARGS.T = 1.0
    ARGS.gamma = 0.1
    _MODEL = "model1"
    ARGS.pretrained = f'/home/metod/Desktop/PhD/year1/PoE/IMTA/_models/{ARGS.data}/{_MODEL}/checkpoint_089.pth.tar'

In [4]:
problematic_prefix = 'module.'

# load pre-trained model
if IMTA:
    model = IMTA_MSDNet(args=ARGS)
else:
    model = MSDNet(args=ARGS)
MODEL_PATH = f'_models/{ARGS.data}/{MODEL}/checkpoint_{EPOCH}.pth.tar'
# MODEL_PATH = f'_models/{ARGS.data}/{MODEL}/save_models/model_best.pth.tar'  # TODO: investigate why using this results in poor accuracy of baseline model
print(MODEL_PATH)
state = torch.load(MODEL_PATH)
params = OrderedDict()
for params_name, params_val in state['state_dict'].items():
    if params_name.startswith(problematic_prefix):
        params_name = params_name[len(problematic_prefix):]
    params[params_name] = params_val
model.load_state_dict(params)
model = model.cuda()
model.eval()

building network of steps: 
[4, 4, 4, 4, 4] 20
 ********************** Block 1  **********************
|		inScales 4 outScales 4 inChannels 32 outChannels 16		|

|		inScales 4 outScales 4 inChannels 48 outChannels 16		|

|		inScales 4 outScales 4 inChannels 64 outChannels 16		|

|		inScales 4 outScales 4 inChannels 80 outChannels 16		|

 ********************** Block 2  **********************
|		inScales 4 outScales 4 inChannels 96 outChannels 16		|

|		inScales 4 outScales 3 inChannels 112 outChannels 16		|
|		Transition layer inserted! (max), inChannels 128, outChannels 64	|

|		inScales 3 outScales 3 inChannels 64 outChannels 16		|

|		inScales 3 outScales 3 inChannels 80 outChannels 16		|

 ********************** Block 3  **********************
|		inScales 3 outScales 3 inChannels 96 outChannels 16		|

|		inScales 3 outScales 3 inChannels 112 outChannels 16		|

|		inScales 3 outScales 2 inChannels 128 outChannels 16		|
|		Transition layer inserted! (max), inChannels 144, outChannels

IMTA_MSDNet(
  (net): WrappedModel(
    (module): MSDNet(
      (blocks): ModuleList(
        (0): Sequential(
          (0): MSDNFirstLayer(
            (layers): ModuleList(
              (0): Sequential(
                (0): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
                (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
                (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
              )
              (1): ConvBasic(
                (net): Sequential(
                  (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
                  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (2): ReLU(inplace=True)
                )
              )
              (2): ConvBasic(
                (net): Sequential(
                  (0): Conv2d(64, 128, kerne

In [5]:
valdir = os.path.join(ARGS.data_root, 'valid')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

val_set = datasets.ImageFolder(valdir, transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
]))

val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=ARGS.batch_size, shuffle=False,
            num_workers=ARGS.workers, pin_memory=True)

In [6]:
logits = []
targets = []
with torch.no_grad():
    for i, (x, y) in enumerate(val_loader):
        y = y.cuda(device=None)
        x = x.cuda()

        input_var = torch.autograd.Variable(x)
        target_var = torch.autograd.Variable(y)

        if IMTA:
            output = model(input_var)
        else:
            output, _ = model(input_var)

        if not isinstance(output, list):
            output = [output]

        logits.append(torch.stack(output))
        targets.append(target_var)

logits = torch.cat(logits, dim=1).cpu()
targets = torch.cat(targets).cpu()

In [7]:
logits.shape

torch.Size([5, 50000, 1000])

# 1 Anytime performance

In [8]:
L = len(logits)
N = logits.shape[1]

In [9]:
probs = torch.softmax(logits, dim=2)
preds = {i: torch.argmax(probs, dim=2)[i, :] for i in range(L)}
acc = [(targets == preds[i]).sum() / len(targets) for i in range(L)]

In [10]:
[round(float(x), 4) for x in acc]

[0.5716, 0.6585, 0.6998, 0.7186, 0.7275]

In [11]:
[round(x, 4) for x in modal_probs_decreasing(preds, probs, L, N=N, thresholds=[-0.0001, -0.01, -0.05, -0.1, -0.2, -0.25, -0.33, -0.5], diffs_type="all").values()]

[79.162, 61.17, 42.818, 30.452, 15.556, 10.884, 5.71, 1.184]

In [13]:
T = 1.
probs_poe_ovr_break_ties_generalized = torch.tensor(f_probs_ovr_poe_logits_weighted_generalized(logits, weights=(np.arange(1, L + 1, 1, dtype=float) / L) * T))
preds_poe_ovr_break_ties_generalized = {i: torch.argmax(probs_poe_ovr_break_ties_generalized, dim=2)[i, :] for i in range(L)}
acc_poe_ovr_break_ties_generalized = [(targets == preds_poe_ovr_break_ties_generalized[i]).sum() / len(targets) for i in range(L)]

In [14]:
[round(float(x), 4) for x in acc_poe_ovr_break_ties_generalized]

[0.5716, 0.6541, 0.6931, 0.7142, 0.7304]

In [15]:
[round(x, 4) for x in modal_probs_decreasing(preds_poe_ovr_break_ties_generalized, 
                                            probs_poe_ovr_break_ties_generalized, L, N=N, thresholds=[-0.0001, -0.01, -0.05, -0.1, -0.2, -0.25, -0.33, -0.5], diffs_type="all").values()]

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In [16]:
probs_stateful = probs_stateful = anytime_caching(probs, N=N, L=L)
preds_stateful = {i: torch.argmax(probs_stateful, dim=2)[i, :] for i in range(L)}
acc_stateful = [(targets == preds_stateful[i]).sum() / len(targets) for i in range(L)]

In [17]:
[round(float(x), 4) for x in acc_stateful]

[0.5716, 0.6467, 0.6868, 0.7078, 0.7219]

In [18]:
[round(x, 4) for x in modal_probs_decreasing(preds_stateful, probs_stateful, L, N=N, thresholds=[-0.0001, -0.01, -0.05, -0.1, -0.2, -0.25, -0.33, -0.5], diffs_type="all").values()]

[4.704, 4.006, 2.592, 1.588, 0.594, 0.31, 0.122, 0.01]