# Exploring Abstention Loss
author: Elizabeth A. Barnes, Randal J. Barnes
date: January 15, 2021, 0738MST

* based on Thulasidasan, S., T. Bhattacharya, J. Bilmes, G. Chennupati, and J. Mohd-Yusof, 2019: Combating Label Noise in Deep Learning Using Abstention. arXiv [stat.ML],.
* thesis: https://digital.lib.washington.edu/researchworks/handle/1773/45781
* code base is here: https://github.com/thulas/dac-label-noise/blob/master/dac_loss.py

In [None]:
import numpy as np
import time
import sys
import collections
import os
import glob
import pickle

import sklearn
from sklearn.model_selection import train_test_split
from sklearn import preprocessing

import tensorflow as tf
from tensorflow.keras import optimizers
import matplotlib as mpl
import matplotlib.pyplot as plt
import cartopy as ct
import cartopy.crs as ccrs

import abstentionloss
import metrics
import network
import plots
import climatedata
import experiments

import imp
imp.reload(experiments)
imp.reload(abstentionloss)
imp.reload(plots)
imp.reload(climatedata)

import palettable
import pprint

mpl.rcParams['figure.facecolor'] = 'white'
mpl.rcParams['figure.dpi']= 150
dpiFig = 300.

np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
tf.print(f"sys.version = {sys.version}", output_stream=sys.stdout)
tf.print(f"tf.version.VERSION = {tf.version.VERSION}", output_stream=sys.stdout)

In [None]:
#--------------------------------------------------------
DATA_NAME = 'tranquilFOO23'#'tranquilFOO0'
SCRIPT_NAME = 'trainingApproach_climatedata_v2.26_cmdA.py'
checkpointDir = '/Users/eabarnes/Data/2021/abstention_loss/checkpoints/'
EXPINFO = experiments.define_experiments(DATA_NAME)
pprint.pprint(EXPINFO, width=60)
#--------------------------------------------------------

In [None]:
NP_SEED = 99
np.random.seed(NP_SEED)
tf.random.set_seed(99)

## Internal functions

In [None]:
def in_ipynb():
    try:
        from IPython import get_ipython
        if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
            mpl.use('Agg')            
            return False
    except:
        mpl.use('Agg')        
        return False
    return True

In [None]:
def get_exp_name(loss, data_name, extra_text = ''):
    # set experiment name
    if loss == 'DNN':
        EXP_NAME = (
            data_name
            + '_DNN'
            + '_prNoise' + str(PR_NOISE)
            + '_networkSeed' + str(NETWORK_SEED)
            + '_npSeed' + str(NP_SEED)
        )                
    else:
        EXP_NAME = (
            data_name
            + '_' + loss
            + '_' + UPDATER
            + '_abstSetpoint' + str(setpoint)
            + '_prNoise' + str(PR_NOISE)
            + '_networkSeed' + str(NETWORK_SEED)
            + '_npSeed' + str(NP_SEED)
    )

    return EXP_NAME + extra_text

In [None]:
def make_model(loss_str = 'DNN', updater_str='Colorado', setpoint=.5, spinup_epochs=10, nupd=10):
    # Define and train the model
    
    tf.keras.backend.clear_session()
    if(loss_str == 'DNN'):
        model = network.defineNN(hiddens, input_shape=X_train_std.shape[1], output_shape=NLABEL, ridge_penalty=RIDGE, act_fun='relu', network_seed=NETWORK_SEED)
        loss_function = tf.keras.losses.CategoricalCrossentropy()
        model.compile(
            optimizer=optimizers.SGD(lr=LR_INIT, momentum=0.9, nesterov=True),
            loss = loss_function,
            metrics=[
                metrics.AbstentionFraction(NLABEL),
                metrics.PredictionAccuracy(NLABEL)
            ]
        )        
    else:
        model = network.defineNN(hiddens, input_shape=X_train_std.shape[1], output_shape=NLABEL+1, ridge_penalty=RIDGE, act_fun='relu', network_seed=NETWORK_SEED)
        updater = getattr(abstentionloss, updater_str)(setpoint=setpoint, 
                                                       alpha_init=.5, 
                                                       length=nupd)
        loss_function = getattr(abstentionloss, loss_str)(updater=updater,
                                                          spinup_epochs=spinup_epochs)
        model.compile(
            optimizer=optimizers.SGD(lr=LR_INIT, momentum=0.9, nesterov=True),
            loss = loss_function,
            metrics=[
                alpha_value,
                metrics.AbstentionFraction(NLABEL),
                metrics.PredictionLoss(NLABEL),
                metrics.PredictionAccuracy(NLABEL)
            ]
        )        
        
    # model.summary()

        
    return model, loss_function

