# Libs and pre-definitions

In [1]:
import os
REPOSITORY_PATH = r'/home/luis-felipe/UncEst'
DATA_PATH = os.path.join(REPOSITORY_PATH,'data')
#CORRUPTED_DATA_PATH = os.path.join(DATA_PATH,'corrupted')

PATH_MODELS = os.path.join(REPOSITORY_PATH,'torch_models')
PATH_TRAINER = os.path.join(PATH_MODELS,'trainer')

PATH_FIGS = os.path.join(REPOSITORY_PATH,'figs','EvalMIMO')

### Bibliotecas padrões python e utils pytorch

In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from collections import defaultdict

In [4]:
# Define o computador utilizado como cuda (gpu) se existir ou cpu caso contrário
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

True


### Bibliotecas desenvolvidas

https://github.com/lfpc/Uncertainty_Estimation

In [5]:
import NN_models as models
import NN_utils as utils
import NN_utils.train_and_eval as TE
import torch_data

In [40]:
import uncertainty.metrics as metrics
from uncertainty.ensemble import Ensemble
from uncertainty.mimo import MIMO_ensemble

# Analysis

## Definitions

In [7]:
MODEL_ARC = 'VGG_16'
DATA = 'Cifar100'

data = torch_data.__dict__[DATA](data_dir = DATA_PATH)
model_class = models.__dict__[MODEL_ARC]

weights_path = os.path.join(PATH_MODELS,MODEL_ARC,DATA,'MIMO')
WEIGTH_FILE = 0

Files already downloaded and verified
Files already downloaded and verified


## Upload Models

In [None]:
def weights_files(weights_path):
    #Get all weight files in weights_path
    files = [f for f in os.listdir(weights_path) if os.path.isfile(os.path.join(weights_path, f)) and f.endswith(".pt")]
    files = sorted(files)
    return files

def upload_weights(file, weights_path):
    files = weights_files(weights_path)

    if isinstance(file,int):
        weights = files[file]
        weights = os.path.join(weights_path,weights)
    elif file == 'random':
        weights = random.choice(files)
        weights = os.path.join(weights_path,weights)
    elif file == 'max':
        return upload_weights(0, weights_path)
    elif isinstance(file,str):
        if file in files:
            weights = os.path.join(weights_path,file)
        elif file+'.pt' in files:
            weights = os.path.join(weights_path,file+'.pt')
        else: raise Exception("No file named ", file)
    state_dict = torch.load(weights)
    return state_dict


In [None]:
models_dict = {}

### M = 1

In [None]:
N_ENS = 1

NAME = 'MIMO_' + str(N_ENS) + '_' + MODEL_ARC +'_' + DATA

In [None]:
model = model_class(num_classes = data.n_classes, name = NAME, softmax = True).to(dev)
model.load_state_dict(upload_weights(WEIGTH_FILE ,os.path.dirname(weights_path)))
model.eval()
models_dict[N_ENS] = model

acc_model = TE.model_acc(model,data.test_dataloader)
rk_model = 1-0.01*acc_model #risk
print(f'Acurácia M={N_ENS} : {acc_model}')

### M = 2 - 6

In [None]:
for N_ENS in range(2,7):
    NAME = 'MIMO_' + str(N_ENS) + '_' + MODEL_ARC +'_' + DATA
    model = MIMO_ensemble(model_class, data.n_classes, n_ensembles = N_ENS, 
                        name=NAME, softmax=True).to(dev)
    model.load_state_dict(upload_weights(NAME, weights_path))
    model.eval()
    models_dict[N_ENS] = model
    print(f'Acurácia M={N_ENS} : {TE.model_acc(model,data.test_dataloader)}')

## Plots

In [None]:
RC_dict = {}
for N_ENS, model in models_dict.items():
    RC = metrics.selective_metrics(model,data.test_dataloader, name = NAME)
    RC_dict[N_ENS] = RC

#### Risk x Coverage

In [35]:
for N_ENS, RC in RC_dict.items():
    NAME = 'MIMO_' + str(N_ENS) + '_' + MODEL_ARC +'_' + DATA
    RC.plot_ROC_and_RC(aurc = True)
    plt.suptitle(f'M = {N_ENS}')
    plt.show()

72.16

#### Risk x M

In [None]:
uncs = ['MCP','MI','Best']
#RC.set_uncs(uncs)
f, axes = plt.subplots(1, len(uncs),sharey = True,figsize=(14,6),dpi=80)
for N_ENS, RC in RC_dict.items():
    
    for i,unc in enumerate(uncs):
        if unc in RC.risk.keys():
            axes[i].plot(RC.c_list,RC.risk[unc],label = f'M = {N_ENS}')
        elif unc == 'Best':
            axes[i].plot(RC.c_list,RC.get_best(),label = f'M = {N_ENS}')
        else:
            continue
    NAME = 'MIMO_' + str(N_ENS) + '_' + MODEL_ARC +'_' + DATA
for i,unc in enumerate(uncs):
    axes[i].set_title(uncs[i])
    axes[i].set_xlabel("Coverage", fontsize=RC.LABEL_FONTSIZE*0.7)
    axes[i].tick_params(axis="x",labelsize=RC.TICKS_FONTSIZE)
    axes[i].grid()
    axes[i].legend()
axes[0].set_ylabel("Risk", fontsize=RC.LABEL_FONTSIZE*0.7)
plt.suptitle(f'Risk x Coverage')
plt.show()