In [66]:
import tensorflow as tf
import numpy as np

class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_flag), (self.test_data, self.test_flag) = mnist.load_data()
        
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis = -1)
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis = -1)
        
        self.train_flag = self.train_flag.astype(np.int32)
        self.test_flag = self.test_flag.astype(np.int32)
        
        self.train_num = self.train_data.shape[0]
        self.test_num = self.test_data.shape[0]
        
        print(self.train_num, self.train_data.shape)
        
    def get_batch(self, batch_size):
        index = np.random.randint(0, self.train_data.shape[0], batch_size)
        return self.train_data[index, :], self.train_flag[index]

In [67]:
data = MNISTLoader()

60000 (60000, 28, 28, 1)


In [61]:
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units = 100, activation = tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units = 10)
        
    def call(self, input):
        x = self.flatten(input)
        x = self.dense1(x)
        x = self.dense2(x)
        output = tf.nn.softmax(x)
        return output

In [62]:
class CNN(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.conv1 = tf.keras.layers.Conv2D(
        filters = 32,
        kernel_size = [5,5],
        padding = "same",
        activation = tf.nn.relu)
        
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size = [2, 2], strides = 2)
        
        self.conv2 = tf.keras.layers.Conv2D(
        filters = 64,
        kernel_size = [5, 5],
        padding = "same",
        activation = tf.nn.relu)
        
        self.pool2 = tf.keras.layers.MaxPool2D(pool_size = [2, 2], strides = 2)
        
        self.flatten = tf.keras.layers.Reshape(target_shape = (7 * 7 * 64, ))
        
        self.dense1 = tf.keras.layers.Dense(units = 1024, activation = tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units = 10)
        
        
    def call(self, input):
        x = self.conv1(input)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        output = tf.nn.softmax(x)
        return output

In [63]:
num_epochs = 5
batch_size = 50
learning_rate = 0.001

module = CNN()
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate)

num_batchs = int(data_loader.train_num / batch_size * num_epochs)
for batch_index in range(10):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = module(X)
        print(X.shape)
        print(y.shape, y_pred.shape)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true = y, y_pred = y_pred)
        loss = tf.reduce_mean(loss)
        print("batch: %d, loss: %f" %(batch_index, loss.numpy()))
        
    grads = tape.gradient(loss, module.variables)
    optimizer.apply_gradients(grads_and_vars = zip(grads, module.variables))

60000 (60000,)
(50, 28, 28, 1)
(50,) (50, 10)
batch: 0, loss: 2.318369
(50, 28, 28, 1)
(50,) (50, 10)
batch: 1, loss: 2.234157
(50, 28, 28, 1)
(50,) (50, 10)
batch: 2, loss: 2.060000
(50, 28, 28, 1)
(50,) (50, 10)
batch: 3, loss: 1.955134
(50, 28, 28, 1)
(50,) (50, 10)
batch: 4, loss: 1.749707
(50, 28, 28, 1)
(50,) (50, 10)
batch: 5, loss: 1.478910
(50, 28, 28, 1)
(50,) (50, 10)
batch: 6, loss: 1.459552
(50, 28, 28, 1)
(50,) (50, 10)
batch: 7, loss: 0.956123
(50, 28, 28, 1)
(50,) (50, 10)
batch: 8, loss: 1.139385
(50, 28, 28, 1)
(50,) (50, 10)
batch: 9, loss: 0.964800


In [68]:
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batchs = int(data_loader.test_num / batch_size)
for batch_index in range(10):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = module.predict(data_loader.test_data[start_index:end_index])
    
sparse_categorical_accuracy.update_state(y_true = data_loader.test_flag[start_index:end_index], y_pred = y_pred)
print(data_loader.test_flag.shape)
print("test saccuracy: %f" % sparse_categorical_accuracy.result())

(10000,)
test saccuracy: 0.600000