## Load the data

In [None]:
# load the data
if 'SSTrand' not in globals():
    try:
        SIMPLE_DATA = EXPINFO['simple_data']
    except KeyError:
        SIMPLE_DATA = False

    try:
        REGION_NAME = EXPINFO['foo_region']
    except KeyError:
        REGION_NAME = 'ENSO'
        
    if(SIMPLE_DATA==True):
        SSTrand, y, lat, lon = climatedata.load_simpledata(size='15x60')
    elif(SIMPLE_DATA==False):
        SSTrand, y, lat, lon = climatedata.load_data()
    else:
        SSTrand, y, lat, lon = climatedata.load_simpledata(size=SIMPLE_DATA)

lat = np.squeeze(lat)
lon = np.squeeze(lon)
print('SST shape = ' + str(np.shape(SSTrand)))

# define the ENSO region
reg_lats, reg_lons = climatedata.get_region(region_name = REGION_NAME)

# plot the data
cmap = palettable.cartocolors.diverging.Geyser_7.mpl_colormap
    
if in_ipynb():
    plt.figure(figsize=(12,2.73*2))
    mapProj = ct.crs.EqualEarth(central_longitude = 0.)
    ax = plt.subplot(1,2,1,projection=mapProj)
    cb, image = plots.drawOnGlobe(ax, 
                            mapProj,
                            SSTrand[20,:,:], 
                            np.squeeze(lat), 
                            np.squeeze(lon), 
                            cmap = cmap, 
                            vmin = -3, 
                            vmax=3, 
                            cbarBool=True, 
                            fastBool=True, 
                            extent='both'
                           )
    plt.plot([reg_lons[0], reg_lons[0],reg_lons[1],reg_lons[1],reg_lons[0]], [reg_lats[0], reg_lats[1], reg_lats[1], reg_lats[0],reg_lats[0]],
             color='white', linestyle='--',
             transform=ccrs.PlateCarree(),
             )
    plt.show()

In [None]:
imp.reload(climatedata)
np.random.seed(NP_SEED)

NLABEL = EXPINFO['numClasses']
NSAMPLES = EXPINFO['nSamples']
PR_NOISE = EXPINFO['prNoise']
CUTOFF = EXPINFO['cutoff']
UNDERSAMPLE = EXPINFO['undersample']

#----------------------------
X, y_cat, tranquil, corrupt, y_perc = climatedata.add_noise(data_name=DATA_NAME, 
                                                   X=SSTrand[:NSAMPLES], 
                                                   y=y[:NSAMPLES], 
                                                   lat=lat, 
                                                   lon=lon, 
                                                   pr_noise=PR_NOISE, 
                                                   nlabel=NLABEL, 
                                                   cutoff=CUTOFF,
                                                   region_name=REGION_NAME,                                                            
                                                  )
data_train, data_val, data_test = climatedata.split_data(X, y_cat, tranquil, corrupt)
X_train, y_train, tr_train, cr_train = data_train
X_val, y_val, tr_val, cr_val = data_val

print('Train Shape = ' + str(np.shape(X_train)))
print('Validation Shape = ' + str(np.shape(X_val)))

# undersample the data
if UNDERSAMPLE:
    print('----Training----')
    X_train, y_train, tr_train = climatedata.undersample(X_train, y_train, tr_train) # training data
    print('total samples = ' + str(np.shape(X_train)[0]))    
    print('----Validation----')
    X_val, y_val, tr_val = climatedata.undersample(X_val, y_val, tr_val) # validation data
    print('total samples = ' + str(np.shape(X_val)[0]))               
    
# process data for training
X_train_std, onehotlabels, X_val_std, onehotlabels_val, xmean, xstd = climatedata.preprocess_data(X_train, y_train, X_val, y_val, NLABEL)

if in_ipynb():
    plt.figure(figsize=(6*1.5,3*1.5))
    plt.subplot(2,2,1)
    plt.hist(y_train,np.arange(0,NLABEL+1))
    plt.xlabel('labels')
    plt.title('all')
    
    plt.subplot(2,2,4)
    plt.hist(y_train[cr_train==1],np.arange(0,NLABEL+1))
    plt.xlabel('class')
    plt.title('corrupted labels')
        
    plt.subplot(2,2,3)
    plt.hist(y_train[tr_train==1],np.arange(0,NLABEL+1))
    plt.xlabel('class')
    plt.title('tranquil labels')

    plt.subplot(2,2,2)
    plt.hist(y_train[tr_train==0],np.arange(0,NLABEL+1))
    plt.xlabel('class')
    plt.title('not tranquil')
    
    plt.tight_layout()
    plt.show()


