In [4]:
import json
import time
import boto3
import numpy as np
import sagemaker
import sagemaker.huggingface

#BUCKET="[BUCKET_NAME]" # please use your bucket name
PREFIX = "whisper-zhcn" 
ROLE = sagemaker.get_execution_role()
sess = sagemaker.Session()
BUCKET = sess.default_bucket()
print(f"sagemaker role arn: {ROLE}")
print(f"sagemaker bucket: {BUCKET}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker role arn: arn:aws:iam::348052051973:role/service-role/AmazonSageMakerServiceCatalogProductsExecutionRole
sagemaker bucket: sagemaker-us-east-1-348052051973
sagemaker session region: us-east-1


In [5]:
from sagemaker.huggingface import HuggingFace

#create an unique id to tag training job, model name and endpoint name. 
id = int(time.time())

TRAINING_JOB_NAME = f"{PREFIX}-{id}"
print('Training job name: ', TRAINING_JOB_NAME)

hyperparameters = {'max_steps':50000, # you can increase the max steps to improve model accuracy
                   'train_batch_size': 16,
                   'eval_batch_size': 8,
                   'model_name': "openai/whisper-medium",
                   'language': "Chinese",
                   'dataloader_num_workers': 16,
                  }

# define metrics definitions
metric_definitions=[
        {'Name': 'eval_loss', 'Regex': "'eval_loss': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_wer', 'Regex': "'eval_wer': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_runtime', 'Regex': "'eval_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_samples_per_second', 'Regex': "'eval_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'epoch', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"}]

Training job name:  whisper-zhcn-1673529542


In [6]:
from sagemaker.inputs import TrainingInput
training_input_path=f's3://{BUCKET}/whisper/data/zhcn-common-voice-processed'

training = TrainingInput(
    s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
    s3_data=training_input_path,
    distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key 
    input_mode='FastFile'
)

In [7]:
# configuration for running training on smdistributed model parallel
# mpi_options = {
#     "enabled" : True,
#     "processes_per_host" : 8
# }

# smp_options = {
#     "enabled":True,
#     "parameters": {
#         "microbatches": 4,
#         "placement_strategy": "spread",
#         "pipeline": "interleaved",
#         "optimize": "speed",
#         "partitions": 4,
#         "ddp": True,
#         "fp16": True,
#     }
# }

# distribution={
#     "smdistributed": {"modelparallel": smp_options},
#     "mpi": mpi_options
# }

# distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}
distribution = None
instance_type='ml.p3.2xlarge'


In [None]:
OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'

huggingface_estimator = HuggingFace(entry_point='train.py',
                                    source_dir='./scripts',
                                    output_path= OUTPUT_PATH, 
                                    instance_count=1,
                                    instance_type=instance_type,
                                    transformers_version='4.17.0',
                                    pytorch_version='1.10.2',
                                    py_version='py38',
                                    role=ROLE,
                                    hyperparameters = hyperparameters,
                                    metric_definitions = metric_definitions,
                                    volume_size=500,
                                    distribution=distribution,
                                    max_run=432000,
                                   )

#Starts the training job using the fit function, training takes approximately 2 hours to complete.
huggingface_estimator.fit({'train': training}, job_name=TRAINING_JOB_NAME)

2023-01-12 13:19:03 Starting - Starting the training job...
2023-01-12 13:19:30 Starting - Preparing the instances for trainingProfilerReport-1673529543: InProgress
.........
2023-01-12 13:20:48 Downloading - Downloading input data...
2023-01-12 13:21:27 Training - Downloading the training image.