# Exploring Abstention Loss
author: Elizabeth A. Barnes, Randal J. Barnes
date: February 23, 2021

* based on Thulasidasan, S., T. Bhattacharya, J. Bilmes, G. Chennupati, and J. Mohd-Yusof, 2019: Combating Label Noise in Deep Learning Using Abstention. arXiv [stat.ML],.
* thesis: https://digital.lib.washington.edu/researchworks/handle/1773/45781
* code base is here: https://github.com/thulas/dac-label-noise/blob/master/dac_loss.py

In [1]:
import numpy as np
import numpy.ma as ma
import pandas as pd
import random
import xarray as xr
import scipy.stats as stats
import random
import time
import sys
from collections import Counter
import os.path
from os import path

import sklearn
from sklearn.model_selection import train_test_split
from sklearn import preprocessing

import tensorflow as tf
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras import regularizers
from tensorflow.keras import metrics
from tensorflow.keras import optimizers
from tensorflow.keras.losses import Loss
from tensorflow.keras.models import Sequential

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns
import palettable
import pprint

import metrics
import climatedata
import plots
import network
import experiments

import imp
imp.reload(experiments)
imp.reload(plots)

mpl.rcParams['figure.facecolor'] = 'white'
mpl.rcParams['figure.dpi']= 150
dpiFig = 300.

np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
tf.print(f"sys.version = {sys.version}", output_stream=sys.stdout)
tf.print(f"tf.version.VERSION = {tf.version.VERSION}", output_stream=sys.stdout)

sys.version = 3.7.9 (default, Aug 31 2020, 07:22:35) 
[Clang 10.0.0 ]
tf.version.VERSION = 2.4.1


In [2]:
# DATA_NAME = 'badClasses0' #done
# DATA_NAME = 'badClasses1' # done
# DATA_NAME = 'mixedLabels2' #done
# DATA_NAME = 'mixedLabels3' # done
# DATA_NAME = 'tranquilFOO10' #done
# DATA_NAME = 'tranquilFOO12' # done
# DATA_NAME = 'tranquilFOO17' #done
# DATA_NAME = 'tranquilFOO18' # done
# DATA_NAME = 'tranquilFOO19' #done 
# DATA_NAME = 'tranquilFOO20' #done
# DATA_NAME = 'tranquilFOO22' #done
DATA_NAME = 'tranquilFOO23'

EXPINFO = experiments.define_experiments(DATA_NAME)
pprint.pprint(EXPINFO, width=60)

{'batch_size': 32,
 'cutoff': 0.5,
 'foo_region': 'nhENSO',
 'hiddens': [50, 25],
 'loss': 'NotWrongLoss',
 'lr_init': 0.001,
 'nSamples': 18000,
 'np_seed': 99,
 'numClasses': 50,
 'nupd': 6,
 'prNoise': 1.0,
 'simple_data': '15x60',
 'spinup': 0,
 'undersample': False,
 'updater': 'Colorado'}


In [3]:
NP_SEED = 99
np.random.seed(NP_SEED)
tf.random.set_seed(99)

In [4]:
def in_ipynb():
    try:
        from IPython import get_ipython
        if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
            mpl.use('Agg')            
            return False
    except:
        mpl.use('Agg')        
        return False
    return True

# Internal functions

In [5]:
def get_exp_name(loss, data_name, extra_text = ''):
    # set experiment name
    if loss == 'DNN':
        EXP_NAME = (
            data_name
            + '_DNN'
            + '_prNoise' + str(PR_NOISE)
            + '_networkSeed' + str(NETWORK_SEED)
            + '_npSeed' + str(NP_SEED)
        )                
    else:
        EXP_NAME = (
            data_name
            + '_' + loss
            + '_' + UPDATER
            + '_abstSetpoint' + str(setpoint)
            + '_prNoise' + str(PR_NOISE)
            + '_networkSeed' + str(NETWORK_SEED)
            + '_npSeed' + str(NP_SEED)
    )

    return EXP_NAME + extra_text



