In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics


In [5]:
mnistx = np.load('../data/mnist.npz')
x,y = mnistx['x_train'],mnistx['y_train']
x_val,y_val = mnistx['x_test'],mnistx['y_test']
print('datasets:',x.shape,y.shape,x.min(),x.max())

datasets: (60000, 28, 28) (60000,) 0 255


In [6]:
batch_size =32

In [8]:
xs = tf.convert_to_tensor(x,dtype=tf.float32)/255.
db = tf.data.Dataset.from_tensor_slices((xs,y))
db = db.batch(batch_size).repeat(30)

In [9]:
model = Sequential([layers.Dense(256,activation='relu'),
                   layers.Dense(128,activation='relu'),
                   layers.Dense(10)])
model.build(input_shape=(4,28*28))
model.summary()

optimizer = optimizers.SGD(lr=0.01)
acc_meter = metrics.Accuracy()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  200960    
_________________________________________________________________
dense_1 (Dense)              multiple                  32896     
_________________________________________________________________
dense_2 (Dense)              multiple                  1290      
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
_________________________________________________________________


In [10]:
for step,(x,y) in enumerate(db):
    with tf.GradientTape() as tape:
        
        x = tf.reshape(x,(-1,28*28))
        
        out = model(x)
        
        y_onehot = tf.one_hot(y,depth=10)
        
        loss = tf.square(out-y_onehot)
        
        loss = tf.reduce_sum(loss) / x.shape[0]
        
    acc_meter.update_state(tf.argmax(out,axis=1),y)
    
    grads = tape.gradient(loss,model.trainable_variables)
    optimizer.apply_gradients(zip(grads,model.trainable_variables))
    
    if step % 200 == 0:
        print(step,'loss:',float(loss),'acc:',acc_meter.result().numpy())
        acc_meter.reset_states()

0 loss: 1.840341329574585 acc: 0.03125
200 loss: 0.4670301377773285 acc: 0.66734374
400 loss: 0.3423806130886078 acc: 0.8389062
600 loss: 0.36635205149650574 acc: 0.8565625
800 loss: 0.28271153569221497 acc: 0.8871875
1000 loss: 0.3174542784690857 acc: 0.885625
1200 loss: 0.3056131601333618 acc: 0.9039062
1400 loss: 0.22257688641548157 acc: 0.9146875
1600 loss: 0.22300714254379272 acc: 0.9078125
1800 loss: 0.21345525979995728 acc: 0.9253125
2000 loss: 0.2243383526802063 acc: 0.93875
2200 loss: 0.15718558430671692 acc: 0.929375
2400 loss: 0.27255016565322876 acc: 0.9278125
2600 loss: 0.22062839567661285 acc: 0.9351562
2800 loss: 0.14155571162700653 acc: 0.9326562
3000 loss: 0.2083386927843094 acc: 0.9334375
3200 loss: 0.1914547085762024 acc: 0.93703127
3400 loss: 0.13000856339931488 acc: 0.93421876
3600 loss: 0.12920421361923218 acc: 0.935
3800 loss: 0.1752210110425949 acc: 0.9529688
4000 loss: 0.22880887985229492 acc: 0.9475
4200 loss: 0.16257014870643616 acc: 0.93890625
4400 loss: 0.1

35600 loss: 0.02910851687192917 acc: 0.98296875
35800 loss: 0.08281341195106506 acc: 0.9815625
36000 loss: 0.06775905191898346 acc: 0.97859377
36200 loss: 0.09053294360637665 acc: 0.98125
36400 loss: 0.04121784120798111 acc: 0.9803125
36600 loss: 0.034611914306879044 acc: 0.9792187
36800 loss: 0.03272838890552521 acc: 0.97984374
37000 loss: 0.04326413571834564 acc: 0.97875
37200 loss: 0.18887025117874146 acc: 0.9764063
37400 loss: 0.07659545540809631 acc: 0.98
37600 loss: 0.08551669120788574 acc: 0.9853125
37800 loss: 0.042726244777441025 acc: 0.97984374
38000 loss: 0.06473325937986374 acc: 0.9817188
38200 loss: 0.03706770017743111 acc: 0.980625
38400 loss: 0.0334138385951519 acc: 0.9790625
38600 loss: 0.06359656155109406 acc: 0.9814063
38800 loss: 0.08631355315446854 acc: 0.9790625
39000 loss: 0.07115783542394638 acc: 0.97828126
39200 loss: 0.025479109957814217 acc: 0.9784375
39400 loss: 0.029756061732769012 acc: 0.98484373
39600 loss: 0.07847161591053009 acc: 0.9828125
39800 loss: 0.