# Train and Evaluate ResNet 20 on CIFAR 10

## Import

In [None]:
from __future__ import print_function
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, Flatten
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
# from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
import os

import pickle
import matplotlib.pyplot as plt
import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath("utility"))))
from utility.evaluation import softmax, ECE, MCE
import sklearn.metrics as metrics
import glob
import csv

In [None]:
# solves out of memory issue
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)

## Parameter

In [None]:
batch_size = 512
epochs = 500
data_augmentation = True
num_classes = 10
subtract_pixel_mean = True
n = 3
depth = n * 6 + 2
seed = 86

## Prepare Dataset

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
if subtract_pixel_mean:
    x_train_mean = np.mean(x_train, axis=0)
    x_train -= x_train_mean
    x_test -= x_train_mean

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=seed)
input_shape = x_train.shape[1:]

## Model

In [None]:
# model from https://keras.io/examples/cifar10_resnet/

def lr_schedule(epoch):
    """Learning Rate Schedule

    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
    Called automatically every epoch as part of callbacks during training.

    # Arguments
        epoch (int): The number of epochs

    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    elif epoch > 80:
        lr *= 1e-1
    return lr

def resnet_layer(inputs,
                 num_filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu',
                 batch_normalization=True,
                 conv_first=True):
    """2D Convolution-Batch Normalization-Activation stack builder

    # Arguments
        inputs (tensor): input tensor from input image or previous layer
        num_filters (int): Conv2D number of filters
        kernel_size (int): Conv2D square kernel dimensions
        strides (int): Conv2D square stride dimensions
        activation (string): activation name
        batch_normalization (bool): whether to include batch normalization
        conv_first (bool): conv-bn-activation (True) or
            bn-activation-conv (False)

    # Returns
        x (tensor): tensor as input to the next layer
    """
    conv = Conv2D(num_filters,
                  kernel_size=kernel_size,
                  strides=strides,
                  padding='same',
                  kernel_initializer='he_normal',
                  kernel_regularizer=l2(1e-4))

    x = inputs
    if conv_first:
        x = conv(x)
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
    else:
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x

def resnet(input_shape, depth, num_classes=10):
    """ResNet Version 1 Model builder
    
    # Arguments
        input_shape (tensor): shape of input image tensor
        depth (int): number of core convolutional layers
        num_classes (int): number of classes (CIFAR10 has 10)

    # Returns
        model (Model): Keras model instance
    """
    if (depth - 2) % 6 != 0:
        raise ValueError('depth should be 6n+2 (eg 20, 32, 44 in [a])')
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape=input_shape)
    x = resnet_layer(inputs=inputs)
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:
                strides = 2  # downsample
            y = resnet_layer(inputs=x,
                             num_filters=num_filters,
                             strides=strides)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters,
                             activation=None)
            if stack > 0 and res_block == 0:
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = keras.layers.add([x, y])
            x = Activation('relu')(x)
        num_filters *= 2

    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes,
                    activation='softmax',
                    kernel_initializer='he_normal')(y)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# defind optimizer
optimizer = Adam(learning_rate=lr_schedule(0))
metric = ['accuracy', 'categorical_crossentropy']
loss = CategoricalCrossentropy()

# Training Part

In [None]:
model = resnet(input_shape=input_shape, depth=depth)
model.summary()

# set optimizer
model.compile(loss=loss, optimizer=optimizer, metrics=metric)

# set checkpoint path
checkpoint_dir = '../../models/EXP3/'
checkpoint_path = checkpoint_dir + 'weights.{epoch:03d}.hdf5'
!rm -r $checkpoint_dir
!mkdir $checkpoint_dir

# set callbacks
lr_scheduler = LearningRateScheduler(lr_schedule)
checkpoint = ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, period=10)
callbacks = [lr_scheduler, checkpoint]

# lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
#                                cooldown=0,
#                                patience=5,
#                                min_lr=0.5e-6)

# callbacks = [checkpoint, lr_reducer, lr_scheduler]

# set data augmentation
datagen = ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    zca_epsilon=1e-06,
    rotation_range=0,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.,
    zoom_range=0.,
    channel_shift_range=0.,
    fill_mode='nearest',
    cval=0.,
    horizontal_flip=True,
    vertical_flip=False,
    rescale=None,
    preprocessing_function=None,
    data_format=None,
    validation_split=0.0)

datagen.fit(x_train)
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
                    validation_data=(x_val, y_val),
                    epochs=epochs,
                    callbacks=callbacks)

# save history
pickle_path = "../../logits/EXP3/hist.pkl"
pickle_out = open(pickle_path,"wb")
pickle.dump(hist.history, pickle_out)
pickle_out.close()

# Evaluation Part

In [None]:
pickle_path = "../../logits/EXP3/hist.pkl"
pickle_in = open(pickle_path,"rb")
hist = pickle.load(pickle_in)

## Plot Test Error & NLL

In [None]:
error = [(1 - i)*100 for i in hist['val_accuracy']]
nll = [i*20 + 5 for i in hist['val_categorical_crossentropy']]
plt.axis([0, len(error), 5, 25])
plt.plot(error)
plt.plot(nll)
plt.title('Test Error & NLL')
plt.ylabel('Error (%)/NLL (Scaled)')
plt.xlabel('epoch')
plt.legend(['Test Error', 'Test NLL'])
plt.savefig('../../logits/EXP3/plots/error_nll.png', dpi=400)

## Generate Logits

In [None]:
def evaluate_model(model, weights_file, x_test, y_test, x_val, y_val, bins = 15, pickle_file = None):
    """Evaluates the model, calculates the calibration errors and saves the logits
    
    Args:
        model (keras.model): constructed model
        weights (string): path to weights file
        x_test: (numpy.ndarray) with test data
        y_test: (numpy.ndarray) with test data labels
        verbose: (boolean) print out results or just return these
        pickle_file: (string) path to pickle probabilities given by model
        
    Returns:
        (acc, ece, mce): accuracy of model, ECE and MCE (calibration errors)
    """    
    # Change last activation to linear (instead of softmax)
    last_layer = model.layers.pop()
    last_layer.activation = keras.activations.linear
    i = model.input
    o = last_layer(model.layers[-2].output)
    model = keras.models.Model(inputs=i, outputs=[o])
    
    # First load in the weights
    model.load_weights(weights_file)
    model.compile(optimizer=optimizer, loss=loss)
    
    # If 1-hot representation, get back to numeric 
    if y_test.shape[1] > 1: 
        y_test = np.array([[np.where(r==1)[0][0]] for r in y_test])
    
    if y_val.shape[1] > 1: 
        y_val = np.array([[np.where(r==1)[0][0]] for r in y_val])


    # Next get predictions
    y_logits_test = model.predict(x_test, verbose=1)
    y_probs_test = softmax(y_logits_test)
    y_preds_test = np.argmax(y_probs_test, axis=1)
    
    # Confidence of prediction
    y_confs_test = np.max(y_probs_test, axis=1)  # Take only maximum confidence
    # Calculate Accuracy
    accuracy = metrics.accuracy_score(y_test, y_preds_test) * 100
    # Calculate ECE
    ece = ECE(y_confs_test, y_preds_test, y_val, bin_size = 1/bins)
    # Calculate MCE
    mce = MCE(y_confs_test, y_preds_test, y_val, bin_size = 1/bins)
    
    y_logits_val = model.predict(x_val)
    y_probs_val = softmax(y_logits_val)
    y_preds_val = np.argmax(y_probs_val, axis=1)
        
    # Pickle probabilities for test and validation
    if pickle_file:
        with open(pickle_file + '_logits.p', 'wb') as f:
            pickle.dump([(y_logits_val, y_val),(y_logits_test, y_test)], f)
    
    # Return the basic results
    return (accuracy, ece, mce)

In [None]:
cps = [10,30,50,70,100,150,200,300,400,500]
def evaluate_epoch(cp):
    cp_path = glob.glob('../../models/EXP3/weights.{:03d}.hdf5'.format(cp))
    if len(cp_path) != 1:
        print(cp_path)
        raise Exception('checkpoint name confusion')
    cp_path = cp_path[0]
    cp_name = cp_path.split('/')[-1]

    model = resnet(input_shape=input_shape, depth=depth)
    # model.summary()

    accuracy, ece, mce = evaluate_model(model, cp_path, x_test, y_test, x_val, y_val, bins = 15, 
                   pickle_file = '../../logits/EXP3/cp_' + str(cp))
    return accuracy, ece, mce

csv_path = '../../logits/EXP3/results.csv'

with open(csv_path, 'w', newline='') as csvfile:
    csv_writer = csv.writer(csvfile)
    csv_writer.writerow(['Epoch', 'Accuracy', 'ECE', 'MCE'])

    for cp in cps:
        print('[{} Epochs]\n'.format(cp))
        accuracy, ece, mce = evaluate_epoch(cp)
        csv_writer.writerow([str(cp), str(accuracy), str(ece), str(mce)])
        print("Accuracy:", accuracy)
        print("ECE:", ece)
        print("MCE:", mce)
        print('\n---------\n')