In [6]:
def get_frac_correct(y_true,y_pred):
    icorr = np.where(y_pred - y_true == 0)[0]
    if(len(y_true)==0):
        return 0., icorr
    else:
        return len(icorr)/len(y_true), icorr

def get_acc_stats(onehotlabels, y_pred, tranquil, abstain, dnn=False):
    
    cat_pred = np.argmax(y_pred,axis=-1)
    cat_true = np.argmax(onehotlabels,axis=-1)

    if(dnn is True):
        max_logits = np.max(y_pred,axis=-1)
        i_cover = np.where(max_logits >= np.percentile(max_logits, 100.*abstain))[0]
    else:
        i_cover = np.where(cat_pred != abstain)[0]        
    
    acc, j_corr = get_frac_correct(cat_true[i_cover],cat_pred[i_cover])
    n = len(i_cover)
    n_corr = len(j_corr)
    n_tr_corr = np.sum(tranquil[i_cover][j_corr])
    n_tr = np.sum(tranquil[i_cover])

    i_tr = np.where(tranquil[i_cover]==1)[0]
    acc_tr, __ = get_frac_correct(cat_true[i_cover][i_tr], cat_pred[i_cover][i_tr])

    i_tr = np.where(tranquil[i_cover]==0)[0]
    acc_ntr, __ = get_frac_correct(cat_true[i_cover][i_tr], cat_pred[i_cover][i_tr])

    return acc, acc_tr, acc_ntr, n, n_tr, n_tr_corr, n_corr


# Load the data

In [7]:
# load the data
if 'SSTrand' not in globals():
    try:
        SIMPLE_DATA = EXPINFO['simple_data']
    except KeyError:
        SIMPLE_DATA = False

    try:
        REGION_NAME = EXPINFO['foo_region']
    except KeyError:
        REGION_NAME = 'ENSO'
        
    if(SIMPLE_DATA==True):
        SSTrand, y, lat, lon = climatedata.load_simpledata(size='15x60')
    elif(SIMPLE_DATA==False):
        SSTrand, y, lat, lon = climatedata.load_data()
    else:
        SSTrand, y, lat, lon = climatedata.load_simpledata(size=SIMPLE_DATA)

lat = np.squeeze(lat)
lon = np.squeeze(lon)
print('SST shape = ' + str(np.shape(SSTrand)))


SST shape = (50000, 15, 60)


In [8]:
imp.reload(climatedata)
np.random.seed(NP_SEED)

NLABEL = EXPINFO['numClasses']
NSAMPLES = EXPINFO['nSamples']
PR_NOISE = EXPINFO['prNoise']
CUTOFF = EXPINFO['cutoff']
UNDERSAMPLE = EXPINFO['undersample']

#----------------------------
# get training data
X, y_cat, tranquil, corrupt, y_perc = climatedata.add_noise(data_name=DATA_NAME, 
                                                   X=SSTrand[:NSAMPLES], 
                                                   y=y[:NSAMPLES], 
                                                   lat=lat, 
                                                   lon=lon, 
                                                   pr_noise=PR_NOISE, 
                                                   nlabel=NLABEL, 
                                                   cutoff=CUTOFF,
                                                   region_name=REGION_NAME,                                                            
                                                  )
data_train, data_val, data_test = climatedata.split_data(X, y_cat, tranquil, corrupt)
X_train, y_train, tr_train, cr_train = data_train
X_test, y_test, tr_test, cr_test = data_test

#----------------------------
# undersample the data
if UNDERSAMPLE:
    print('----Training----')
    X_train, y_train, tr_train = climatedata.undersample(X_train, y_train, tr_train) # training data
    print('total samples = ' + str(np.shape(X_train)[0]))    
    print('----Testing----')
    X_test, y_test, tr_test = climatedata.undersample(X_test, y_test, tr_test) # testing data
    print('total samples = ' + str(np.shape(X_test)[0]))            
    
# process data for training
X_train_std, onehotlabels, X_test_std, onehotlabels_test, xmean, xstd = climatedata.preprocess_data(X_train, y_train, X_test, y_test, NLABEL)


region shape = 1 x 15

