# CIFAR pre-trained model prediction

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.layers import Flatten, Dense, Dropout, Convolution2D, Activation, GlobalAveragePooling2D, Rescaling,Conv2D, BatchNormalization, MaxPooling2D

from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.models import Sequential, Model,load_model
from tensorflow.keras.layers.experimental import preprocessing

from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt


import numpy as np

### Download CIFAR dataset

In [None]:
builder, ds_info = tfds.load('cifar10',
                             split=["train", "test[:40%]", "test[40%:]"],
                             as_supervised=True, 
                             with_info=True)

In [None]:
ds_info.features

In [None]:
ds_train = builder[0]
ds_validation = builder[1]
ds_test = builder[2]

print('train size: ' + str(len(ds_train)))
print('validation size: ' + str(len(ds_validation)))
print('test size: ' + str(len(ds_test)))

In [None]:
for image, label in ds_train.take(1):
    plt.imshow(image)
    print(int(label))

In [None]:
ds_info.features['image']

### Verify data

In [None]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
plt.figure(figsize=(10,10))
i = 0
for image, label in ds_train.take(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(image.numpy().astype('uint8'))
    # The CIFAR labels happen to be arrays, 
    # which is why you need the extra index
    plt.xlabel(class_names[label])
    i = i+1
plt.show()

## Define model

In [None]:
NUM_CLASSES = ds_info.features['label'].num_classes
BATCH_SIZE = 32

In [None]:
data_augmentation = Sequential(
  [
    preprocessing.RandomFlip("horizontal", input_shape=(32,32,3)),
    preprocessing.RandomRotation(0.1),
    preprocessing.RandomZoom(0.1),
  ]
)

**Define model**

In [None]:
model = Sequential()

#Data augmentation to reduce variance
#model.add(data_augmentation)

#Standarize the image 
model.add(Rescaling(scale=1./127.5, offset=-1, input_shape=(32,32,3)))

model.add(Conv2D(32, kernel_size=3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size=3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling2D())
model.add(Dropout(0.2))

model.add(Conv2D(64, kernel_size=3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size=3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling2D())
model.add(Dropout(0.3))

model.add(Conv2D(128, kernel_size=3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(Conv2D(128, kernel_size=3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling2D())
model.add(Dropout(0.4))

model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))
model.add(Dense(NUM_CLASSES, activation='softmax'))

model.summary()

### Train model

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

ds_train = ds_train.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds_validation = ds_validation.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds_test = ds_test.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

**Define callbacks to improve the training loop**

In [None]:
def lr_schedule(epoch):
    lr = 1e-3
    if (epoch > 10):
        lr *= 0.01
    elif (epoch > 3):
        lr *= 0.1
    return lr

#lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=2, factor=0.1, min_delta=0.001)

In [None]:
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True)

In [None]:
epochs = 15
history = model.fit(ds_train, epochs=epochs, validation_data=ds_validation, callbacks=[lr_callback, es])

### Plot training

In [None]:
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

## Evaluate model

In [None]:
 model.evaluate(ds_test)

In [None]:
y_pred = model.predict(ds_test)
y_test = np.concatenate([y for x, y in ds_test], axis=0)
y_pred_classes = np.argmax(y_pred,axis = 1)
confusion_mtx = tf.math.confusion_matrix(y_test, y_pred_classes)

In [None]:
import seaborn as sns

plt.figure(figsize=(12, 9))
c = sns.heatmap(confusion_mtx, annot=True, fmt='g')
c.set(xticklabels=class_names, yticklabels=class_names)

### Save model

In [None]:
model.save("models/cifar_cnn")