In [16]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers , optimizers , datasets ,Sequential ,metrics
(x, y) , (x_test , y_test) = datasets.mnist.load_data()

In [17]:
def preprocess(x, y):
    x = tf.cast(x , dtype= tf.float32) /255.
    y = tf.cast(y,dtype=tf.int32)
    return x, y

In [18]:
batchsz = 128

In [19]:
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(10000).batch(batchsz)

In [20]:
db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test = db_test.map(preprocess).batch(batch_size=batchsz)

In [22]:
db_iter = iter(db)
sample = next(db_iter)
print('batch:' , sample[0].shape ,sample[1].shape)

batch: (128, 28, 28) (128,)


In [23]:
model = Sequential([
    layers.Dense(256, tf.nn.relu),
    layers.Dense(128, tf.nn.relu),
    layers.Dense(64, tf.nn.relu),
    layers.Dense(32, tf.nn.relu),
    layers.Dense(10)   
])
model.build(input_shape=[None ,28*28])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 256)               200960    
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_3 (Dense)              (None, 32)                2080      
_________________________________________________________________
dense_4 (Dense)              (None, 10)                330       
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________


In [24]:
optimizer = optimizers.Adam(lr=1e-3)

In [None]:
for epoch in range(30):
    for step, (x, y) in enumerate(db):
        x = tf.reshape(x, [-1, 28 * 28])
        with tf.GradientTape() as tape:
            logits = model(x)
            y_onehot = tf.one_hot(y, depth=10)
            loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
            loss_ce = tf.losses.categorical_crossentropy(y_onehot,
                                                         logits,
                                                         from_logits=True)
            loss_ce = tf.reduce_mean(loss_ce)

        grads = tape.gradient(loss_ce, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 == 0:
            print(epoch, step, 'loss:', float(loss_ce), float(loss_mse))

            #             #test
    total_correct = 0
    total_num = 0
    for x, y in db_test:
        x = tf.reshape(x, [-1, 28 * 28])
        logits = model(x)
        prob = tf.nn.softmax(logits, axis=1)
        pred = tf.argmax(prob, axis=1)
        pred = tf.cast(pred, dtype=tf.int32)
        correct = tf.equal(pred, y)
        correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))

        total_correct += int(correct)
        total_num += x.shape[0]
    acc = total_correct / total_num    
    print(epoch,'test acc' ,acc)
    

0 0 loss: 0.03167692571878433 29.402713775634766
0 100 loss: 0.019802510738372803 36.26633834838867
0 200 loss: 0.04704906791448593 37.99005126953125
0 300 loss: 0.022478271275758743 32.15018844604492
0 400 loss: 0.11422568559646606 35.02384948730469
0 test acc 0.9737
1 0 loss: 0.031096648424863815 29.94021987915039
1 100 loss: 0.008172761648893356 38.496604919433594
1 200 loss: 0.0023693502880632877 41.98213195800781
1 300 loss: 0.007467607501894236 42.995140075683594
1 400 loss: 0.010752130299806595 38.128326416015625
1 test acc 0.9767
2 0 loss: 0.05066175386309624 32.60882568359375
2 100 loss: 0.007179488427937031 43.025115966796875
2 200 loss: 0.022776931524276733 53.033424377441406
2 300 loss: 0.10913066565990448 47.601478576660156
2 400 loss: 0.010676905512809753 52.85319519042969
2 test acc 0.9794
3 0 loss: 0.04869420453906059 48.71522903442383
3 100 loss: 0.016710542142391205 45.26734161376953
3 200 loss: 0.004624741617590189 47.733802795410156
3 300 loss: 0.012886697426438332 

29 400 loss: 0.00012383046851027757 109.33740234375
29 test acc 0.9805
