In [1]:
import os
import sys
import torch
import random
import argparse
from torch import nn
import matplotlib.pyplot as plt
import torch.backends.cudnn as cudnn

import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
# from scipy.special import softmax
import datetime
import argparse


import data.cifar10 as cifar10
import data.cifar100 as cifar100
import calibration as cal
import calibration.metric as metric
import calibration.tace as tace
from xautodl.datasets.get_dataset_with_transform import get_datasets
from torch.utils.data import DataLoader
# import calibration.ece_kde as ece_kde
import inspect

from calibration.temperature_scaling import ModelWithTemperature
from calibration.temp_scale import accuracy
from torch.utils.data.sampler import SubsetRandomSampler

from Net.resnet_tiny_imagenet import resnet50 as resnet50_ti
from Net.resnet import resnet18,resnet34,resnet50, resnet110
from Net.wide_resnet import wide_resnet_cifar
from Net.densenet import densenet121

import os

# Import metrics to compute

In [2]:
def get_preds_and_targets(model, dataloader, device):
    preds, pred_classes, targets = [], [], []

    model.eval()  # Set model to evaluation mode
    model.to(device)  # Move model to the selected device (CPU or GPU)

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output_tuple = model(data)

            output = output_tuple[1]

            prob = F.softmax(output, dim = 1)  # Compute probabilities
            _, pred = torch.max(prob, 1)  # Get predicted class

            preds.extend(prob.cpu().numpy())  # Move probabilities to CPU and convert to numpy array
            pred_classes.extend(pred.cpu().numpy())  # Move predictions to CPU and convert to numpy array
            targets.extend(target.cpu().numpy())  # Move targets to CPU and convert to numpy array

    return np.array(preds), np.array(pred_classes), np.array(targets)

def get_preds_and_targets2(model, dataloader, device):
    #use this when using temp scale since the network output is logits not tuple 

    preds, pred_classes, targets = [], [], []

    model.eval()  
    model.to(device)

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output_tuple = model(data)

            output = output_tuple
            
            prob = F.softmax(output, dim=1)
            _, pred = torch.max(prob, 1)

            preds.extend(prob.cpu().numpy())  # Move probabilities to CPU and convert to numpy array
            targets.extend(target.cpu().numpy())  # Move targets to CPU and convert to numpy array

    return np.array(preds), np.array(pred_classes), np.array(targets)

def get_param_dict(func, *args, **kwargs):
    result = func(*args, **kwargs)
    
    # Get the function's signature and parameters
    signature = inspect.signature(func)
    params = signature.parameters

    # Create a dictionary with default parameter values
    default_params = {k: v.default for k, v in params.items() if v.default != inspect.Parameter.empty}

    # Update the default parameter values with the provided kwargs
    all_params = {**default_params, **kwargs}
    
    all_params['result'] = result
    return all_params

In [3]:
def get_logits_labels(data_loader, net,device):
    logits_list = []
    labels_list = []
    net.eval()
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            logits = net(data)
            logits_list.append(logits)
            labels_list.append(label)
        logits = torch.cat(logits_list).to(device)
        labels = torch.cat(labels_list).to(device)
    return logits, labels

In [4]:
models = {
    'resnet18': resnet18
    # ,
    # 'resnet34': resnet34,
    # 'resnet50': resnet50,
    # 'resnet110': resnet110,
    # 'wide_resnet': wide_resnet_cifar,
    # 'densenet121': densenet121
}

dataset_num_classes = {
    'cifar10': 10,
    'cifar100': 100,
    'ImageNet16-120': 120
}

