In [39]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, callbacks
from keras.datasets import cifar10
from keras import layers
from sklearn.model_selection import train_test_split

In [40]:
# class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#
# plt.figure(figsize=(15, 15))
# for i in range(20):
#     plt.subplot(6, 6, i + 1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(X_train[i])
#     plt.xlabel(class_names[y_train[i][0]])
#
# plt.show()

In [41]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
print(y_train.shape)

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=42)

(50000, 10)


In [42]:
batch_size = 64
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

In [43]:
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomTranslation(0.1, 0.1),
])

In [44]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y),
    num_parallel_calls=AUTOTUNE
).prefetch(AUTOTUNE)
val_ds = val_ds.prefetch(AUTOTUNE)
test_ds = test_ds.prefetch(AUTOTUNE)

In [45]:
def create_model():
    inputs = keras.Input(shape=(32, 32, 3)) # 32 x 32, 3 channels

    x = layers.Conv2D(32, 3, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.2)(x)

    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Dropout(0.4)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(10, activation='softmax')(x)

    ret = keras.Model(inputs=inputs, outputs=outputs)

    ret.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss = 'categorical_crossentropy',
        metrics = ['accuracy']
    )

    return ret

In [46]:
model = create_model()
model.summary()

In [47]:
callback_list = [
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.5,
        paticen=5,
        min_lr=1e-6
    ),
    callbacks.ModelCheckpoint(
        filepath='../best_cifar10_model.keras',
        monitor='val_accuracy',
        save_best_only=True
    )
]

In [48]:
epochs = 200
history = model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=callback_list
)

Epoch 1/200
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 116ms/step - accuracy: 0.2992 - loss: 2.2357 - val_accuracy: 0.5050 - val_loss: 1.3790 - learning_rate: 0.0010
Epoch 2/200
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 127ms/step - accuracy: 0.4616 - loss: 1.4980 - val_accuracy: 0.4938 - val_loss: 1.5152 - learning_rate: 0.0010
Epoch 3/200
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 159ms/step - accuracy: 0.5282 - loss: 1.3108 - val_accuracy: 0.5940 - val_loss: 1.1549 - learning_rate: 0.0010
Epoch 4/200
[1m464/704[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m41s[0m 174ms/step - accuracy: 0.5671 - loss: 1.2120

KeyboardInterrupt: 