In [None]:
%tensorflow_version 2.x

import tensorflow as tf
from tensorflow.keras.datasets import mnist

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class Model(tf.Module):
  def __init__(self):
    super(Model,self).__init__()
    self.build()

  def build(self):
    # self.__w=tf.Variable(tf.random.truncated_normal((784,10))*0.1)  
    # self.__b=tf.Variable(tf.random.truncated_normal((10,)))
    self.__w=tf.Variable(tf.random.truncated_normal((784,10),stddev=0.01))
    self.__b=tf.Variable(tf.zeros((10,)),dtype=tf.float32)

  @tf.function
  def __call__(self,x):
    # h=tf.matmul(x,self.__w)-self.__b
    h=tf.add(tf.matmul(x,self.__w),self.__b)
    return tf.nn.softmax(h)

  @property
  def variables(self):
    return self.__w,self.__b


In [None]:
#y:观察值  y_hat:模型值
def loss_fun(y,y_hat):
  return tf.reduce_mean(tf.square(y-y_hat))

In [None]:
# 预测准确度
def accuracy(X,label):
  # return tf.reduce_mean(tf.cast(tf.equal(tf.argmax(X,1),tf.argmax(label,1)),tf.float32))
  return tf.reduce_mean(tf.cast(tf.equal(tf.argmax(X,axis=1),tf.argmax(label,axis=1)),tf.float32))

In [None]:
#自定义Model类实例训练函数
# l_r=0.1  #学习率
def train_Model(x,y,test_x,test_y):
  epochs=50
    
  for epoch in range(epochs):
    l_r=0.3-(0.3-0.01)/epochs*epoch  #学习率

    mnist_batch=tf.data.Dataset.from_tensor_slices((x,y)).shuffle(x.shape[0]).batch(300)
    loss_sum,n=0,0

    for x_train,y_train in mnist_batch:
      with tf.GradientTape() as tape:
        tape.watch(model.variables)
        y_hat=model(x_train)
        loss=loss_fun(y_train,y_hat)
      gradients=tape.gradient(loss,model.variables)
      model.variables[0].assign_sub(l_r*gradients[0])
      model.variables[1].assign_sub(l_r*gradients[1])
      loss_sum += loss.numpy()
      n += x_train.shape[0]

    if epoch % 10 ==0:
      print("epoch={} n={} loss={} train_Accuracy={}".format(epoch,n,loss_sum/n,accuracy(model(x),y)))
      print("test_accuracy={}".format(accuracy(model(test_x),test_y)))

In [None]:
#引入数据
(train_x0,train_y0),(test_x0,test_y0)=mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
#转换数据格式,flatten数据并转成float型
train_x=tf.reshape(tf.cast(train_x0,tf.float32)/255.0,(-1,784))
#标签转为one-hot
train_label=tf.cast(tf.one_hot(train_y0,depth=10),tf.float32)

test_x=tf.reshape(tf.cast(test_x0,tf.float32)/255.0,(-1,784))
#标签转为one-hot
test_label=tf.cast(tf.one_hot(test_y0,depth=10),tf.float32)

In [None]:
# model=Model(w1,b1,w2,b2)
model=Model()
train_Model(train_x,train_label,test_x,test_label)

epoch=0 n=60000 loss=0.0002355153949931264 train_Accuracy=0.7440166473388672
test_accuracy=0.755299985408783
epoch=10 n=60000 loss=6.575437230058014e-05 train_Accuracy=0.886983335018158
test_accuracy=0.895799994468689
epoch=20 n=60000 loss=5.739706383707623e-05 train_Accuracy=0.8969333171844482
test_accuracy=0.90420001745224
epoch=30 n=60000 loss=5.419166508751611e-05 train_Accuracy=0.9019333124160767
test_accuracy=0.9067999720573425
epoch=40 n=60000 loss=5.271263892451922e-05 train_Accuracy=0.9037333130836487
test_accuracy=0.9088000059127808


In [None]:
import os

file_path="./mnist"
if not os.path.exists(file_path):
  os.mkdir(file_path)


#保存模型训练参数
checkpoint=tf.train.Checkpoint(model=model)
filePrefix=os.path.join(file_path,'ckpt')
checkpoint.save(file_prefix=filePrefix)

'./mnist/ckpt-1'

In [None]:
!ls ./mnist

checkpoint  ckpt-1.data-00000-of-00001	ckpt-1.index


In [None]:
m=Model()  #这里只是恢复参数，所以需要先生成一个Model实例对象
check_point=tf.train.Checkpoint(model=m)
status=check_point.restore(tf.train.latest_checkpoint('./mnist'))

In [None]:
#测试
mm=check_point.model
print("--------------check_point-------------------")
print("checkpoint:restored model test accuracy:{}".format(accuracy(mm(test_x),test_label)))

--------------check_point-------------------
checkpoint:restored model test accuracy:0.9093000292778015


In [None]:
#保存完整模型原型
tf.saved_model.save(model,"./mnist/2")

INFO:tensorflow:Assets written to: ./mnist/2/assets


In [None]:
#读取模型并测试
ms=tf.saved_model.load("./mnist/2")
print("------------saved_Model---------------------")
print("saved_model :restored model test accuracy:{}".format(accuracy(ms(test_x),test_label)))

------------saved_Model---------------------
saved_model :restored model test accuracy:0.9093000292778015