if __name__ == "__main__":
    # Argument parser setup
    # parser = argparse.ArgumentParser(description="Model evaluation and result storage")
    # parser.add_argument("--csv_file", type=str, default="results.csv", help="CSV file name to store or update results")
    # parser.add_argument("--bin_sizes", type=int, default=[5, 10,15,20,25,50,100,200,500], help="Number of bins for calibration metrics")

    # parser.add_argument("--image_dataset", type=str, default="cifar10", help="CIFAR-10, CIFAR-100, and ImageNet16-120")
    # parser.add_argument("--post_temp", type=str, default='False', help="if using temp scale")
    # parser.add_argument("--device", type=str, default='cuda:0', help="device")
    # parser.add_argument("--dataset", type=str, default='cifar10',
    #                     dest="dataset", help='dataset to test on')
    # parser.add_argument("--save-path", type=str, default=save_loc,
    #                     dest="save_loc",
    #                     help='Path to import the model')
    # parser.add_argument("--saved_model_name", type=str, default=saved_model_name,
    #                     dest="saved_model_name", help="file name of the pre-trained model")


    # args = parser.parse_args()

    # Use the parsed arguments
    csv_file = "results.csv"
    bin_sizes = [5, 10, 15, 20, 25, 50, 100, 200, 500]
    image_dataset = "cifar10"
    post_temp = 'False'
    device = 'cuda:2' if torch.cuda.is_available() else "cpu"
    dataset = 'cifar100'
    save_loc = './'
    saved_model_name = 'resnet18_cross_entropy.model'

    num_classes = dataset_num_classes[dataset]

    model_dir = "/home/younan/project_calibration/project_calibration/MODEL_DIRECTORY/"

    for model_name in models:
        model = models[model_name]

        net = model(num_classes=num_classes, temp=1.0)
        # net.cuda()
        net.to(device)
        # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True
        # Load the state dict from the file
        state_dict = torch.load(model_dir + saved_model_name)

        # Remove the 'module.' prefix
        new_state_dict = {k[len("module."):]: v for k, v in state_dict.items()}

        # Load the adjusted state dict into your model
        net.load_state_dict(new_state_dict)

In [5]:
post_temp = 'True'

if dataset == 'cifar10':
    if post_temp == 'True':
        test_loader, val_loader = cifar10.get_test_valid_loader(batch_size = 256,
                            random_seed = 42,
                            valid_size=0.2,
                            shuffle=True,
                            num_workers=4, pin_memory=False)
    else:
        test_loader = cifar10.get_test_loader(batch_size=256, shuffle=False, num_workers=4, pin_memory=False)
elif dataset == 'cifar100':
    if post_temp == 'True':
        test_loader, val_loader = cifar100.get_test_valid_loader(batch_size = 256,
                            random_seed = 42,
                            valid_size=0.2,
                            shuffle=True,
                            num_workers=4, pin_memory=False)
    else:
        test_loader = cifar100.get_test_loader(batch_size=256, shuffle=False, num_workers=4, pin_memory=False)
elif dataset == 'ImageNet16-120':
    

    root = './datasets/ImagenNet16'
    train_data, test_data, xshape, class_num = get_datasets(image_dataset, root, 0)

    if post_temp == 'True':
        def imagenet_get_test_valid_loader(batch_size = 256, random_seed= 42, valid_size = 0.2, shuffle = True,
                                    num_workers=4, pin_memory=False,
                                test_dataset=test_data):
            num_test = len(test_dataset)
            indices = list(range(num_test))
            split = int(np.floor(valid_size * num_test))

            if shuffle:
                np.random.seed(random_seed)
                np.random.shuffle(indices)

            test_idx, valid_idx = indices[split:], indices[:split]
            

            test_sampler = SubsetRandomSampler(test_idx)
            valid_sampler = SubsetRandomSampler(valid_idx)

            test_loader = torch.utils.data.DataLoader(
                test_dataset, batch_size=batch_size, sampler=test_sampler,
                num_workers=num_workers, pin_memory=pin_memory,
            )
            valid_loader = torch.utils.data.DataLoader(
                test_dataset, batch_size=batch_size, sampler=valid_sampler,
                num_workers=num_workers, pin_memory=pin_memory,
            )
            return test_loader, valid_loader
        test_loader, val_loader = imagenet_get_test_valid_loader(batch_size = 256, random_seed= 42, valid_size = 0.2, shuffle = True,
                                num_workers=4, pin_memory=False)
    else:
        test_loader = DataLoader(test_data, batch_size=256, shuffle=False)

Files already downloaded and verified


In [6]:
logits, labels = get_logits_labels(test_loader, net,device)

In [7]:
logits, labels = get_logits_labels(val_loader, net,device)

In [8]:
logits

tensor([[-1.2115, -3.0624,  1.0376,  ..., -0.1389,  0.9347,  1.0207],
        [ 0.2889,  2.0660, -2.4687,  ..., -4.0581, -3.8591, -1.1131],
        [-1.6086, -1.1720,  0.4942,  ...,  1.8165, -0.1948,  0.4138],
        ...,
        [-2.2142,  1.6175,  0.2877,  ...,  0.6096,  0.6216,  0.9865],
        [-2.4944, -1.3995, -1.6311,  ..., -0.0470, -2.7177,  0.5685],
        [-0.7392,  1.6196, -0.4817,  ...,  0.9357, -0.3500, -0.2709]],
       device='cuda:2')

In [9]:

