# Distributed data parallel MNIST training with PyTorch and SMDataParallel


## Background
SMDataParallel is a new capability in Amazon SageMaker to train deep learning models faster and cheaper. SMDataParallel is a distributed data parallel training framework for PyTorch. 

This notebook example shows how to use SMDataParallel with PyTorch in SageMaker using MNIST dataset.

For more information:
1. [PyTorch in SageMaker](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html)
2. [SMDataParallel PyTorch API Specification] < LINK TO BE ADDED >
3. [Getting started with SMDataParallel on SageMaker] < LINK TO BE ADDED >

**NOTE:** This example requires SageMaker Python SDK v2.X.


### Dataset
This example uses the MNIST dataset. MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits).



### SageMaker execution roles

The IAM role arn used to give training and hosting access to your data. See the [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the sagemaker.get_execution_role() with the appropriate full IAM role arn string(s).

In [1]:
pip install sagemaker --upgrade

Collecting sagemaker
  Using cached sagemaker-2.23.4.post0-py2.py3-none-any.whl
Collecting smdebug-rulesconfig==1.0.1
  Using cached smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl (20 kB)
Installing collected packages: smdebug-rulesconfig, sagemaker
  Attempting uninstall: smdebug-rulesconfig
    Found existing installation: smdebug-rulesconfig 1.0.0
    Uninstalling smdebug-rulesconfig-1.0.0:
      Successfully uninstalled smdebug-rulesconfig-1.0.0
  Attempting uninstall: sagemaker
    Found existing installation: sagemaker 2.19.0
    Uninstalling sagemaker-2.19.0:
      Successfully uninstalled sagemaker-2.19.0
Successfully installed sagemaker-2.23.4.post0 smdebug-rulesconfig-1.0.1
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import sagemaker

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

## Model training with SMDataParallel

### Training script

The MNIST dataset is downloaded using the `torchvision.datasets` PyTorch module; you can see how this is implemented in the `train_pytorch_smdataparallel_mnist.py` training script that is printed out in the next cell.

The training script provides the code you need for distributed data parallel (DDP) training using SMDataParallel. The training script is very similar to a PyTorch training script you might run outside of SageMaker, but modified to run with SMDataParallel. SMDataParallel's PyTorch client provides an alternative to PyTorch's native DDP. For details about how to use SMDataParallel's DDP in your native PyTorch script, see the Getting Started with SMDataParallel tutorials.


In [3]:
!pygmentize code/train_pytorch_smdataparallel_mnist.py

[37m# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.[39;49;00m
[37m#[39;49;00m
[37m# Licensed under the Apache License, Version 2.0 (the "License"). You[39;49;00m
[37m# may not use this file except in compliance with the License. A copy of[39;49;00m
[37m# the License is located at[39;49;00m
[37m#[39;49;00m
[37m#     http://aws.amazon.com/apache2.0/[39;49;00m
[37m#[39;49;00m
[37m# or in the "license" file accompanying this file. This file is[39;49;00m
[37m# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF[39;49;00m
[37m# ANY KIND, either express or implied. See the License for the specific[39;49;00m
[37m# language governing permissions and limitations under the License.[39;49;00m

[34mfrom[39;49;00m [04m[36m__future__[39;49;00m [34mimport[39;49;00m print_function

[34mimport[39;49;00m [04m[36mos[39;49;00m
[34mimport[39;49;00m [04m[36margparse[39;49;00m
[34mimport[39;49;00m [04m[36m

### Estimator function options

In the following code block, you can update the estimator function to use a different instance type, instance count, and distrubtion strategy. You're also passing in the training script you reviewed in the previous cell.

**Instance types**

SMDataParallel supports model training on SageMaker with the following instance types only:
1. ml.p3.16xlarge
1. ml.p3dn.24xlarge [Recommended]
1. ml.p4d.24xlarge [Recommended]

**Instance count**

To get the best performance and the most out of SMDataParallel, you should use at least 2 instances, but you can also use 1 for testing this example.

**Distribution strategy**

Note that to use DDP mode, you update the the `distribution` strategy, and set it to use `smdistributed dataparallel`. 

In [4]:
from sagemaker.pytorch import PyTorch
estimator = PyTorch(base_job_name='pytorch-smdataparallel-mnist',
                        source_dir='code',
                        entry_point='train_pytorch_smdataparallel_mnist.py',
                        role=role,
                        framework_version='1.6.0',
                        py_version='py36',
                        # For training with multinode distributed training, set this count. Example: 2
                        instance_count=2,
                        # For training with p3dn instance use - ml.p3dn.24xlarge
                        instance_type= 'ml.p3.16xlarge',
                        sagemaker_session=sagemaker_session,
                        # Training using SMDataParallel Distributed Training Framework
                        distribution={'smdistributed':{
                                            'dataparallel':{
                                                    'enabled': True
                                                 }
                                          }
                                      },
                        debugger_hook_config=False)

In [5]:
estimator.fit()

2021-01-18 04:48:36 Starting - Starting the training job...
2021-01-18 04:49:00 Starting - Launching requested ML instancesProfilerReport-1610945316: InProgress
.........
2021-01-18 04:50:21 Starting - Preparing the instances for training.........
2021-01-18 04:52:07 Downloading - Downloading input data...
2021-01-18 04:52:22 Training - Downloading the training image..............[35mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[35mbash: no job control in this shell[0m
[35m2021-01-18 04:54:41,915 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[35m2021-01-18 04:54:41,993 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2021-01-18 04:54:44,917 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container


2021-01-18 04:55:04 Training - Training image download completed. Training in progress.[34m[1,8]<stdout>:NCCL version 2.7.8+cuda11.0[0m
[34m[1,0]<stdout>:NCCL version 2.7.8+cuda11.0[0m
[34m[1,8]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/data/MNIST/raw/train-images-idx3-ubyte.gz[0m
[34m[1,0]<stdout>:Running smdistributed.dataparallel v1.0.0[0m
[34m[1,0]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/data/MNIST/raw/train-images-idx3-ubyte.gz[0m
[34m[1,8]<stdout>:Extracting /tmp/data/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/MNIST/raw[0m
[34m[1,0]<stdout>:Extracting /tmp/data/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/MNIST/raw[0m
[34m[1,8]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/data/MNIST/raw/train-labels-idx1-ubyte.gz[0m
[34m[1,0]<stdout>:Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/dat

[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0533, Accuracy: 9826/10000 (98%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0422, Accuracy: 9851/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0400, Accuracy: 9856/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0388, Accuracy: 9864/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0381, Accuracy: 9862/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0369, Accuracy: 9871/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0371, Accuracy: 9876/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0376, Accuracy: 9871/

[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0368, Accuracy: 9871/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0368, Accuracy: 9874/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Test set: Average loss: 0.0367, Accuracy: 9873/10000 (99%)[0m
[34m[1,0]<stdout>:[0m
[34m[1,0]<stdout>:Saving the model...[0m
[34m2021-01-18 04:56:23,295 sagemaker-training-toolkit INFO     Reporting training SUCCESS[0m
[35m2021-01-18 04:56:23,295 sagemaker-training-toolkit INFO     Orted process exited[0m
[35m2021-01-18 04:56:53,325 sagemaker-training-toolkit INFO     MPI process finished.[0m
[35m2021-01-18 04:56:53,326 sagemaker-training-toolkit INFO     Reporting training SUCCESS[0m

2021-01-18 04:57:06 Uploading - Uploading generated training model
2021-01-18 04:57:06 Completed - Training job completed
Training seconds: 598
Billable seconds: 598


## Next steps

Now that you have a trained model, you can deploy an endpoint to host the model. After you deploy the endpoint, you can then test it with inference requests. The following cell will store the model_data variable to be used with the inference notebook.

In [None]:
model_data = estimator.model_data
print("Storing {} as model_data".format(model_data))
%store model_data