In [6]:
import tensorflow as tf
from tensorflow.keras import datasets,layers,models

In [5]:
#定义CNN模型类
class CNN(object):
    def __init__(self):
        model=models.Sequential()
        #第一卷积层，卷积核3x3，32个，图像28x28
        model.add(layers.Conv2D(32,(3,3),activation='relu',
                                input_shape=(28,28,1)))
        model.add(layers.MaxPooling2D((2,2)))
        #第二卷积层
        model.add(layers.Conv2D(64,(3,3),activation='relu'))
        model.add(layers.MaxPooling2D((2,2)))
        #第三卷积层
        model.add(layers.Conv2D(64,(3,3),activation='relu'))
        model.add(layers.Flatten())
        #全连接层
        model.add(layers.Dense(64,activation='relu'))
        model.add(layers.Dense(10,activation='softmax'))
        model.summary()
        self.model=model

In [9]:
class DataSource(object):
    def __init__(self):
        (train_images,train_labels),(test_images,test_labels)=datasets.mnist.load_data()
        train_images=train_images.reshape((60000,28,28,1))
        test_images=test_images.reshape((10000,28,28,1))
        #将像素映射到0~1
        train_images,test_images=train_images/255.0,test_images/255.0
        self.train_iamges,self.train_labels=train_images,train_labels
        self.test_images,self.test_labels=test_images,test_labels

In [29]:
class Train:
    def __init__(self):
        self.cnn=CNN()
        self.data=DataSource()
    def train(self):
        check_path='./ckpt/cp-{epoch:04d}.ckpt'
        #定义回调函数，每训练5次保存一次检查点
        save_model_cb=tf.keras.callbacks.ModelCheckpoint(
            check_path,save_weights_only=True,verbose=1,period=5)
        #定义模型参数
        self.cnn.model.compile(optimizer='adam',
                               loss=tf.keras.losses.sparse_categorical_crossentropy,
                               metrics=['acc'])
        #指定训练计划
        self.cnn.model.fit(self.data.train_iamges,
                           self.data.train_labels,
                           epochs=10,
                           callbacks=[save_model_cb])
        #对模型进行评估
        test_loss,test_acc=self.cnn.model.evaluate(
            self.data.test_images,self.data.test_labels)
        print("acc:%.4f,tested %d image"%(test_acc,len(self.data.test_labels)))

In [30]:

if __name__=='__main__':
    test=Train()
    test.train()

Model: "sequential_12"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_36 (Conv2D)          (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d_24 (MaxPooli  (None, 13, 13, 32)        0         
 ng2D)                                                           
                                                                 
 conv2d_37 (Conv2D)          (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_25 (MaxPooli  (None, 5, 5, 64)          0         
 ng2D)                                                           
                                                                 
 conv2d_38 (Conv2D)          (None, 3, 3, 64)          36928     
                                                                 
 flatten_12 (Flatten)        (None, 576)             

In [16]:
from PIL import Image
import numpy as np
import tensorflow as tf

In [32]:
#定义预测类
class Predict(object):
    def __init__(self):
        latest=tf.train.latest_checkpoint('./ckpt')
        self.cnn=CNN()
        #恢复网络权重
        self.cnn.model.load_weights(latest)
    def predict(self,image_path):
        #读取黑白图片
        img=Image.open(image_path).convert('L')
        flatten_img=np.reshape(img,(28,28,1))
        x=np.array([1-flatten_img])
        y=self.cnn.model.predict(x)
        print(image_path)
        print(y[0])
        print('  ->Predict digit',np.argmax(y[0]))

In [33]:
if __name__ == "__main__":
    test=Predict()
    test.predict("C:/Users/yeolume/Pictures/三星多屏联动/0.jpg")
    test.predict("C:/Users/yeolume/Pictures/三星多屏联动/1.jpg")
    test.predict("C:/Users/yeolume/Pictures/三星多屏联动/2.jpg")

Model: "sequential_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_42 (Conv2D)          (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d_28 (MaxPooli  (None, 13, 13, 32)        0         
 ng2D)                                                           
                                                                 
 conv2d_43 (Conv2D)          (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_29 (MaxPooli  (None, 5, 5, 64)          0         
 ng2D)                                                           
                                                                 
 conv2d_44 (Conv2D)          (None, 3, 3, 64)          36928     
                                                                 
 flatten_14 (Flatten)        (None, 576)             