if post_temp == 'True':
    val_probs, val_pred_classes, val_targets = get_preds_and_targets2(net, val_loader, device)
    test_probs, test_pred_classes, test_targets = get_preds_and_targets2(net, test_loader, device)

    scaled_model = ModelWithTemperature(net)
    scaled_model.set_temperature(val_loader,device=device)

    preds, pred_classes,targets  = get_preds_and_targets2(scaled_model, test_loader, device)
else:

    preds, pred_classes,targets = get_preds_and_targets2(net, test_loader, device)

Before temperature - NLL: 3.604, ECE: 0.040
Optimal temperature: 1.200
After temperature - NLL: 3.588, ECE: 0.026


In [10]:
ece_str = ''
sce_str = ''
tace_str = ''
ace_str = ''
mce_str = ''
cwECE_str = ''
ECE_em_str = ''
ole_str = ''
ole_loss = tace.OELoss()

for n_bin in bin_sizes:
    ece_str += str(get_param_dict(metric.get_ece, preds, targets, n_bins=n_bin)) + ', '
    sce_str += str(get_param_dict(metric.get_sce,preds, targets, n_bins=n_bin,logits = False)) + ', '
    tace_str += str(get_param_dict(metric.get_tace,preds, targets, n_bins=n_bin,logits = False)) + ', '
    ace_str += str(get_param_dict(metric.get_ace,preds, targets, n_bins=n_bin,logits = False)) + ', '
    mce_str += str(get_param_dict(metric.get_mce, preds, targets, n_bins=n_bin)) + ', '
    cwECE_str += str(get_param_dict(metric.get_classwise_ece, preds, targets, n_bins=n_bin)) + ', '
    ECE_em_str += str(get_param_dict(cal.get_ece_em, preds, targets, num_bins=n_bin)) + ', '
    ole_str = str(get_param_dict(ole_loss.loss, preds, targets, n_bins=n_bin,logits = False)) + ', '

# Remove the trailing comma and space
ece_str = ece_str.rstrip(', ')
sce_str = sce_str.rstrip(', ')
tace_str = tace_str.rstrip(', ')
ace_str = ace_str.rstrip(', ')
mce_str = mce_str.rstrip(', ')
cwECE_str = cwECE_str.rstrip(', ')
ECE_em_str = ECE_em_str.rstrip(', ')
ole_str = ole_str.rstrip(', ')

data = {
    'config': [model],
    'info' : [accuracy(preds,targets)],
    'dataset': [dataset],
    'ece': ece_str,
    'sce': sce_str,
    'tace': tace_str,
    'ace': ace_str,
    'MCE': mce_str,
    'cwECE': cwECE_str,
    'Marginal_CE_debias': [get_param_dict(cal.get_calibration_error,preds, targets,p=1)],
    'Marginal_CE': [get_param_dict(cal.get_calibration_error,preds, targets, debias=False,p=1)],
    'ECE_em': ECE_em_str,
    'Ole': ole_str,
    'KSCE': [get_param_dict(metric.get_KSCE,preds, targets)],
    'KDECE': [get_param_dict(metric.get_KDECE,preds, targets)],
    'MMCE': [get_param_dict(metric.get_MMCE,preds, targets)],
    'NLL': [get_param_dict(metric.get_nll,preds, targets)],
    'brier': [get_param_dict(metric.get_brierscore,preds, targets)],
    # 'ECE_KDE': [get_param_dict(ece_kde.get_ece_kde(tensor_preds, tensor_targets, bandwidth=bandwidth, p=1, mc_type='canonical', device=device).item())],
    'timestamp': [datetime.datetime.now()]
}

# print(data)
# Step 2: Convert the dictionary into a DataFrame
df = pd.DataFrame(data)
print(df.head(1))

                                  config     info   dataset  \
0  <function resnet18 at 0x7fa1280b2b80>  0.13625  cifar100   

                                                 ece  \
0  {'n_bins': 5, 'result': 0.009811824304051702},...   

                                                 sce  \
0  {'n_bins': 5, 'logits': False, 'result': 0.003...   

                                                tace  \
0  {'n_bins': 5, 'threshold': 0.001, 'logits': Fa...   

                                                 ace  \
0  {'n_bins': 5, 'logits': False, 'result': 0.003...   

                                                 MCE  \
0  {'n_bins': 5, 'result': 0.19718930721282957}, ...   

                                               cwECE  \
0  {'n_bins': 5, 'result': 0.00337140131123364}, ...   

                                  Marginal_CE_debias  \
0  {'p': 1, 'debias': True, 'mode': 'marginal', '...   

                                         Marginal_CE  \
0  {'p': 1, 'debias': Fals