In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
print("Tensorflow version: {}".format(tf.__version__))

Tensorflow version: 2.10.0


In [3]:
# 获取mnist数据集
(train_image, train_label), (test_image, test_label) = keras.datasets.mnist.load_data()

In [4]:
train_image.shape

(60000, 28, 28)

In [5]:
# 增加训练集和测试集的维度，并且做归一化
train_image = tf.expand_dims(train_image, -1)
test_image = tf.expand_dims(test_image, -1)

In [6]:
train_image.shape

TensorShape([60000, 28, 28, 1])

In [7]:
# 改变数据类型，image为float32，label为int64
train_image = tf.cast(train_image / 255, tf.float32)
test_image = tf.cast(test_image / 255, tf.float32)
train_label = tf.cast(train_label, tf.int64)
test_label = tf.cast(test_label, tf.int64)

In [8]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
test_dataset = tf.data.Dataset.from_tensor_slices((test_image, test_label))

In [9]:
train_dataset = train_dataset.shuffle(10000).batch(32)
test_dataset = test_dataset.batch(32)

In [10]:
train_dataset

<BatchDataset element_spec=(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

In [11]:
# 定义模型，并初始化
model = keras.Sequential([
    layers.Conv2D(16, [3, 3], input_shape = (28, 28, 1), activation = "relu"),
    layers.Conv2D(32, [3, 3], activation = "relu"),
    layers.GlobalAveragePooling2D(),
    layers.Dense(10)
])

In [12]:
# 使用优化器
optimizer = keras.optimizers.Adam()

In [13]:
# 定义损失函数方法，调用即可
loss_function = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [14]:
# 迭代查看一个dataset的数据
features, label = next(iter(train_dataset))

In [15]:
features.shape

TensorShape([32, 28, 28, 1])

In [16]:
label

<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([6, 9, 7, 1, 2, 2, 8, 0, 5, 3, 8, 0, 0, 9, 7, 2, 5, 3, 9, 9, 0, 4,
       6, 9, 9, 2, 6, 2, 7, 6, 9, 3], dtype=int64)>

In [17]:
predictions = model(features)

In [18]:
predictions.shape

TensorShape([32, 10])

In [19]:
tf.argmax(predictions, axis=1)

<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3], dtype=int64)>

In [20]:
# 定义一个损失函数
def loss(model, x, y):
    # 预测的值为y_
    y_ = model(x)
    return loss_function(y, y_)

In [21]:
train_loss = keras.metrics.Mean("train_loss")
train_accuracy = keras.metrics.SparseCategoricalAccuracy("train_accuracy")
test_loss = keras.metrics.Mean("test_loss")
test_accuracy = keras.metrics.SparseCategoricalAccuracy("test_accuracy")

In [22]:
# 一步训练数据集
def train_setp(model, images, labels):
    with tf.GradientTape() as t:
        pred = model(images)
        loss_step = loss_function(labels, pred)
    grads = t.gradient(loss_step, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss_step)
    train_accuracy(labels, pred)

In [31]:
def test_step(model, images, labels):
    pred = model(images,training=False)
    loss_step = loss_function(labels, pred)
    test_loss(loss_step)
    test_accuracy(labels, pred)

In [34]:
# 开始训练
def train():
    for epoch in range(10):
        for (batch, (images, labels)) in enumerate(train_dataset):
            train_setp(model, images, labels)
        print("Epoch{} loss is {}, accuracy is {}".format(epoch, 
                                                          train_loss.result(),
                                                          train_accuracy.result()))
        
        
        for (batch, (images, labels)) in enumerate(test_dataset):
            test_step(model, images, labels)
        print("Epoch{} test_loss is {}, test_accuracy is {}".format(epoch, 
                                                          test_loss.result(),
                                                          test_accuracy.result()))
        
        
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

In [35]:
train()

Epoch0 loss is 0.9787788987159729, accuracy is 0.6949796080589294
Epoch0 test_loss is 0.857010006904602, test_accuracy is 0.7325999736785889
Epoch1 loss is 0.8332871794700623, accuracy is 0.750249981880188
Epoch1 test_loss is 0.755294144153595, test_accuracy is 0.7682999968528748
Epoch2 loss is 0.7381232380867004, accuracy is 0.7803333401679993
Epoch2 test_loss is 0.6700170040130615, test_accuracy is 0.7896999716758728
Epoch3 loss is 0.6676739454269409, accuracy is 0.8007500171661377
Epoch3 test_loss is 0.6248553991317749, test_accuracy is 0.8062000274658203
Epoch4 loss is 0.6115558743476868, accuracy is 0.8194166421890259
Epoch4 test_loss is 0.5403087735176086, test_accuracy is 0.8395000100135803
Epoch5 loss is 0.5652474164962769, accuracy is 0.8312666416168213
Epoch5 test_loss is 0.5157349705696106, test_accuracy is 0.8457000255584717
Epoch6 loss is 0.5282272696495056, accuracy is 0.8415833115577698
Epoch6 test_loss is 0.48312732577323914, test_accuracy is 0.8555999994277954
Epoch7 l