In [3]:
'''
all using pitfalls data

cifar100 and cifar10
'''

import os
import sys
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import copy
import time
import platform

from analysis.UN_utils         import w_l, \
                                      v_project2
from analysis.util_calibration import ts_calibrate, \
                                      ets_calibrate, \
                                      mir_calibrate, \
                                      irova_calibrate
from analysis.util_evaluation  import ece_eval_binary, \
                                      ece_hist_binary

In [4]:
#%% ========== ========== ========== ========== ========== ========== ========== ==========
# using pitfalls data
from scipy.special import logsumexp, log_softmax, softmax

data_dir    = './pitfalldee/megacache/'

''' 
file list

----- CIFAR100 -----
'logits_CIFAR100-PreResNet110-deepens-100-1.pkl.npy'
'logits_CIFAR100-VGG16BN-deepens-100-1.pkl.npy'
'logits_CIFAR100-PreResNet164-deepens-100-1.pkl.npy'
'logits_CIFAR100-WideResNet28x10-deepens-100-1.pkl.npy'
'logits_CIFAR100-myResNet18-csgld-100-1.pkl.npy'

----- CIFAR10 -----
'logits_CIFAR10-PreResNet110-deepens-100-1.pkl.npy'
'logits_CIFAR10-VGG16BN-deepens-100-1.pkl.npy'
'logits_CIFAR10-PreResNet164-deepens-100-1.pkl.npy'
'logits_CIFAR10-WideResNet28x10-deepens-100-1.pkl.npy'

'''

logits_file = 'logits_CIFAR10-PreResNet110-deepens-100-1.pkl.npy'

# note: this is the raw output
log_probs = np.load(data_dir+logits_file)
print(log_probs.shape)

# do softmax (since this data is small)
log_probs_p = softmax(log_probs, axis = -1)
print(log_probs_p.shape)
print(np.sum(log_probs_p[0,:20,:], axis=-1)) # check prob sum


(100, 10000, 10)
(100, 10000, 10)
[1.0000001  0.99999917 0.99999976 0.99999946 0.9999999  1.0000006
 1.0000004  1.         0.99999917 0.99999994 0.99999976 1.0000002
 0.9999992  0.9999999  1.0000005  0.9999997  0.99999946 1.0000005
 0.9999999  1.0000004 ]


In [5]:
#%% load label
if 'CIFAR100-' in logits_file:
    dataset     = 'cifar100'
    num_feature = 100
    n_class     = 100
    
if 'CIFAR10-' in logits_file:
    dataset     = 'cifar10'
    num_feature = 10
    n_class     = 10

fname = './labels/' + dataset + '-labels.pk'
print('> reading label file from:', fname)
with open(fname, 'rb') as f:
    train_label = pickle.load(f)
    test_label = pickle.load(f)

test_label = test_label.astype(int)
train_label = train_label.astype(int)

label_total = np.eye(n_class)[test_label]

print(label_total.shape)


> reading label file from: ./labels/cifar10-labels.pk
(10000, 10)


In [6]:
#%% cut in pipeline
lines = []

