# Train and Evaluate ResNet 50 on CIFAR 10

## Import

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D, Dense, Input, add, Activation, GlobalAveragePooling2D
from tensorflow.keras.callbacks import LearningRateScheduler, TensorBoard, ModelCheckpoint
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers, regularizers
from sklearn.model_selection import train_test_split
import pickle
import matplotlib.pyplot as plt

import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath("utility"))))
from utility.evaluation import evaluate_model
import glob
import csv

In [None]:
# solves out of memory issue
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)

## Parameter

In [None]:
stack_n            = 8           
num_classes        = 10
img_rows, img_cols = 32, 32
img_channels       = 3
batch_size         = 512
epochs             = 500
iterations         = 45000 // batch_size
weight_decay       = 0.0001
seed = 333

label_smoothing_r  = 0.1

## Model

In [None]:
def scheduler(epoch):
    if epoch < 250:
        return 0.1
    if epoch < 375:
        return 0.01
    return 0.001

def residual_network(img_input,classes_num=10,stack_n=8):
    def residual_block(intput,out_channel,increase=False):
        if increase:
            stride = (2,2)
        else:
            stride = (1,1)

        pre_bn   = BatchNormalization()(intput)
        pre_relu = Activation('relu')(pre_bn)

        conv_1 = Conv2D(out_channel,kernel_size=(3,3),strides=stride,padding='same',
                        kernel_initializer="he_normal",
                        kernel_regularizer=regularizers.l2(weight_decay))(pre_relu)
        bn_1   = BatchNormalization()(conv_1)
        relu1  = Activation('relu')(bn_1)
        conv_2 = Conv2D(out_channel,kernel_size=(3,3),strides=(1,1),padding='same',
                        kernel_initializer="he_normal",
                        kernel_regularizer=regularizers.l2(weight_decay))(relu1)
        if increase:
            projection = Conv2D(out_channel,
                                kernel_size=(1,1),
                                strides=(2,2),
                                padding='same',
                                kernel_initializer="he_normal",
                                kernel_regularizer=regularizers.l2(weight_decay))(intput)
            block = add([conv_2, projection])
        else:
            block = add([intput,conv_2])
        return block
    
    # input: 32x32x3 output: 32x32x16
    x = Conv2D(filters=16,kernel_size=(3,3),strides=(1,1),padding='same',
               kernel_initializer="he_normal",
               kernel_regularizer=regularizers.l2(weight_decay))(img_input)

    # input: 32x32x16 output: 32x32x16
    for _ in range(stack_n):
        x = residual_block(x,16,False)

    # input: 32x32x16 output: 16x16x32
    x = residual_block(x,32,True)
    for _ in range(1,stack_n):
        x = residual_block(x,32,False)
    
    # input: 16x16x32 output: 8x8x64
    x = residual_block(x,64,True)
    for _ in range(1,stack_n):
        x = residual_block(x,64,False)

    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = GlobalAveragePooling2D()(x)

    # input: 64 output: 10
    x = Dense(classes_num,activation='softmax',
              kernel_initializer="he_normal",
              kernel_regularizer=regularizers.l2(weight_decay))(x)
    return x

## Prepare Dataset

In [None]:
(x_train_50, y_train_50), (x_test_10, y_test_10) = cifar10.load_data()
y_train_50 = keras.utils.to_categorical(y_train_50, num_classes)
y_test_10 = keras.utils.to_categorical(y_test_10, num_classes)

x_train_45, x_val_5, y_train_45, y_val_5 = train_test_split(x_train_50, y_train_50, test_size=0.1, random_state=seed)

img_mean = x_train_45.mean(axis=0)
img_std = x_train_45.std(axis=0)
x_train_45 = (x_train_45-img_mean)/img_std
x_val_5 = (x_val_5-img_mean)/img_std
x_test_10 = (x_test_10-img_mean)/img_std

## Train

In [None]:
# build network
img_input = Input(shape=(img_rows,img_cols,img_channels))
output    = residual_network(img_input,num_classes,stack_n)
resnet    = Model(img_input, output)
# print(resnet.summary())

# set optimizer
optimizer = optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
metrics = ['accuracy', 'categorical_crossentropy']

if label_smoothing_r > 0:
    loss = CategoricalCrossentropy(label_smoothing=label_smoothing_r)
else:
    loss = CategoricalCrossentropy()
resnet.compile(loss=loss, optimizer=optimizer, metrics=metrics)

# set checkpoint
checkpoint_dir = '../../models/EXP2/ls_{}/'.format(str(int(label_smoothing_r*10)))
checkpoint_path = checkpoint_dir + 'weights.{epoch:03d}.hdf5'

print('writing checkpoints to ' + checkpoint_dir)

!rm -r $checkpoint_dir
!mkdir $checkpoint_dir

cp_callback = ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, period=5)

# set callback
cbks = [LearningRateScheduler(scheduler), cp_callback]

