# TensorFlow distributed training on SageMaker

## 1 说明
本章内容为演示TensorFlow在SageMaker上使用smdistributed进行分布式训练。  

## 2 运行环境
Kernel 选择tensorflow2_p36。  
本文在boto3 1.17.109和sagemaker 2.48.1下测试通过。

In [None]:
import boto3,sagemaker
print(boto3.__version__)
print(sagemaker.__version__)

如果版本较低，请执行以下命令，重启kernal后再检查版本

In [None]:
!pip install -U boto3 -i https://opentuna.cn/pypi/web/simple/

In [None]:
!pip install -U sagemaker -i https://opentuna.cn/pypi/web/simple/

## 3 设置/获取相关参数

In [None]:
import boto3
import sagemaker
from sagemaker.image_uris import retrieve

sagemaker_session = sagemaker.Session()
iam = boto3.client('iam')
roles = iam.list_roles(PathPrefix='/service-role')
role=""
for current_role in roles["Roles"]:
    if current_role["RoleName"].startswith("AmazonSageMaker-ExecutionRole-"):
        role=current_role["Arn"]
        break
#如果role为空表示有问题，需要先打开https://cn-northwest-1.console.amazonaws.cn/sagemaker/home?region=cn-northwest-1#/notebook-instances/create以创建IAM Role
print(role)

In [None]:
data_input="s3://junzhong/data/mnist.npz"

In [None]:
use_spot = True

## 4 训练

In [None]:
from sagemaker.tensorflow import TensorFlow

estimator = TensorFlow(
    base_job_name="tensorflow2-smdataparallel-mnist",
    source_dir="code",
    entry_point="train_tensorflow_smdataparallel_mnist.py",
    role=role,
    py_version="py37",
    framework_version="2.3.1",
    instance_count=2,
    instance_type="ml.p3.16xlarge",
    sagemaker_session=sagemaker_session,
    hyperparameters={'batch_size':128},
    use_spot_instances=use_spot,
    max_wait=7200 if use_spot else None,
    max_run=7200,
    # Training using SMDataParallel Distributed Training Framework
    distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
)
#日志都会输出到第1个node上
estimator.fit(data_input)

In [None]:
model_data = estimator.model_data
model_data

## 5 部署

In [None]:
from sagemaker.tensorflow.model import TensorFlowModel
model = TensorFlowModel(
            model_data=model_data, 
            role=role,
            framework_version='2.3.1')
predictor = model.deploy(initial_instance_count=1, instance_type="ml.m5.large",endpoint_name="tensorflowmnist")

## 6 推理

In [None]:
import tensorflow as tf
(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data()

In [None]:
import numpy as np
import random
image_size = 3
mask1 = random.sample(range(len(mnist_images)), image_size)
mask2 = np.array(mask1, dtype=np.int)
data = mnist_images[mask2]

In [None]:
from matplotlib import pyplot as plt
plt.figure(figsize=(2,2))
for index, mask in enumerate(mask1):
    plt.subplot(1,image_size,index+1)
    plt.imshow(mnist_images[mask])
    plt.axis('off')
plt.show()

In [None]:
from sagemaker.tensorflow.model import TensorFlowPredictor
endpoint_name = "tensorflowmnist"
predictor = TensorFlowPredictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker.Session())

predict输入和输出数据的格式直接对应Predict于TensorFlow Serving REST API 中方法的请求和响应格式  
除此外，还支持简化的 json 格式、行分隔的 json 对象（“jsons”或“jsonlines”）和 CSV 数据

In [None]:
response = predictor.predict(np.expand_dims(data, axis=3))
for i in range(0,image_size):
    print("Most likely answer: {}".format(np.argmax(response["predictions"][i])))

## 7 清理

In [None]:
predictor.delete_endpoint()