In [6]:
import json
import time
import boto3
import numpy as np
import sagemaker
import sagemaker.huggingface

#BUCKET="[BUCKET_NAME]" # please use your bucket name
PREFIX = "whisper-hi" 
ROLE = sagemaker.get_execution_role()
sess = sagemaker.Session()
BUCKET = sess.default_bucket()
print(f"sagemaker role arn: {ROLE}")
print(f"sagemaker bucket: {BUCKET}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker role arn: arn:aws:iam::348052051973:role/service-role/AmazonSageMakerServiceCatalogProductsExecutionRole
sagemaker bucket: sagemaker-us-east-1-348052051973
sagemaker session region: us-east-1


In [15]:
from sagemaker.huggingface import HuggingFace

#create an unique id to tag training job, model name and endpoint name. 
id = int(time.time())

TRAINING_JOB_NAME = f"huggingface-whisper-training-{id}"
print('Training job name: ', TRAINING_JOB_NAME)

hyperparameters = {'max_steps':4000, # you can increase the max steps to improve model accuracy
                   'train_batch_size': 4,
                   'eval_batch_size': 2,
                   'model_name': "openai/whisper-small",
                   'language': "Hindi",
                   'dataloader_num_workers': 16,
                  }

# define metrics definitions
metric_definitions=[
        {'Name': 'eval_loss', 'Regex': "'eval_loss': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_wer', 'Regex': "'eval_wer': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_runtime', 'Regex': "'eval_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_samples_per_second', 'Regex': "'eval_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'epoch', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"}]

Training job name:  huggingface-whisper-training-1672982311


In [16]:
from sagemaker.inputs import TrainingInput
training_input_path=f's3://{BUCKET}/whisper/data/hi-common-voice'

training = TrainingInput(
    s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
    s3_data=training_input_path,
    distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key 
    input_mode='FastFile'
)

In [17]:
# configuration for running training on smdistributed model parallel
# mpi_options = {
#     "enabled" : True,
#     "processes_per_host" : 8
# }

# smp_options = {
#     "enabled":True,
#     "parameters": {
#         "microbatches": 4,
#         "placement_strategy": "spread",
#         "pipeline": "interleaved",
#         "optimize": "speed",
#         "partitions": 4,
#         "ddp": True,
#         "fp16": True,
#     }
# }

# distribution={
#     "smdistributed": {"modelparallel": smp_options},
#     "mpi": mpi_options
# }

distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}


In [None]:
OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'

huggingface_estimator = HuggingFace(entry_point='train.py',
                                    source_dir='./scripts',
                                    output_path= OUTPUT_PATH, 
                                    instance_type='ml.p3.16xlarge',
                                    instance_count=1,
                                    transformers_version='4.17.0',
                                    pytorch_version='1.10.2',
                                    py_version='py38',
                                    role=ROLE,
                                    hyperparameters = hyperparameters,
                                    metric_definitions = metric_definitions,
                                    volume_size=200,
                                    distribution = distribution,
                                   )

#Starts the training job using the fit function, training takes approximately 2 hours to complete.
huggingface_estimator.fit({'train': training}, job_name=TRAINING_JOB_NAME)

2023-01-06 05:18:32 Starting - Starting the training job...ProfilerReport-1672982312: InProgress
......
2023-01-06 05:19:56 Starting - Preparing the instances for training......
2023-01-06 05:20:57 Downloading - Downloading input data...
2023-01-06 05:21:17 Training - Downloading the training image........................
2023-01-06 05:25:23 Training - Training image download completed. Training in progress...[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-01-06 05:25:47,289 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-01-06 05:25:47,365 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-01-06 05:25:47,368 sagemaker_pytorch_container.training INFO     Invoking SMDataParallel[0m
[34m2023-01-06 05:25:47,368 sagemaker_pytorch_container.training INFO     Invoking user training script.[0