# Train the model

In [None]:
def alpha_value(y_true,y_pred):
    return loss_function.updater.alpha

def scheduler(epoch, lr):
    if epoch < lr_epoch_bound:
        return lr
    else:
        return LR_INIT/2.#lr*tf.math.exp(-0.1)

class EarlyStoppingDAC(tf.keras.callbacks.Callback):
    """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

    def __init__(self, patience=0):
        super(EarlyStoppingDAC, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as zero.
        self.best = 0.
        self.best_epoch = np.Inf
        # initialize best_weights to non-trained model
        self.best_weights = self.model.get_weights()
        

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("val_prediction_accuracy")
        if np.greater(current, self.best):
            abstention_error = np.abs(logs.get("val_abstention_fraction") - setpoint)
            if np.less(abstention_error,.1):
                self.best = current
                self.wait = 0
                # Record the best weights if current results is better (greater).
                self.best_weights = self.model.get_weights()
                self.best_epoch = epoch
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Early stopping, setting to best_epoch = " + str(self.best_epoch + 1))    

In [None]:
LOSS = EXPINFO['loss']
UPDATER = EXPINFO['updater']
REWRITE = False
SAVE_HISTORY = True
EXTRA_TEXT = ''
#---------------------
# Set parameters
NUPD = EXPINFO['nupd']
hiddens = EXPINFO['hiddens']
SPINUP_EPOCHS = EXPINFO['spinup']
BATCH_SIZE = EXPINFO['batch_size']
LR_INIT = EXPINFO['lr_init']

N_EPOCHS = 200
lr_epoch_bound = 10000
RIDGE = 0.
#---------------------

In [None]:
approach_dic = {'DNN':'', 
                'DAC':'', 
#                 'DNN-DNN':'_postDNN-DNN', 
#                 'DAC-DNN':'_postDAC-DNN', 
                'ORACLE':'_oracle', 
#                 'SELENE':'_selene'
               }
abstain_setpoint = np.around(np.arange(0., 1., .1), 3)
seed_vector = np.arange(0,50)

if in_ipynb():
    NETWORK_SEED_LIST = (0,)
else:
    NETWORK_SEED_LIST = (int(sys.argv[-1]),)
    if(NETWORK_SEED_LIST[0]>np.max(seed_vector)):
        sys.exit()
for NETWORK_SEED in NETWORK_SEED_LIST:
    for setpoint in abstain_setpoint:
        for app in approach_dic.keys():

            # skipping rules----
            if(setpoint==0):
                if((app != 'DNN') and (app != 'ORACLE') and (app != 'SELENE')):
                    continue
            else:
                if((app=='DNN') or app=='ORACLE' or app=='SELENE'):
                    continue
            #-------------------

            if((app=='DNN') or (app=='ORACLE' or app=='SELENE')):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
            elif(app=='DAC'):
                EXP_NAME = get_exp_name(loss = LOSS, data_name=DATA_NAME, extra_text = approach_dic[app])
            elif(app=='DNN-DNN' or app=='DAC-DNN'):
                EXP_NAME = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text=approach_dic[app])
                i = EXP_NAME.find('prNoise')
                EXP_NAME = EXP_NAME[:i] + 'abstSetpoint' + str(setpoint) + '_' + EXP_NAME[i:]
            else:
                raise ValueError('no such approach')
            model_name = 'saved_models/model_' +  EXP_NAME

            if(os.path.exists((model_name + '.h5').format(N_EPOCHS)) and REWRITE==False):
                continue
            else:
                print(EXP_NAME)   

            #-------------------------------
            # Determine indices to grab for training of the different approaches
            if((app=='DNN') or (app=='DAC')):
                i_train = np.arange(0,np.shape(onehotlabels)[0])
                i_val = np.arange(0,np.shape(onehotlabels_val)[0])

            elif(app=='ORACLE'):
                i_train = np.where(cr_train==0)[0]
                i_val = np.where(cr_val==0)[0]
                
            elif(app=='SELENE'):
                i_train = np.where(tr_train==1)[0]
                i_val = np.where(tr_val==1)[0]
                
            elif(app=='DNN-DNN'):
                exp_name_0 = get_exp_name(loss = 'DNN', data_name=DATA_NAME, extra_text='')
                model_name_0 = 'saved_models/model_' +  exp_name_0 + '.h5'
                model0, __ = make_model(loss_str = 'DNN')
                model0.load_weights(model_name_0)

                y_pred_train_0 = model0.predict(X_train_std)
                y_pred_val_0 = model0.predict(X_val_std)
                max_logits = np.max(y_pred_train_0,axis=-1)
                i_train = np.where(max_logits >= np.percentile(max_logits, 100*setpoint))[0]
                max_logits = np.max(y_pred_val_0,axis=-1)
                i_val = np.where(max_logits >= np.percentile(max_logits, 100*setpoint))[0]            

            elif(app=='DAC-DNN'):
                exp_name_0 = get_exp_name(loss = LOSS, data_name=DATA_NAME, extra_text='')
                model_name_0 = 'saved_models/model_' +  exp_name_0 + '.h5'
                model0, __ = make_model(loss_str = LOSS)
                model0.load_weights(model_name_0)

                y_pred_train_0 = model0.predict(X_train_std)
                y_pred_val_0 = model0.predict(X_val_std)
                i_train = np.where(np.argmax(y_pred_train_0,axis=-1) != NLABEL)[0]
                i_val = np.where(np.argmax(y_pred_val_0,axis=-1) != NLABEL)[0]

            else:
                raise ValueError('no such app')

            #-------------------------------
            # Get the model
            tf.keras.backend.clear_session()

            # callbacks
            lr_callback = tf.keras.callbacks.LearningRateScheduler(scheduler,verbose=0)
            cp_callback = tf.keras.callbacks.ModelCheckpoint(
                filepath = checkpointDir + 'model_' + EXP_NAME + '_epoch{epoch:03d}.h5', 
                verbose=0, 
                save_weights_only=True,
            )

            # define the model and loss function
            if(app=='DAC'):
                es_dac_callback = EarlyStoppingDAC(patience=30)            
                model, loss_function = make_model(loss_str = LOSS, 
                                                  updater_str=UPDATER, 
                                                  setpoint=setpoint, 
                                                  spinup_epochs=SPINUP_EPOCHS,
                                                  nupd=NUPD)
                callbacks = [abstentionloss.AlphaUpdaterCallback(), lr_callback, cp_callback, es_dac_callback]            
            else:
                es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_prediction_accuracy', patience=30, mode='max', restore_best_weights=True, verbose=1)                    
                model, loss_function = make_model(loss_str = 'DNN')
                callbacks = [lr_callback, cp_callback, es_callback]

            #-------------------------------
            # Remake onehotencoding
            hotlabels = onehotlabels[:,:model.output_shape[-1]] # strip off abstention class if using the DNN
            hotlabels_val = onehotlabels_val[:,:model.output_shape[-1]] # strip off abstention class if using the DNN

            #-------------------------------
            # Train the model

            start_time = time.time()

            try:
                history = model.fit(
                    X_train_std[i_train],
                    hotlabels[i_train],
                    validation_data=(X_val_std[i_val], hotlabels_val[i_val]),
                    batch_size=BATCH_SIZE,
                    epochs=N_EPOCHS,
                    shuffle=True,
                    verbose=0,
                    callbacks=callbacks
                )
                if(SAVE_HISTORY):
                    # save history data
                    history_dict = model.history.history
                    history_file = 'saved_models/history_' +  EXP_NAME + '.pickle'
                    with open(history_file, 'wb') as handle:
                        pickle.dump(history_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
                    
            except ValueError:
                continue
                
            stop_time = time.time()
            tf.print(f"Elapsed time during fit = {stop_time - start_time:.2f} seconds\n")
            
            model.save_weights(model_name + '.h5')
            for f in glob.glob(checkpointDir + 'model_' + EXP_NAME + "_epoch*.h5"):
                os.remove(f)        

            #-------------------------------
            # Display the results

            exp_info=(LOSS, N_EPOCHS, setpoint, SPINUP_EPOCHS, hiddens, LR_INIT, lr_epoch_bound, BATCH_SIZE, NETWORK_SEED)

            plots.plot_results(
                EXP_NAME,
                history,
                exp_info=exp_info,
                saveplot=True,
                showplot=True
            )

if in_ipynb()==False:    
    print('-----starting new kernel-----')                
    os.execv(sys.executable, ['python'] + ['/Users/eabarnes/GoogleDrive/WORK/RESEARCH/2021/abstention_networks/' + SCRIPT_NAME] + [str(NETWORK_SEED+1)])        
    print('-----exiting...')        
    sys.exit() 

In [None]:
# (X_val_std[i_val], hotlabels_val[i_val])
# model.evaluate(x=X_val_std[i_val], y=hotlabels_val[i_val])