In [None]:
pip install sagemaker --upgrade

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

In [None]:
!pygmentize code/train_tensorflow_smdataparallel_mnist.py

In [None]:
from sagemaker.tensorflow import TensorFlow

estimator = TensorFlow(
    base_job_name="tensorflow2-smdataparallel-mnist",
    source_dir="code",
    entry_point="train_tensorflow_smdataparallel_mnist.py",
    role=role,
    py_version="py37",
    framework_version="2.4.1",
    # For training with multinode distributed training, set this count. Example: 2
    instance_count=2,
    # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
    instance_type="ml.p3.16xlarge",
    sagemaker_session=sagemaker_session,
    # Training using SMDataParallel Distributed Training Framework
    distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
)


In [None]:
estimator.fit()

In [None]:
model_data = estimator.model_data
print("Storing {} as model_data".format(model_data))
%store model_data

In [None]:
predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")

In [None]:
print(predictor.endpoint_name)

In [None]:
import tensorflow as tf
import numpy as np

(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data(path="/tmp/data")

In [None]:
for i in range(10):
    data = mnist_images[i].reshape(1, 28, 28, 1)

    predict_response = predictor.predict(data)

    print("========================================")
    label = mnist_labels[i]

    predict_label = np.argmax(predict_response["predictions"])

    print("label is {}".format(label))
    print("prediction is {}".format(predict_label))

In [None]:
predictor.delete_endpoint()