# CIFAR pre-trained model prediction

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.layers import Flatten, Dense, Dropout, Convolution2D, Activation, GlobalAveragePooling2D, Rescaling

from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.models import Sequential, Model,load_model
from tensorflow.keras.layers.experimental import preprocessing

from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt


import numpy as np

### Download CIFAR dataset

In [None]:
builder, ds_info = tfds.load('cifar10',
                             split=["train", "test[:40%]", "test[40%:]"],
                             as_supervised=True, 
                             with_info=True)

In [None]:
ds_info.features

In [None]:
ds_train = builder[0]
ds_validation = builder[1]
ds_test = builder[2]

print('train size: ' + str(len(ds_train)))
print('validation size: ' + str(len(ds_validation)))
print('test size: ' + str(len(ds_test)))

In [None]:
for image, label in ds_train.take(1):
    plt.imshow(image)
    print(int(label))

In [None]:
ds_info.features['image'].shape

In [None]:
input_size = 80
size = (input_size, input_size)

ds_train = ds_train.map(lambda x, y: (tf.image.resize(x, size, method=tf.image.ResizeMethod.AREA), y))
ds_validation = ds_validation.map(lambda x, y: (tf.image.resize(x, size, method=tf.image.ResizeMethod.AREA), y))
ds_test = ds_test.map(lambda x, y: (tf.image.resize(x, size, method=tf.image.ResizeMethod.AREA), y))

### Verify data

In [None]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
plt.figure(figsize=(10,10))
i = 0
for image, label in ds_train.take(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(image.numpy().astype('uint8'))
    # The CIFAR labels happen to be arrays, 
    # which is why you need the extra index
    plt.xlabel(class_names[label])
    i = i+1
plt.show()

## Define model

In [None]:
NUM_CLASSES = ds_info.features['label'].num_classes
BATCH_SIZE = 64

**Load pre-trained model**

In [None]:
base_model = tf.keras.applications.xception.Xception(weights='imagenet', input_shape=(input_size, input_size, 3), classes = NUM_CLASSES, include_top=False)
#base_model = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(weights='imagenet', input_shape=(input_size, input_size, 3), classes = NUM_CLASSES, include_top=False)
#base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights='imagenet', input_shape=(input_size, input_size, 3), include_top = False, classes = NUM_CLASSES)
#base_model = tf.keras.applications.resnet_v2.ResNet50V2(weights='imagenet', input_shape=(input_size, input_size, 3), include_top = False, classes = NUM_CLASSES)

In [None]:
base_model.trainable = False
base_model.summary()

**Define model**

In [None]:
data_augmentation = Sequential(
  [
    preprocessing.RandomFlip("horizontal"),
    preprocessing.RandomRotation(0.1),
    preprocessing.RandomZoom(0.1),
  ]
)

In [None]:
inputs = tf.keras.Input(shape=(input_size, input_size, 3))

#x = data_augmentation(inputs)


x = tf.keras.applications.xception.preprocess_input(inputs)
#x = tf.keras.applications.inception_resnet_v2.preprocess_input(inputs)
#x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
#x = tf.keras.applications.resnet_v2.preprocess_input(inputs)


# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(x, training=False)
#x = base_model(inputs, training=False)


x = GlobalAveragePooling2D()(x)


x = Dropout(0.3)(x)
outputs = Dense(NUM_CLASSES, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)
model.summary()

### Train model

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

ds_train = ds_train.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds_validation = ds_validation.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds_test = ds_test.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

In [None]:
def lr_schedule(epoch):
    lr = 1e-3
    if (epoch > 10):
        lr *= 0.01
    elif (epoch > 3):
        lr *= 0.1
    return lr

#lr_callback = tensorflow.keras.callbacks.LearningRateScheduler(lr_schedule)

lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=2, factor=0.1, min_delta=0.001)

In [None]:
epochs = 20
history = model.fit(ds_train, epochs=epochs, validation_data=ds_validation, callbacks=[lr_callback])

### Plot training

In [None]:
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

## Evaluate model

In [None]:
 model.evaluate(ds_test)

In [None]:
y_pred = model.predict(ds_test)
y_test = np.concatenate([y for x, y in ds_test], axis=0)
y_pred_classes = np.argmax(y_pred,axis = 1)
confusion_mtx = tf.math.confusion_matrix(y_test, y_pred_classes)

In [None]:
import seaborn as sns

plt.figure(figsize=(12, 9))
c = sns.heatmap(confusion_mtx, annot=True, fmt='g')
c.set(xticklabels=class_names, yticklabels=class_names)

### Save model

In [None]:
model.save("models/cifar_xception")