In [1]:
import tensorflow as tf 
import datetime
import os 

In [2]:
tf.__version__

'2.3.0'

In [3]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [4]:
train_images = tf.expand_dims(train_images, -1)

In [5]:
test_images = tf.expand_dims(test_images, -1)

In [6]:
train_images = tf.cast(train_images/255, tf.float32)
test_images = tf.cast(test_images/255, tf.float32)

In [7]:
train_labels = tf.cast(train_labels, tf.int64)
test_labels = tf.cast(test_labels, tf.int64)

In [8]:
dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))

In [9]:
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

In [10]:
dataset

<TensorSliceDataset shapes: ((28, 28, 1), ()), types: (tf.float32, tf.int64)>

In [11]:
dataset = dataset.repeat().shuffle(60000).batch(128)

In [12]:
test_dataset = test_dataset.repeat().batch(128)

In [13]:
dataset

<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>

In [14]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, [3,3], activation='relu', input_shape=(None, None, 1)),
    tf.keras.layers.Conv2D(32, [3,3], activation='relu'),
    tf.keras.layers.GlobalMaxPooling2D(),
    tf.keras.layers.Dense(10, activation='softmax')    
])

In [15]:
# 使用 Keras 的自定义训练 API 

In [16]:
optimizer = tf.keras.optimizers.Adam()

In [17]:
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()

In [18]:
def loss(model, x, y):
    y_ = model(x)
    return loss_func(y, y_)

In [19]:
train_loss = tf.keras.metrics.Mean('train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')

test_loss = tf.keras.metrics.Mean('test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')

In [20]:
def train_step(model, images, labels):
    with tf.GradientTape() as t:
        pred = model(images)
        loss_step = loss_func(labels, pred)
    grads = t.gradient(loss_step, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss_step)
    train_accuracy(labels, pred)

In [21]:
def test_step(model, images, labels):
    pred = model(images)
    loss_step = loss_func(labels, pred)
    test_loss(loss_step)
    test_accuracy(labels, pred)

In [22]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

In [23]:
train_log_dir = os.path.join('log/gradient_tape', current_time, 'train')
test_log_dir = os.path.join('log/gradient_tape', current_time, 'test')
train_writer = tf.summary.create_file_writer(train_log_dir)
test_writer = tf.summary.create_file_writer(test_log_dir)

In [24]:
def train():
    for epoch in range(10):
        for (batch, (images, labels)) in enumerate(dataset):
            train_step(model, images, labels)
        with train_writer.as_default():
            tf.summary.scalar('loss', train_loss.result(), step=epoch)
            tf.summary.scalar('acc', train_accuracy.result(), step=epoch)
        
        for (batch, (images, labels)) in enumerate(dataset):
            test_step(model, images, labels)
            print('*', end='')
        with test_writer.as_default():
            tf.summary.scalar('loss', test_loss.result(), step=epoch)
            tf.summary.scalar('acc', test_accuracy.result(), step=epoch)
            
        template = 'Epoch {}, Loss: {}, Accuracy {}, Test Loss: {}, Test Accuracy {}'
        print(template.format(epoch+1,
                              train_loss.result(),
                             train_accuracy.result()*100,
                             test_loss.result(),
                             test_accuracy.result()*100))
        
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

In [25]:
train()

KeyboardInterrupt: 