# Distributed training using torch.distributed.launch module on Azure Machine Learning


This example show how to train language model using the huggingface library  distributed on Azure Machine Learning using pytorch estimator.

In [None]:
%load_ext autoreload
%autoreload 2


import wget
import os

from azureml.core import (Workspace, Experiment, 
                          VERSION, Datastore)
from azureml.core.compute import ComputeTarget, AmlCompute
from azureml.core.environment import Environment,CondaDependencies
from azureml.train.dnn import PyTorch,Nccl
from azureml.data.data_reference import DataReference
from azureml.core.compute_target import ComputeTargetException
from azureml.widgets import RunDetails

SUBSCRIPTION_ID = ""
RESOURCE_GROUP = ""
WORKSPACE_NAME = ""


NUM_NODES = 8
NUM_GPU_PER_NODE = 2
SKU = 'Standard_NC12'
EXP_NAME = 'Azureml-LM_huggingface_example'
CLUSTER_NAME = 'two-gpu-cluster'

RUN_DIR = os.getcwd()
DATA_DIR = 'data'

print("SDK version:", VERSION)

In [None]:
ws = Workspace(subscription_id = SUBSCRIPTION_ID, 
               resource_group =RESOURCE_GROUP , 
               workspace_name = WORKSPACE_NAME
              )


exp = Experiment(workspace=ws, name=EXP_NAME)

In [None]:
os.makedirs(DATA_DIR, exist_ok=True)
wget.download("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip",
              out=DATA_DIR
             ) 

In [None]:
datastore = ws.get_default_datastore()
ds_reference = datastore.upload(src_dir='data',
                 target_path='wikitext',
                 overwrite=True,
                 show_progress=True)


In [None]:
from azureml.core.compute import AmlCompute
from azureml.core.compute import ComputeTarget



found = False
cts = ws.compute_targets
if CLUSTER_NAME in cts and cts[CLUSTER_NAME].type == 'AmlCompute':
    found = True
    print('Found existing compute target.')
    compute_target = cts[CLUSTER_NAME]

if not found:
    print('Creating a new compute target...')
    provisioning_config = AmlCompute.provisioning_configuration(vm_size = SKU,max_nodes = NUM_NODES)

    # Create the cluster.\n",
    compute_target = ComputeTarget.create(ws, CLUSTER_NAME, provisioning_config)

print('Checking cluster status...')
compute_target.wait_for_completion(show_output = True, min_node_count = None, timeout_in_minutes = 20)

In [None]:
script_params = {
    '--dataset-path':ds_reference.as_mount(),
    '--rank':'$AZ_BATCHAI_TASK_INDEX',
    '--node_count':NUM_NODES,
    '--process_per_node':NUM_GPU_PER_NODE,
    '--batch_size':'2'
}


est = PyTorch(source_directory=RUN_DIR,
                pip_packages=['gitpython','scikit-learn','seqeval','tensorboardX',\
                              'tqdm','transformers'],
                script_params=script_params,
                use_gpu=True,
                compute_target=compute_target,
                entry_script=os.path.join(RUN_DIR,'train.py'),
                framework_version='1.4',
                node_count=NUM_NODES,
                distributed_training=Nccl()
            )

In [None]:
run = exp.submit(est)
RunDetails(run).show()

In [None]:
run