----Mislabeled----
# tranquil = 5227 out of 18000 samples
percent tranquil = 29.0%
tranquil mislabeled = 0.0%
non-tranquil mislabeled = 100.0%
total mislabeled = 71.0%
----Training----
(8000, 15, 60)
(8000, 1)
(8000,)
----Validation----
(5000, 15, 60)
(5000, 1)
(5000,)
----Testing----
(5000, 15, 60)
(5000, 1)
(5000,)


# Plot the results

In [9]:
imp.reload(metrics)
LOSS = EXPINFO['loss']
UPDATER = EXPINFO['updater']
REWRITE = False
EXTRA_TEXT = ''
#---------------------
# Set parameters
hiddens = EXPINFO['hiddens']
SPINUP_EPOCHS = EXPINFO['spinup']
BATCH_SIZE = EXPINFO['batch_size']
N_EPOCHS = 200
lr_epoch_bound = 10000
RIDGE = 0.
DNN_EPOCHS = 999
DAC_EPOCHS = 999
#---------------------
DNN_model = network.defineNN(hiddens, 
                             input_shape=X_train_std.shape[1], 
                             output_shape=NLABEL, 
                             ridge_penalty=RIDGE, 
                             act_fun='relu', 
                             network_seed=99)
DAC_model = network.defineNN(hiddens, 
                             input_shape=X_train_std.shape[1], 
                             output_shape=NLABEL+1, 
                             ridge_penalty=RIDGE, 
                             act_fun='relu', 
                             network_seed=99)

In [10]:
# raise ValueError('here')

## Compare all stats

In [11]:
approach_dic = {'DNN':'', 
                'DAC':'', 
#                 'DNN-DNN':'_postDNN-DNN', 
#                 'DAC-DNN':'_postDAC-DNN', 
                'ORACLE':'_oracle', 
#                 'SELENE':'_selene'
               }
abstain_setpoint = np.around(np.arange(0., 1.01, .01), 3)
seed_vector = np.arange(0,50)

df = pd.DataFrame(columns=('epochs',
                               'network_seed',
                               'np_seed',
                               'app_type',
                               'setpoint',
                               'frac_abstain',
                               'coverage',
                               'acc',
                               'acc_tr',
                               'acc_ntr',
                               'n_cover',
                               'n_tr',
                               'n_tr_corr',
                               'n_corr',
                               'frac_tr',                               
                               'frac_corr_tr',
                               'acc_portion_tr',
                               'acc_portion_ntr',
                               'perf_frac_tr',                               
                              )) 