for C in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]: # 10, 20, 30, 40, 50, 60, 70, 80, 90, 100
    print('> cut at:', C)
    log_prob3 = np.mean(log_probs_p[:C,:,:], axis=0)
    print(log_probs_p[:C,:,:].shape)
    
    #% % all together ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~    
    log_tD = np.zeros_like(log_prob3)    
    for s in range(log_prob3.shape[0]):
        values = log_probs_p[:C,s,:] # log_probs_p | all_eval_logit
        v_mean = np.mean(values, axis=0) # 1st estimation
        w      = w_l(values, v_mean)
        v_     = np.dot(values.T, w) # 2nd estimation
        for i in range(5):
            w  = w_l(values, v_)
            v_ = np.dot(values.T, w)
        log_tD[s,:] = v_
    
    #% % without truth ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # mean of probabilities:
    # note: np.mean(log_probs_p, axis=0) <=> log_prob3
    ece, nll, mse, accu = ece_eval_binary(log_prob3, label_total) 
    hist_ece = ece_hist_binary(log_prob3, label_total)
    line = '>>without truth<< ece | nll | mse | accu | hist \n{:.7f} {:.7f} {:.7f} {:.7f} {:.7f}'.format(ece, nll, mse, accu, hist_ece.numpy()[0])
    print('\n'+line)
    lines.append(line)
    
    #% % after truth ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    ece, nll, mse, accu = ece_eval_binary(log_tD, label_total) 
    hist_ece = ece_hist_binary(log_tD, label_total)
    line = '>>with truth<< ece | nll | mse | accu | hist \n{:.7f} {:.7f} {:.7f} {:.7f} {:.7f}'.format(ece, nll, mse, accu, hist_ece.numpy()[0])
    print('\n'+line)
    lines.append(line)
    
    #% % acc perserving ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~    
    log_tD      = np.zeros_like(log_prob3)    
    acc_changed = 0
    for s in range(log_prob3.shape[0]):
        values   = log_probs_p[:C,s,:] # log_probs_p | all_eval_logit
        v_mean   = np.mean(values, axis=0) # 1st estimation
        ind_mean = np.argmax(v_mean)
        w        = w_l(values, v_mean)
        v_       = np.dot(values.T, w) # 2nd estimation
        for i in range(5):
            w  = w_l(values, v_)
            v_ = np.dot(values.T, w)
        # acc perserving step here <<<<<
        ind_ = np.argmax(v_)
        if ind_ != ind_mean:
            # print('> one sample found.')
            acc_changed += 1
            v_ = v_project2(v_, ind_mean, C=n_class)
        # acc perserving step here <<<<<
        log_tD[s,:] = v_
    print('> acc changed:', acc_changed)
    
    #% % after acc perserving ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    ece, nll, mse, accu = ece_eval_binary(log_tD, label_total) 
    hist_ece = ece_hist_binary(log_tD, label_total)
    line = '>>with truth<< ece | nll | mse | accu | hist \n{:.7f} {:.7f} {:.7f} {:.7f} {:.7f} {}'.format(ece, nll, mse, accu, hist_ece.numpy()[0], acc_changed)
    print('\n'+line)
    lines.append(line)

> cut at: 10
(10, 10000, 10)

>>without truth<< ece | nll | mse | accu | hist 
0.0109697 0.1131105 0.0540911 0.9628000 0.0050049

>>with truth<< ece | nll | mse | accu | hist 
0.0137652 0.1187795 0.0556528 0.9632000 0.0134574
> acc changed: 22

>>with truth<< ece | nll | mse | accu | hist 
0.0142573 0.1188241 0.0556964 0.9628000 0.0137872 22
> cut at: 20
(20, 10000, 10)

>>without truth<< ece | nll | mse | accu | hist 
0.0124757 0.1098123 0.0529854 0.9642000 0.0044891

>>with truth<< ece | nll | mse | accu | hist 
0.0122656 0.1123410 0.0536506 0.9639000 0.0099984
> acc changed: 9

>>with truth<< ece | nll | mse | accu | hist 
0.0121165 0.1123311 0.0536444 0.9642000 0.0096817 9
> cut at: 30
(30, 10000, 10)

>>without truth<< ece | nll | mse | accu | hist 
0.0126912 0.1074398 0.0523041 0.9642000 0.0042738

>>with truth<< ece | nll | mse | accu | hist 
0.0114973 0.1094113 0.0526672 0.9642000 0.0083415
> acc changed: 13

>>with truth<< ece | nll | mse | accu | hist 
0.0119043 0.1094062 0.0

In [7]:
#%%
for line in lines:
    l = line.split('\n')
    print(l[1])

0.0109697 0.1131105 0.0540911 0.9628000 0.0050049
0.0137652 0.1187795 0.0556528 0.9632000 0.0134574
0.0142573 0.1188241 0.0556964 0.9628000 0.0137872 22
0.0124757 0.1098123 0.0529854 0.9642000 0.0044891
0.0122656 0.1123410 0.0536506 0.9639000 0.0099984
0.0121165 0.1123311 0.0536444 0.9642000 0.0096817 9
0.0126912 0.1074398 0.0523041 0.9642000 0.0042738
0.0114973 0.1094113 0.0526672 0.9642000 0.0083415
0.0119043 0.1094062 0.0526627 0.9642000 0.0083316 13
0.0132586 0.1070942 0.0521342 0.9641000 0.0041239
0.0108600 0.1085312 0.0524417 0.9639000 0.0078179
0.0108953 0.1085259 0.0524374 0.9641000 0.0076104 9
0.0131082 0.1060188 0.0519437 0.9638000 0.0048223
0.0105708 0.1074706 0.0522784 0.9638000 0.0075961
0.0107360 0.1074718 0.0522798 0.9638000 0.0075940 5
0.0131180 0.1058330 0.0519794 0.9645000 0.0037859
0.0105098 0.1070941 0.0522778 0.9640000 0.0070593
0.0109282 0.1070901 0.0522739 0.9645000 0.0065575 5
0.0123611 0.1056936 0.0518050 0.9642000 0.0038555
0.0103756 0.1068218 0.0520204 0.9638