# Loading dataset

In [37]:
#Sequential API
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Layer, Dense, Flatten, Conv2D, MaxPooling2D, BatchNormalization, Input
from tensorflow.keras.datasets import mnist
from tensorflow.keras import Sequential, Model

In [39]:
fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist
X_train, y_train = X_train_full[:5000], y_train_full[:5000]
X_valid, y_valid = X_train_full[5000:], y_train_full[5000:]

In [40]:
X_train = X_train / 255.
X_valid = X_valid / 255.

# Sequential API

In [43]:
model = keras.Sequential(
    [
        Input(shape=(32, 32, 3)),
        Conv2D(32, 3, padding="valid", activation='relu'),
        MaxPooling2D(),
        Conv2D(64, 3, activation="relu"),
        MaxPooling2D(),
        Conv2D(128, 3, activation="relu"),
        Flatten(),
        Dense(64, activation="relu"),
        Dense(10),
    ]
)

# Functional API

In [47]:
from tensorflow.keras.activations import relu

In [49]:
def my_model():
    inputs = Input(shape=(28, 28, 1))
    x = Conv2D(32, 3)(inputs)
    x = BatchNormalization()(x)
    x = relu(x)
    x = MaxPooling2D()(x)
    x = Conv2D(64, 3)(x)
    x = BatchNormalization()(x)
    x = relu(x)
    x = MaxPooling2D()(x)
    x = Conv2D(128, 3)(x)
    x = BatchNormalization()(x)
    x = relu(x)
    x = Flatten()(x)
    x = Dense(64, activation="relu")(x)
    outputs = Dense(10)(x)
    model = Model(inputs=inputs, outputs=outputs)
    return model

In [51]:
model = my_model()

In [53]:
model.summary()

In [55]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(0.001),
    metrics=["accuracy"],
)

In [59]:
history = model.fit(X_train, y_train, batch_size=64, validation_data = (X_valid, y_valid), epochs=10, verbose=2)

Epoch 1/10
79/79 - 17s - 214ms/step - accuracy: 0.9530 - loss: 0.1393 - val_accuracy: 0.7867 - val_loss: 0.6357
Epoch 2/10
79/79 - 21s - 265ms/step - accuracy: 0.9670 - loss: 0.1030 - val_accuracy: 0.8362 - val_loss: 0.4748
Epoch 3/10
79/79 - 27s - 339ms/step - accuracy: 0.9756 - loss: 0.0775 - val_accuracy: 0.8596 - val_loss: 0.4327
Epoch 4/10
79/79 - 27s - 342ms/step - accuracy: 0.9846 - loss: 0.0538 - val_accuracy: 0.8547 - val_loss: 0.4915
Epoch 5/10
79/79 - 17s - 217ms/step - accuracy: 0.9920 - loss: 0.0315 - val_accuracy: 0.8506 - val_loss: 0.5534
Epoch 6/10
79/79 - 20s - 249ms/step - accuracy: 0.9908 - loss: 0.0328 - val_accuracy: 0.8454 - val_loss: 0.5767
Epoch 7/10
79/79 - 22s - 279ms/step - accuracy: 0.9954 - loss: 0.0204 - val_accuracy: 0.8572 - val_loss: 0.5365
Epoch 8/10
79/79 - 29s - 363ms/step - accuracy: 0.9980 - loss: 0.0122 - val_accuracy: 0.8643 - val_loss: 0.5578
Epoch 9/10
79/79 - 21s - 263ms/step - accuracy: 1.0000 - loss: 0.0053 - val_accuracy: 0.8651 - val_loss:

In [None]:
import pandas as pd

pd.DataFrame(history.history).plot()