Building an image classifier for the Fashion MNIST dataset that achives a better accuracy by using a convolutional network.

In [51]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import TensorBoard

print(tf.__version__)

1.12.0


In [None]:
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)

train_images.shape

In [None]:
model = keras.Sequential([
    layers.Conv2D(6,(6, 6), strides=1, input_shape=(28, 28, 1), padding="same"),
    layers.BatchNormalization(),
    layers.Activation("relu"),
    
    layers.Conv2D(12, (5, 5), strides=2, padding="same"),
    layers.BatchNormalization(),
    layers.Activation("relu"),
    
    layers.Conv2D(24, (4, 4), strides=2, padding="same"),
    layers.BatchNormalization(),
    layers.Activation("relu"),
    
    layers.Flatten(),
    layers.Dense(200, activation=tf.nn.relu),
    layers.Dropout(0.75),
    layers.Dense(10, activation=tf.nn.softmax)
])

print(model.input_shape)
model.summary()

In [None]:
model.compile(
    optimizer=tf.train.AdamOptimizer(learning_rate=0.003),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy', 'sparse_categorical_crossentropy']
)

In [None]:
tensorBoard = TensorBoard()

history = model.fit(train_images, train_labels, 
                    epochs=20,
                    validation_data=(test_images, test_labels),
                    verbose=2,
                    callbacks=[tensorBoard])

In [None]:
import matplotlib.pyplot as plt

In [None]:
acc = history.history['acc']
loss = history.history['loss']
val_acc = history.history['val_acc']
val_loss = history.history['val_loss']

epoch = [x+1 for x in history.epoch]

plt.plot(epoch, acc, 'bo')
plt.plot(epoch, val_acc, 'r')
plt.show()

In [None]:
plt.cla
plt.plot(epoch, loss, 'bo')
plt.plot(epoch, val_loss, 'r')
plt.show()