curr_seed = -999
for NETWORK_SEED in seed_vector:  
        
    for setpoint in abstain_setpoint:
        for app in approach_dic.keys():

            #-------------------
            # get model names
            if(app=='DNN' or app=='ORACLE' or app=='SELENE'):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
            elif(app=='DAC'):
                EXP_NAME = get_exp_name(loss = LOSS, data_name=DATA_NAME, extra_text = approach_dic[app])
            elif(app=='DNN-DNN' or app=='DAC-DNN'):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
                i = EXP_NAME.find('prNoise')
                EXP_NAME = EXP_NAME[:i] + 'abstSetpoint' + str(setpoint) + '_' + EXP_NAME[i:]
            else:
                raise ValueError('no such approach')
            model_file = 'saved_models/model_' +  EXP_NAME + '.h5'
            
            #check for the model file, skip if it does not exist
            if(os.path.exists(model_file)==False):
                continue
            if(curr_seed!=NETWORK_SEED):
                print('network seed ' + str(NETWORK_SEED) + ' of ' + str(seed_vector[-1]))
                curr_seed = NETWORK_SEED
            
            if(app=='DAC'):
                DAC_model.load_weights(model_file)
                y_pred = DAC_model.predict(X_test_std)
                abst = np.argmax(y_pred,axis=-1)
                frac_abstain = len(np.where(abst==NLABEL)[0])/np.shape(y_pred)[0]  # compute abstention fraction
                acc, acc_tr, acc_ntr, n_cover, n_tr, n_tr_corr, n_corr = get_acc_stats(onehotlabels=onehotlabels_test, 
                                                                    y_pred=y_pred, 
                                                                    tranquil=tr_test, 
                                                                    abstain=NLABEL,
                                                                    dnn=False)             
                df1 = pd.DataFrame({'epochs': DNN_EPOCHS,
                                'network_seed': NETWORK_SEED, 
                                'np_seed': NP_SEED, 
                                'app_type':app, 
                                'setpoint': setpoint, 
                                'frac_abstain': frac_abstain, 
                                'coverage':100.*(1.-frac_abstain), 
                                'acc': acc,
                                'acc_tr':acc_tr,
                                'acc_ntr':acc_ntr,
                                'n_cover':n_cover,
                                'n_tr':n_tr,
                                'n_tr_corr':n_tr_corr,
                                'n_corr':n_corr,
                                'frac_tr':n_tr/n_cover,
                                'frac_corr_tr':n_tr_corr/n_corr,
                                'acc_portion_tr':n_tr_corr/n_cover,
                                'acc_portion_ntr':(n_corr-n_tr_corr)/n_cover,                                    
                                'perf_frac_tr':np.minimum( 1., np.sum(tr_test)/ n_cover)
                               },index=[0])                
                df = df.append(df1,ignore_index = True)
            else:    
                DNN_model.load_weights(model_file)
                y_pred = DNN_model.predict(X_test_std)                
                acc, acc_tr, acc_ntr, n_cover, n_tr, n_tr_corr, n_corr = get_acc_stats(onehotlabels=onehotlabels_test, 
                                                                                       y_pred=y_pred,
                                                                                       tranquil=tr_test,
                                                                                       abstain=setpoint,
                                                                                       dnn=True)       
                df1 = pd.DataFrame({'epochs': DAC_EPOCHS,
                                'network_seed': NETWORK_SEED, 
                                'np_seed': NP_SEED, 
                                'app_type':app, 
                                'setpoint': setpoint, 
                                'frac_abstain': setpoint, 
                                'coverage':100.*(1.-setpoint), 
                                'acc': acc,
                                'acc_tr':acc_tr,
                                'acc_ntr':acc_ntr,
                                'n_cover':n_cover,
                                'n_tr':n_tr,
                                'n_tr_corr':n_tr_corr,
                                'n_corr':n_corr,
                                'frac_tr':n_tr/n_cover,
                                'frac_corr_tr':n_tr_corr/n_corr,
                                'acc_portion_tr':n_tr_corr/n_cover,
                                'acc_portion_ntr':(n_corr-n_tr_corr)/n_cover,
                                'perf_frac_tr':np.minimum( 1., np.sum(tr_test)/ n_cover)
                               },index=[0])
                df = df.append(df1,ignore_index = True)
#----------------------------------------------------------            
savename = (DATA_NAME
           + '_' + LOSS
           + '_npSeed'
           + str(NP_SEED)
           )
df.to_pickle('predictions/' + savename + '.pkl')
print('computation done.')

plots.plot_stats_comparisons(df, savename=DATA_NAME, lines=True, shades=True)                
plt.savefig('figures/summary_plots/' 
            + '/statsComparisons_' 
            + savename
            +'.png',dpi=dpiFig)      
plt.close()
print('plot done.')


network seed 0 of 49
network seed 1 of 49
network seed 2 of 49
network seed 3 of 49
network seed 4 of 49
network seed 5 of 49
network seed 6 of 49
network seed 7 of 49
network seed 8 of 49
network seed 9 of 49
network seed 10 of 49
network seed 11 of 49
network seed 12 of 49
network seed 13 of 49
network seed 14 of 49
network seed 15 of 49
network seed 16 of 49
network seed 17 of 49
network seed 18 of 49
network seed 19 of 49