# set data augmentation
print('Using real-time data augmentation.')
datagen = ImageDataGenerator(horizontal_flip=True,
                             width_shift_range=0.125,
                             height_shift_range=0.125,
                             fill_mode='constant',cval=0.)

datagen.fit(x_train_45)

# start training
hist = resnet.fit_generator(datagen.flow(x_train_45, y_train_45, batch_size=batch_size),
                     steps_per_epoch=iterations,
                     epochs=epochs,
                     callbacks=cbks,
                     validation_data=(x_val_5, y_val_5))

# print("Get test accuracy:")
# loss, accuracy = resnet.evaluate(x_test, y_test, verbose=0)
# print("Test: accuracy1 = %f  ;  loss1 = %f" % (accuracy, loss))

# save history
pickle_path = "../../logits/EXP2/hist_{}.pkl".format(str(int(label_smoothing_r*10)))
pickle_out = open(pickle_path,"wb")
pickle.dump(hist.history, pickle_out)
pickle_out.close()

## Evaluate

In [None]:
ls_rate = [0, 0.1]
history_list = {}
for rate in ls_rate:    
    pickle_path = "../../logits/EXP2/hist_{}.pkl".format(str(int(rate*10)))
    pickle_in = open(pickle_path,"rb")
    hist = pickle.load(pickle_in)
    history_list[rate] = hist

### Plot Test Error & NLL

In [None]:
ls_r = 0.1
hist = history_list[ls_r]

error = [(1 - i)*100 for i in hist['val_accuracy']]
nll = [i*20 + 5 for i in hist['val_categorical_crossentropy']]
plt.axis([0, len(error), 5, 30])
plt.plot(error)
plt.plot(nll)
plt.title('Test Error & NLL (label smoothing rate = {})'.format(ls_r))
plt.ylabel('Error (%)/NLL (Scaled)')
plt.xlabel('epoch')
plt.legend(['Test Error', 'Test NLL'])
plt.savefig('../../logits/EXP2/plots/error_nll_{}.png'.format(str(int(ls_r*10))), dpi=400)

### Plot Test Error at different Label Smoothing Rate

In [None]:
for i in history_list:
    error = [(1 -j)*100 for j in history_list[i]['val_accuracy']]
    plt.plot(error)

plt.axis([0, len(error), 5, 30])
plt.title('Test Error at different Label Smoothing Rate')
plt.ylabel('Error (%)')
plt.xlabel('epoch')
plt.legend(['no label smoothing', 'label smoothing rate=0.1'])
plt.savefig('../../logits/EXP2/plots/error.png', dpi=400)

### Plot NLL at different Label Smoothing Rate

In [None]:
for i in history_list:
    nll = [j for j in history_list[i]['val_categorical_crossentropy']]
    plt.plot(nll)

plt.axis([0, len(error), 0.2, 1.5])
plt.title('Test NLL at different Label Smoothing Rate')
plt.ylabel('NLL')
plt.xlabel('epoch')
plt.legend(['no label smoothing', 'label smoothing rate=0.1'])
plt.savefig('../../logits/EXP2/plots/nll.png', dpi=400)

### ECE, MCE and Reliability Diagram

In [None]:
eva_lr_rate = 0.1
eva_lr_rate_str = str(int(eva_lr_rate*10))
print('evaluating model with label smoothing value of ' + str(eva_lr_rate))
cps = [5,10,30,50,70,100,150,200,300,400,500]
def evaluate_epoch(cp):
    cp_path = glob.glob('../../models/EXP2/ls_{}/weights.{:03d}.hdf5'.format(eva_lr_rate_str, cp))
    if len(cp_path) != 1:
        print(cp_path)
        raise Exception('checkpoint name confusion')
    cp_path = cp_path[0]
    cp_name = cp_path.split('/')[-1]

    img_input = Input(shape=(img_rows,img_cols,img_channels))
    output = residual_network(img_input,num_classes,stack_n)
    resnet_random = Model(img_input, output)

    accuracy, ece, mce = evaluate_model(resnet_random, cp_path, x_test_10, y_test_10, bins = 15, verbose = True, 
                   pickle_file = '../../logits/EXP2/hist_{}/cp_'.format(eva_lr_rate_str) + str(cp), x_val = x_val_5, y_val = y_val_5)
    return accuracy, ece, mce

csv_path = '../../logits/EXP2/hist_{}/results.csv'.format(eva_lr_rate_str)

with open(csv_path, 'w', newline='') as csvfile:
    csv_writer = csv.writer(csvfile)
    csv_writer.writerow(['Epoch', 'Accuracy', 'ECE', 'MCE'])

    for cp in cps:
        print('[{} Epochs]\n'.format(cp))
        accuracy, ece, mce = evaluate_epoch(cp)
        csv_writer.writerow([str(cp), str(accuracy), str(ece), str(mce)])
        print('\n---------\n')