In [1]:
import tensorflow as tf 

gpus = tf.config.experimental.list_physical_devices(device_type='GPU') #本地需要这样操作
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu ,True)



#加载mnist数据集
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()

In [2]:
#增加维度
x_train=tf.expand_dims(x_train,-1)
x_test=tf.expand_dims(x_test,-1)

In [3]:
x_train=x_train/255
x_test=x_test/255

In [4]:
x_train=tf.cast(x_train,tf.float32)
x_test=tf.cast(x_test,tf.float32)

In [5]:
dataset_train=tf.data.Dataset.from_tensor_slices((x_train,y_train))
dataset_test=tf.data.Dataset.from_tensor_slices((x_test,y_test))

In [6]:
dataset_train,dataset_test

(<TensorSliceDataset shapes: ((28, 28, 1), ()), types: (tf.float32, tf.uint8)>,
 <TensorSliceDataset shapes: ((28, 28, 1), ()), types: (tf.float32, tf.uint8)>)

In [7]:
dataset_train=dataset_train.shuffle(x_test.shape[0]).batch(128)
dataset_test=dataset_test.repeat().batch(128)


In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, [3,3], activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, [3,3], activation='relu'),
    tf.keras.layers.GlobalMaxPooling2D(),
    tf.keras.layers.Dense(10)
])


In [9]:
optimizer = tf.keras.optimizers.Adam()

In [10]:
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [11]:
def loss(model, x, y):
    y_ = model(x)
    return loss_func(y, y_)

In [12]:
train_loss = tf.keras.metrics.Mean('train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')

test_loss = tf.keras.metrics.Mean('test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')

In [13]:
def train_step(model, images, labels):
    with tf.GradientTape() as t:
        pred = model(images)
        loss_step = loss_func(labels, pred)
    grads = t.gradient(loss_step, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss_step)
    train_accuracy(labels, pred)

In [14]:
def test_step(model, images, labels):
    pred = model(images)
    loss_step = loss_func(labels, pred)
    test_loss(loss_step)
    test_accuracy(labels, pred)


In [15]:
import datetime
current_time=datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir='logs/gradient_tape'+current_time+'/train'
test_log_dir='logs/gradient_tape'+current_time+'/train'
train_writer=tf.summary.create_file_writer(train_log_dir)
test_writer=tf.summary.create_file_writer(test_log_dir)

In [16]:
def train():
    for epoch in range(10):
        for (batch, (images, labels)) in enumerate(dataset_train):
            print('666')
            train_step(model, images, labels)
        with train_writer.as_default():
            tf.summary.scalar('train_loss',train_loss.result(),step=epoch)
            tf.summary.scalar('train_acc',train_accuracy.result(),step=epoch)
        
        
        for (batch, (images, labels)) in enumerate(dataset_test):
            test_step(model, images, labels)
            print('*',end='')
        with test_writer.as_default():
            tf.summary.scalar('test_loss',test_loss.result(),step=epoch)
            tf.summary.scalar('test_acc',test_accuracy.result(),step=epoch)
        template = 'Epoch{} ,loss: {} ,acc: {} ,test_loss: {} ,test_acc: {}'
        print(
            template.format(
                epoch+1,
                train_loss.result(),
                train_accuracy.result(),
                test_loss.result(),
                test_accuracy.result(),
            )
        )
        
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

In [None]:
train()