network seed 20 of 49
network seed 21 of 49
network seed 22 of 49
network seed 23 of 49
network seed 24 of 49
network seed 25 of 49
network seed 26 of 49
network seed 27 of 49
network seed 28 of 49
network seed 29 of 49
network seed 30 of 49
network seed 31 of 49
network seed 32 of 49
network seed 33 of 49
network seed 34 of 49
network seed 35 of 49
network seed 36 of 49
network seed 37 of 49
network seed 38 of 49
network seed 39 of 49
network seed 40 of 49
network seed 41 of 49
network seed 42 of 49
network seed 43 of 49




network seed 44 of 49
network seed 45 of 49
network seed 46 of 49
network seed 47 of 49
network seed 48 of 49
network seed 49 of 49
computation done.
plot done.


## Compare stats on samples on which the DAC did not abstain

In [12]:
raise ValueError('here')

ValueError: here

In [None]:
df_noabs = pd.DataFrame(columns=('epochs',
                               'network_seed',
                               'np_seed',
                               'app_type',
                               'setpoint',
                               'frac_abstain',
                               'coverage',
                               'acc',
                               'acc_tr',
                               'acc_ntr',
                               'n_cover',
                               'n_tr',
                               'n_tr_corr',
                               'n_corr',
                               'frac_tr',                               
                               'frac_corr_tr',
                               'acc_portion_tr',
                               'acc_portion_ntr',                           
                               'perf_frac_tr',                               
                              )) 

curr_seed = -999
for NETWORK_SEED in seed_vector:  
    for setpoint in abstain_setpoint:
        i_dac_abstain = []
        for app in approach_dic.keys():

            #-------------------
            # get model names
            if(app=='DNN' or app=='ORACLE' or app=='SELENE'):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
            elif(app=='DAC'):
                EXP_NAME = get_exp_name(loss = LOSS, data_name=DATA_NAME, extra_text = approach_dic[app])
            elif(app=='DNN-DNN' or app=='DAC-DNN'):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
                i = EXP_NAME.find('prNoise')
                EXP_NAME = EXP_NAME[:i] + 'abstSetpoint' + str(setpoint) + '_' + EXP_NAME[i:]
            else:
                raise ValueError('no such approach')
            model_file = 'saved_models/model_' +  EXP_NAME + '.h5'
            
            #check for the model file, skip if it does not exist
            if(os.path.exists(model_file)==False):
                continue
            if(curr_seed!=NETWORK_SEED):
                print('network seed ' + str(NETWORK_SEED) + ' of ' + str(seed_vector[-1]))
                curr_seed = NETWORK_SEED
                
            
            if(app=='DAC'):
                DAC_model.load_weights(model_file)
                y_pred = DAC_model.predict(X_test_std)
                abst = np.argmax(y_pred,axis=-1)
                i_dac_abstain = np.where(abst==NLABEL)[0]
                frac_abstain = len(np.where(abst==NLABEL)[0])/np.shape(y_pred)[0]  # compute abstention fraction
            else:
                DNN_model.load_weights(model_file)
                y_pred = DNN_model.predict(X_test_std)     
                y_new_pred = np.append(y_pred,np.zeros([len(y_pred),1]),1)
                y_new_pred[i_dac_abstain,:] = 0.
                y_new_pred[i_dac_abstain,-1] = 1.
                y_pred = y_new_pred
                abst = np.argmax(y_pred,axis=-1)                
                frac_abstain = len(np.where(abst==NLABEL)[0])/np.shape(y_pred)[0]  # compute abstention fraction
                
            acc, acc_tr, acc_ntr, n_cover, n_tr, n_tr_corr, n_corr = get_acc_stats(onehotlabels=onehotlabels_test, 
                                                                y_pred=y_pred, 
                                                                tranquil=tr_test, 
                                                                abstain=NLABEL,
                                                                dnn=False)             
            df1 = pd.DataFrame({'epochs': DNN_EPOCHS,
                            'network_seed': NETWORK_SEED, 
                            'np_seed': NP_SEED, 
                            'app_type':app, 
                            'setpoint': setpoint, 
                            'frac_abstain': frac_abstain, 
                            'coverage':100.*(1.-frac_abstain), 
                            'acc': acc,
                            'acc_tr':acc_tr,
                            'acc_ntr':acc_ntr,
                            'n_cover':n_cover,
                            'n_tr':n_tr,
                            'n_tr_corr':n_tr_corr,
                            'n_corr':n_corr,
                            'frac_tr':n_tr/n_cover,
                            'frac_corr_tr':n_tr_corr/n_corr,
                            'acc_portion_tr':n_tr_corr/n_cover,
                            'acc_portion_ntr':(n_corr-n_tr_corr)/n_cover,                                
                            'perf_frac_tr':np.minimum( 1., np.sum(tr_test)/ n_cover)
                           },index=[0])                
            df_noabs = df_noabs.append(df1,ignore_index = True)
            
