# Use SageMaker Distributed Data Parallel library to pre-train DeBERTa v3

[Amazon SageMaker's distributed library](https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-training.html) can be used to train deep learning models faster and cheaper. The [data parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html) feature in this library (`smdistributed.dataparallel`) is a distributed data parallel training framework that provides seamless integration with common frameworks, like PyTorch and TensorFlow.

This notebook example shows how to use SageMaker's DDP (SMDDP) library with PyTorch(version 1.10.2) on [Amazon SageMaker](https://aws.amazon.com/sagemaker/) to pre-train DeBERTa v3 using the public available wiki103 dataset.


The outline of steps is as follows:

1. Prepare the training dataset and stage in [Amazon S3](https://aws.amazon.com/s3/).
2. Configure the estimator function options, like distribution strategy and hyperparameters.
3. Use PyTorch estimator to pre-train DeBERTa v3 on wiki103 dataset.

**NOTE:** This example requires SageMaker Python SDK v2.X.


## Amazon SageMaker Initialization

Initialize the notebook instance. Get the AWS Region and a SageMaker execution role.

### SageMaker role

The following code cell defines `role` which is the IAM role ARN used to create and run SageMaker training and hosting jobs. This is the same IAM role used to create this SageMaker Notebook instance. 

`role` must have permission to create a SageMaker training job and host a model. For granular policies you can use to grant these permissions, see [Amazon SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html). If you do not require fine-tuned permissions for this demo, you can use the IAM managed policy AmazonSageMakerFullAccess to complete this demo. 

In [None]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker.local import LocalSession
import boto3

local_mode = False

role = (
    get_execution_role() # Provide a pre-existing role ARN as an alternative to creating a new role
)  
print(f"SageMaker Execution Role: {role}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account: {account}")
if local_mode:
    session = LocalSession()
    session.config = {'local': {'local_code': True}}
    sagemaker_session = session
    region = 'us-west-2' # Update to appropiate region.
else:
    session = boto3.session.Session()
    region = session.region_name
    print(f"AWS region: {region}")
    sm_boto_client = boto3.client("sagemaker")
    sagemaker_session = sagemaker.session.Session(boto_session=session)

### Configure the Amazon S3 bucket to host the dataset

In [None]:
# Get default bucket
default_bucket = sagemaker_session.default_bucket()
print("Default bucket for this session: ", default_bucket)

s3_output_location = f"s3://{default_bucket}/output/"

s3_train_bucket = 's3://sagemaker-us-west-2-570106654206/data/deberta-mlm/'
train = sagemaker.inputs.TrainingInput(s3_train_bucket, distribution="FullyReplicated", s3_data_type="S3Prefix")
data_channels = {"train": train}

volume_size = 500 # Size in GB of the EBS volume to use for storing input data during training 

checkpoint_bucket = f"s3://sagemaker-{region}-{account}/"

### Configure SageMaker PyTorch Estimator function options

In the following code blocks, you can update the estimator function to use a different instance type, instance count, distribution strategy and hyperparameters. You're also passing an entry point to the training script.

**Instance types**

`smdistributed.dataparallel` supports model training on SageMaker with the following instance types only. For best performance, it is recommended you use an instance type that supports [Amazon Elastic Fabric Adapter (EFA)](https://aws.amazon.com/hpc/efa/).

1. `ml.p3.16xlarge`
1. `ml.p3dn.24xlarge` [Recommended]
1. `ml.p4d.24xlarge` [Recommended]

**Instance count**

To get the best performance and the most out of `smdistributed.dataparallel`, you should use at least 2 instances, but you can also use 1 for testing this example.

In [None]:
instance_type = "ml.p4d.24xlarge"
instance_count = 2

**Distribution strategy**

Note that to use DDP mode, you need to update the `distribution` strategy, and set it to use `smdistributed dataparallel`.

In [None]:
distribution_strategy = {"smdistributed": {"dataparallel": {"enabled": True}}}

**Assign a base name**

The base name is used as prefix to the SageMaker training job, so you can identify it easily in the [SageMaker console](console.aws.amazon.com/sagemaker/).

In [None]:
base_job_name = "deberta-v3"

**Create the estimator function and pass the parameters**

Use all parameters from previous sections to configure the estimator function.

In [None]:
from sagemaker.pytorch import PyTorch
import os

estimator = PyTorch(
    entry_point='launcher_ddp.py', 
    source_dir=os.path.dirname(os.path.dirname(os.getcwd())),
    role=role,
    instance_type=instance_type if not local_mode else 'local_gpu',
    volume_size=volume_size,
    instance_count=instance_count,
    sagemaker_session=sagemaker_session,
    distribution=distribution_strategy,
    framework_version="1.10",
    py_version="py38",
    output_path=s3_output_location,
    checkpoint_s3_uri=f"{checkpoint_bucket}/experiments/{base_job_name}/",
    checkpoint_local_path='/opt/ml/checkpoints/',
    debugger_hook_config=False,
    disable_profiler=True,
    base_job_name=base_job_name,
)

## Start the SageMaker training job
Run the cell below to start the pre-training of DeBERTa v3 on wiki103 dataset.

In [None]:
estimator.fit(inputs=data_channels)

## Clean Up

To avoid incurring unnecessary charges, follow these [steps to use the AWS Management Console to delete resources such as endpoints, notebook instances, S3 buckets, and CloudWatch logs](https://docs.aws.amazon.com/sagemaker/latest/dg/ex1-cleanup.html).