In [2]:
import tensorflow as tf
from keras.datasets import mnist
from keras.layers import Input, Conv2D, MaxPool2D, Dense
from keras.layers import BatchNormalization, Dropout,Flatten
from keras.optimizers import RMSprop
import numpy as np
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255.0
x_test /= 255.0

x_train = np.expand_dims(x_train,axis=3)
x_test = np.expand_dims(x_test,axis=3)

y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

def create_cnn2d(input_shape, num_class =10):
    inputs =Input(shape= input_shape)
    x= Conv2D(filters= 16, kernel_size=(3,3),activation="relu")(inputs)
    x= BatchNormalization()(x)
    x= MaxPool2D()(x)
    x = Conv2D(filters= 32, kernel_size=(3,3), activation="relu")(x)
    x= MaxPool2D()(x)
    x =Dropout(rate=0.2)(x)
    x = Flatten()(x)
    
    outputs = tf.keras.layers.Dense(units=10, activation="softmax")(x)
    model= tf.keras.Model(inputs=inputs, outputs = outputs)
    return model

model = create_cnn2d(input_shape=x_train.shape[1:])

opt = tf.keras.optimizers.RMSprop(learning_rate=0.01)
model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["accuracy"])
ret = model.fit(x_train,y_train,epochs=100, batch_size= 400, verbose=0)

y_pred = model.predict(x_train)
y_label = np.argmax(y_pred,axis=1)
C= tf.math.confusion_matrix(np.argmax(y_train,axis=1),y_label)
print("confusion_matrix:",C)

train_loss, train_acc = model.evaluate(x_train,y_train,verbose=0)
test_loss, test_acc = model.evaluate(x_test,y_test,verbose=0)


confusion_matrix: tf.Tensor(
[[5923    0    0    0    0    0    0    0    0    0]
 [   0 6739    0    0    0    0    0    1    1    1]
 [   0    0 5956    0    1    0    0    0    0    1]
 [   1    0    3 6111    0    5    0    1    6    4]
 [   0    1    0    0 5837    0    0    1    2    1]
 [   1    0    0    0    0 5416    3    0    1    0]
 [   2    0    0    0    1    0 5914    0    1    0]
 [   0    4    5    0    0    0    0 6256    0    0]
 [   2    1    2    0    2    0    1    0 5843    0]
 [   1    2    0    0   10    0    0    2    3 5931]], shape=(10, 10), dtype=int32)