#----------------------------------------------------------            
savename = (DATA_NAME
           + '_' + LOSS
           + '_npSeed'
           + str(NP_SEED)
           + '_noAbstain' 
           )
df.to_pickle('predictions/' + savename + '.pkl')

plots.plot_stats_comparisons(df_noabs, savename, lines=True, shades=False)
plt.savefig('figures/summary_plots/' 
            + '/statsComparisons_noAbstain_' 
            + savename
            +'.png',dpi=dpiFig)            


## Look at softmax output

In [None]:
raise ValueError('here')

In [None]:
approach_dic = {
#                 'DNN':'', 
                'DAC':'', 
#                 'DNN-DNN':'_postDNN-DNN', 
#                 'DAC-DNN':'_postDAC-DNN', 
#                 'ORACLE':'_oracle', 
#                 'SELENE':'_selene'
               }
abstain_setpoint = (.6,)#np.around(np.arange(0., 1.01, .01), 3)
seed_vector = (19,)#np.arange(0,50)

df = pd.DataFrame(columns=('epochs',
                               'network_seed',
                               'np_seed',
                               'app_type',
                               'setpoint',
                               'frac_abstain',
                               'coverage',
                               'acc',
                               'acc_tr',
                               'acc_ntr',
                               'n_cover',
                               'n_tr',
                               'n_tr_corr',
                               'n_corr',
                               'frac_tr',                               
                               'frac_corr_tr',
                               'acc_portion_tr',
                               'acc_portion_ntr',
                               'perf_frac_tr',                               
                              )) 
