In [41]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
import torch
import matplotlib.pyplot as plt 
import numpy as np
from utils.cnn_duq import CNN_DUQ
from utils.datasets import all_datasets
from utils.cnn_duq import SoftmaxModel as CNN

from utils.resnet import ResNet
from utils.resnet_duq import ResNet_DUQ
from utils.evaluate_ood import get_cifar_svhn_ood, get_auroc_classification

mod='CIFAR10' #['CIFAR10','FMnist']

if mod=='FMnist':
    ds1 = all_datasets["FashionMNIST"]()
    ds2 = all_datasets["MNIST"]()
    input_size = 28
    num_classes = 10
    embedding_size = 256
    learnable_length_scale = False
    gamma = 0.999
    length_scale = 0.1
    d=28
    c=1

    model = CNN_DUQ(
    input_size,
    num_classes,
    embedding_size,
    learnable_length_scale,
    length_scale,
    gamma,
    )
    #model.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/DUQ_FM_30_FULL.pt'))

    ensemble = [CNN(input_size, num_classes).cuda() for _ in range(5)]
    ensemble = torch.nn.ModuleList(ensemble);
    #ensemble.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/FM_5_ensemble_30.pt'))

else:
    ds1 = all_datasets["CIFAR10"]()
    ds2 = all_datasets["SVHN"]()
    length_scale = 0.1
    input_size, num_classes, dataset, test_dataset = ds1
    centroid_size=512
    model_output_size=512 
    gamma = 0.999
    d=32
    c=3

    model = ResNet_DUQ(
            input_size, num_classes, centroid_size, model_output_size, length_scale, gamma
        )
    model.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/DUQ_CIFAR_75.pt'))
    ensemble = [
            ResNet(input_size, num_classes).cuda() for _ in range(5)
        ]
    ensemble = torch.nn.ModuleList(ensemble);
    ensemble.load_state_dict(torch.load('/content/gdrive/My Drive/Colab Notebooks/CIFAR10_5_ensemble.pt'))
    
model.eval()
ensemble.eval()

In [43]:
args={'ensemble':5}
input_size, num_classes, _ , test_dataset_n = ds1
_ , _ , _ , test_dataset_o = ds2

test_dataset_o.target_transform = lambda id: 100   #MNIST to be considered all wrong 
"""Data=[]
for i in range(200,500):
    Data.append(test_dataset_o[i])
    Data.append(test_dataset_n[i])"""
Data = test_dataset_n + test_dataset_o
num=len(Data)
rejection_list = [0.1 , 0.2 , 0.3 ,0.4 , 0.5 ,0.6 , 0.7 , 0.8 , 0.9]


In [None]:
target = np.zeros((Data.__len__(),))

confidence_DUQ = np.zeros((Data.__len__(),))
pred_DUQ = np.zeros((Data.__len__(),))
d=32
c=3

for i in range(len(Data)):
  with torch.no_grad():
    _ , output = model((Data[i][0]).reshape(1,c,d,d))
    target[i] = Data[i][1]
    confidence_DUQ[i] , pred_DUQ[i]= output.max(1)
    if(i%500==0):
      print(i)

a  = np.concatenate((target.reshape(-1,1),pred_DUQ.reshape(-1,1),confidence_DUQ.reshape(-1,1)) , axis=1)
x  = a[a[:,-1].argsort()]

accuracy_DUQ = np.zeros((len(rejection_list),1))
rejected_DUQ = np.zeros((len(rejection_list),1))
i=0
for reject in rejection_list :
  y = x[:][int(reject*num):]
  accuracy_DUQ[i] = ((y[:,0]==y[:,1]).sum())/((1-reject)*num)
  rejected_DUQ[i] = reject*100
  i+=1



In [None]:
confidence_DE = np.zeros((Data.__len__(),))
pred_DE = np.zeros((Data.__len__(),))

for i in range(len(Data)):
  with torch.no_grad():
      predictions = torch.stack([model(Data[i][0].reshape(1,c,d,d).cuda()) for model in ensemble])
      mean_prediction = torch.mean(predictions.exp(), dim=0)
      pred_DE[i] = mean_prediction.max(1)[1]
      target[i] = Data[i][1]
      confidence_DE[i] = torch.sum(mean_prediction * torch.log(mean_prediction), dim=1)
      if(i%500==0):
          print(i)

a  = np.concatenate((target.reshape(-1,1),pred_DE.reshape(-1,1),confidence_DE.reshape(-1,1)) , axis=1)
x  = a[a[:,-1].argsort()]

accuracy_DE = np.zeros((len(rejection_list),1))
rejected_DE = np.zeros((len(rejection_list),1))

i=0
for reject in rejection_list :
  y = x[:][int(reject*num):]
  accuracy_DE[i] = ((y[:,0]==y[:,1]).sum())/((1-reject)*num) 
  rejected_DE[i] = reject*100
  i+=1

plt.plot(rejected_DUQ, accuracy_DUQ, color='blue', linewidth = 2, 
         marker='o', markerfacecolor='blue', markersize=5 , label='DUQ')
plt.plot(rejected_DE , accuracy_DE , color='orange', linewidth = 2, 
         marker='o', markerfacecolor='orange', markersize=5 , label='5-Deep Ensemble')



plt.ylim(0.4,1.01) 
plt.xlim(0,100) 

plt.xlabel('Percent of data rejected by uncertainity') 
plt.ylabel('Accuracy') 

plt.legend()

plt.show() 