In [None]:
import tensorflow as tf
from tensorflow import keras
import os
print(tf.__version__)

In [None]:
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()

In [None]:
train_image=train_image/255.0
test_image=test_image/255.0

In [None]:
ds_train = tf.data.Dataset.from_tensor_slices((train_image,train_label))
ds_test = tf.data.Dataset.from_tensor_slices((test_image,test_label))

In [None]:
#定义分布式策略
strategy = tf.distribute.MirroredStrategy()
print("number of devices:{}".format(strategy.num_replicas_in_sync))

In [None]:
#设置输入流水线
#当使用多个GPU训练模型师，可通过增加批次大小来有效利用额外的计算能力，通常，应使用适合GPU内存的批次大小，并相应调整学习率
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA=64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA*strategy.num_replicas_in_sync

In [None]:
train_dataset = ds_train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = ds_test.batch(BATCH_SIZE)

In [None]:
# 生成模型
with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32,3,activation='relu',input_shape=(28,28,1)),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64,activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])

In [None]:
## 定义回调函数

In [None]:
checkpoint_dir = './train_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,"cpkt_{epoch}")

In [None]:
def decay(epoch):
    if epoch <3:
        return 1e-3
    elif epoch>=3 and epoch<7:
        return 1e-4
    else:
        return 1e-5

In [None]:
class PrintLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self,epoch,logs =None):
        print("\nLerning rate for epoch {} is {}".format(epoch+1,model.optimizer.lr.numpy()))

In [None]:
callbacks =[
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

In [None]:

model.fit(train_dataset,
          epochs=12
          # callbacks=callbacks
          )