curr_seed = -999
for NETWORK_SEED in seed_vector:  
        
    for setpoint in abstain_setpoint:
        for app in approach_dic.keys():

            #-------------------
            # get model names
            if(app=='DNN' or app=='ORACLE' or app=='SELENE'):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
            elif(app=='DAC'):
                EXP_NAME = get_exp_name(loss = LOSS, data_name=DATA_NAME, extra_text = approach_dic[app])
            elif(app=='DNN-DNN' or app=='DAC-DNN'):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
                i = EXP_NAME.find('prNoise')
                EXP_NAME = EXP_NAME[:i] + 'abstSetpoint' + str(setpoint) + '_' + EXP_NAME[i:]
            else:
                raise ValueError('no such approach')
            model_file = 'saved_models/model_' +  EXP_NAME + '.h5'
            
            #check for the model file, skip if it does not exist
            if(os.path.exists(model_file)==False):
                continue
            if(curr_seed!=NETWORK_SEED):
                print('network seed ' + str(NETWORK_SEED) + ' of ' + str(seed_vector[-1]))
                curr_seed = NETWORK_SEED
            
            if(app=='DAC'):
                DAC_model.load_weights(model_file)
                y_pred = DAC_model.predict(X_test_std)
                abst = np.argmax(y_pred,axis=-1)
                frac_abstain = len(np.where(abst==NLABEL)[0])/np.shape(y_pred)[0]  # compute abstention fraction
                acc, acc_tr, acc_ntr, n_cover, n_tr, n_tr_corr, n_corr = get_acc_stats(onehotlabels=onehotlabels_test, 
                                                                    y_pred=y_pred, 
                                                                    tranquil=tr_test, 
                                                                    abstain=NLABEL,
                                                                    dnn=False)             
                df1 = pd.DataFrame({'epochs': DNN_EPOCHS,
                                'network_seed': NETWORK_SEED, 
                                'np_seed': NP_SEED, 
                                'app_type':app, 
                                'setpoint': setpoint, 
                                'frac_abstain': frac_abstain, 
                                'coverage':100.*(1.-frac_abstain), 
                                'acc': acc,
                                'acc_tr':acc_tr,
                                'acc_ntr':acc_ntr,
                                'n_cover':n_cover,
                                'n_tr':n_tr,
                                'n_tr_corr':n_tr_corr,
                                'n_corr':n_corr,
                                'frac_tr':n_tr/n_cover,
                                'frac_corr_tr':n_tr_corr/n_corr,
                                'acc_portion_tr':n_tr_corr/n_cover,
                                'acc_portion_ntr':(n_corr-n_tr_corr)/n_cover,                                    
                                'perf_frac_tr':np.minimum( 1., np.sum(tr_test)/ n_cover)
                               },index=[0])                
                df = df.append(df1,ignore_index = True)
            else:    
                DNN_model.load_weights(model_file)
                y_pred = DNN_model.predict(X_test_std)                
                acc, acc_tr, acc_ntr, n_cover, n_tr, n_tr_corr, n_corr = get_acc_stats(onehotlabels=onehotlabels_test, 
                                                                                       y_pred=y_pred,
                                                                                       tranquil=tr_test,
                                                                                       abstain=setpoint,
                                                                                       dnn=True)       
                df1 = pd.DataFrame({'epochs': DAC_EPOCHS,
                                'network_seed': NETWORK_SEED, 
                                'np_seed': NP_SEED, 
                                'app_type':app, 
                                'setpoint': setpoint, 
                                'frac_abstain': setpoint, 
                                'coverage':100.*(1.-setpoint), 
                                'acc': acc,
                                'acc_tr':acc_tr,
                                'acc_ntr':acc_ntr,
                                'n_cover':n_cover,
                                'n_tr':n_tr,
                                'n_tr_corr':n_tr_corr,
                                'n_corr':n_corr,
                                'frac_tr':n_tr/n_cover,
                                'frac_corr_tr':n_tr_corr/n_corr,
                                'acc_portion_tr':n_tr_corr/n_cover,
                                'acc_portion_ntr':(n_corr-n_tr_corr)/n_cover,
                                'perf_frac_tr':np.minimum( 1., np.sum(tr_test)/ n_cover)
                               },index=[0])
                df = df.append(df1,ignore_index = True)
                
# make logits dataframe                
df_logits = pd.DataFrame(columns=('max_logit',
                                  'pred_label',
                                  'true_label'
                                 )
                        )
pred_label = np.argmax(y_pred,axis=-1)
df_logits['pred_label'] = pred_label

max_logit = np.max(y_pred,axis=-1)
df_logits['max_logit'] = max_logit

true_label = np.argmax(onehotlabels_test,axis=-1)
df_logits['true_label'] = true_label

df_logits

In [None]:
colors = cmap = palettable.cartocolors.qualitative.Vivid_10.mpl_colors
inc = .05
xbins = np.arange(-inc/2,1.+inc/2,inc)

plt.figure(figsize=(4*4,8))
for label in np.arange(0,NLABEL+1):

    clr = colors[1]
    if(label==NLABEL):
        clr = colors[0]
    
    data_plot = df_logits[df_logits['pred_label']==label]
    plt.subplot(3,4,label+1)
    sns.histplot(data_plot, x="max_logit", element="step", bins=xbins, legend=False, color=clr)
    plt.xlim(0,1.0)
    plt.xlabel('softmax output')
    plt.title('predicted label = ' + str(label))
    if(label==NLABEL):
        plt.title('predicted label = abstention class')
plt.subplot(3,4,12)
plt.text(.1,.4,DATA_NAME+'\nsetpoint = ' + str(setpoint), horizontalalignment='left', fontsize=10)
plt.axis('off')

# #-----------------------------
plt.tight_layout()
save_name = DATA_NAME + '_setpoint' + str(setpoint)
plt.savefig('figures/summary_plots/' 
            'logit_histograms_'
            + save_name
            +'.png',dpi=dpiFig)  
plt.close()