In [1]:
import tensorflow as tf 
from tensorflow import keras 
from tensorflow.keras import datasets 
import os 


gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.2)
config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)
session = tf.compat.v1.Session(config=config)

In [2]:
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

In [3]:
(x,y),(x_test,y_test)=datasets.mnist.load_data() 

In [4]:
x=tf.convert_to_tensor(x,dtype=tf.float32)/255. 

In [5]:
y=tf.convert_to_tensor(y,dtype=tf.int32) 
x_test=tf.convert_to_tensor(x_test,dtype=tf.float32)/255.0 
y_test=tf.convert_to_tensor(y_test,dtype=tf.int32) 

In [6]:
print(x.shape,y.shape,x.dtype,y.dtype) 

(60000, 28, 28) (60000,) <dtype: 'float32'> <dtype: 'int32'>


In [7]:
print(tf.reduce_min(x),tf.reduce_max(x)) 

tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32)


In [8]:
print(tf.reduce_min(y),tf.reduce_max(y)) 

tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)


In [9]:
train_db=tf.data.Dataset.from_tensor_slices((x,y)).batch(128) 
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(128) 
train_iter=iter(train_db) 
sample=next(train_iter) 
print("batch:",sample[0].shape,sample[1].shape) 

batch: (128, 28, 28) (128,)


In [13]:
w1=tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1=tf.Variable(tf.zeros([256])) 

w2=tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2=tf.Variable(tf.zeros([128])) 

w3=tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3=tf.Variable(tf.zeros([10])) 
lr=0.01

for epoch in range(100):
  for step,(x,y) in enumerate(train_db):
    x=tf.reshape(x,[-1,28*28]) 
    with tf.GradientTape() as tape:
      h1=x@w1+tf.broadcast_to(b1,[x.shape[0],256]) 
      h1=tf.nn.relu(h1) 
      h2=h1@w2+b2 
      h2=tf.nn.relu(h2) 
      out=h2@w3+b3 
      y_onehot=tf.one_hot(y,depth=10) 
      loss=tf.square(y_onehot-out) 
      loss=tf.reduce_mean(loss) 
    grads=tape.gradient(loss,[w1,b1,w2,b2,w3,b3]) 
    w1.assign_sub(lr*grads[0]) 
    b1.assign_sub(lr*grads[1]) 
    w2.assign_sub(lr*grads[2]) 
    b2.assign_sub(lr*grads[3]) 
    w3.assign_sub(lr*grads[4]) 
    b3.assign_sub(lr*grads[5])
    if step%100==0:
        total_correct, total_num = 0,0  
        for step,(x,y) in enumerate(test_db):
          x=tf.reshape(x,[-1,28*28]) 
          h1=tf.nn.relu(x@w1+b1)
          h2=tf.nn.relu(h1@w2+b2) 
          out=h2@w3+b3 
          prob = tf.nn.softmax(out,axis=1) 
          pred=tf.argmax(prob,axis=1)
          pred=tf.cast(pred,dtype=tf.int32)
          correct=tf.cast(tf.equal(pred,y),dtype=tf.int32)
          correct=tf.reduce_sum(correct) 
          total_correct+=int(correct) 
          total_num+=x.shape[0] 
        acc=total_correct/total_num 
        print(epoch,step,'loss:',float(loss),";test acc :",acc ) 
    

0 78 loss: 0.3927045166492462 ;test acc : 0.0689
0 78 loss: 0.13074807822704315 ;test acc : 0.2053
0 78 loss: 0.10684148967266083 ;test acc : 0.3248
0 78 loss: 0.0966406911611557 ;test acc : 0.4232
0 78 loss: 0.085474893450737 ;test acc : 0.4926
1 78 loss: 0.07621917873620987 ;test acc : 0.5289
1 78 loss: 0.07783162593841553 ;test acc : 0.5673
1 78 loss: 0.07259271293878555 ;test acc : 0.6003
1 78 loss: 0.07032232731580734 ;test acc : 0.6254
1 78 loss: 0.0694315955042839 ;test acc : 0.6459
2 78 loss: 0.06051882356405258 ;test acc : 0.6585
2 78 loss: 0.06519881635904312 ;test acc : 0.6733
2 78 loss: 0.062377315014600754 ;test acc : 0.6887
2 78 loss: 0.06039971113204956 ;test acc : 0.7007
2 78 loss: 0.06237991899251938 ;test acc : 0.7121
3 78 loss: 0.05312492698431015 ;test acc : 0.7165
3 78 loss: 0.058642417192459106 ;test acc : 0.7269
3 78 loss: 0.05676771327853203 ;test acc : 0.7378
3 78 loss: 0.054765015840530396 ;test acc : 0.7446
3 78 loss: 0.058030206710100174 ;test acc : 0.752
4 

32 78 loss: 0.026917720213532448 ;test acc : 0.8953
32 78 loss: 0.02728351019322872 ;test acc : 0.8964
32 78 loss: 0.028706630691885948 ;test acc : 0.8966
33 78 loss: 0.022611204534769058 ;test acc : 0.8961
33 78 loss: 0.02656487561762333 ;test acc : 0.8964
33 78 loss: 0.026587393134832382 ;test acc : 0.8969
33 78 loss: 0.02695661224424839 ;test acc : 0.8976
33 78 loss: 0.02836812101304531 ;test acc : 0.8975
34 78 loss: 0.022287271916866302 ;test acc : 0.8977
34 78 loss: 0.02625555731356144 ;test acc : 0.8977
34 78 loss: 0.02626574970781803 ;test acc : 0.8982
34 78 loss: 0.02663826383650303 ;test acc : 0.8984
34 78 loss: 0.02804647386074066 ;test acc : 0.899
35 78 loss: 0.021977528929710388 ;test acc : 0.899
35 78 loss: 0.025960305705666542 ;test acc : 0.8984
35 78 loss: 0.02595910057425499 ;test acc : 0.8991
35 78 loss: 0.026330390945076942 ;test acc : 0.8998
35 78 loss: 0.027739685028791428 ;test acc : 0.8993
36 78 loss: 0.021678129211068153 ;test acc : 0.8994
36 78 loss: 0.025675997

64 78 loss: 0.020466594025492668 ;test acc : 0.9185
64 78 loss: 0.020653704181313515 ;test acc : 0.9183
64 78 loss: 0.02233697846531868 ;test acc : 0.9186
65 78 loss: 0.016505492851138115 ;test acc : 0.9184
65 78 loss: 0.020994190126657486 ;test acc : 0.9185
65 78 loss: 0.02034853771328926 ;test acc : 0.919
65 78 loss: 0.020525138825178146 ;test acc : 0.9188
65 78 loss: 0.022220429033041 ;test acc : 0.9189
66 78 loss: 0.016402671113610268 ;test acc : 0.9188
66 78 loss: 0.020904744043946266 ;test acc : 0.919
66 78 loss: 0.02023285999894142 ;test acc : 0.9193
66 78 loss: 0.020398836582899094 ;test acc : 0.9194
66 78 loss: 0.022107992321252823 ;test acc : 0.9196
67 78 loss: 0.016304831951856613 ;test acc : 0.9193
67 78 loss: 0.020817067474126816 ;test acc : 0.9197
67 78 loss: 0.020116019994020462 ;test acc : 0.9196
67 78 loss: 0.02027280442416668 ;test acc : 0.9194
67 78 loss: 0.021997826173901558 ;test acc : 0.92
68 78 loss: 0.016207540407776833 ;test acc : 0.9198
68 78 loss: 0.020730972

96 78 loss: 0.017593957483768463 ;test acc : 0.9286
96 78 loss: 0.017666051164269447 ;test acc : 0.9288
96 78 loss: 0.019562486559152603 ;test acc : 0.9285
97 78 loss: 0.014236497692763805 ;test acc : 0.9287
97 78 loss: 0.01879177987575531 ;test acc : 0.9284
97 78 loss: 0.01752633973956108 ;test acc : 0.9288
97 78 loss: 0.01759621873497963 ;test acc : 0.9289
97 78 loss: 0.019496072083711624 ;test acc : 0.9287
98 78 loss: 0.014186844229698181 ;test acc : 0.9291
98 78 loss: 0.0187370702624321 ;test acc : 0.9287
98 78 loss: 0.017460156232118607 ;test acc : 0.9294
98 78 loss: 0.017527032643556595 ;test acc : 0.9292
98 78 loss: 0.01942979358136654 ;test acc : 0.9292
99 78 loss: 0.014138281345367432 ;test acc : 0.9292
99 78 loss: 0.018682602792978287 ;test acc : 0.9291
99 78 loss: 0.017394665628671646 ;test acc : 0.9294
99 78 loss: 0.01745874248445034 ;test acc : 0.9298
99 78 loss: 0.019364366307854652 ;test acc : 0.9295
