In [4]:
import json
import time
import boto3
import numpy as np
import sagemaker
import sagemaker.huggingface
import os

#BUCKET="[BUCKET_NAME]" # please use your bucket name
ROLE = sagemaker.get_execution_role()
sess = sagemaker.Session()
BUCKET = sess.default_bucket()
PREFIX = "whisper/data/zhtw-common-voice-processed"
s3uri = os.path.join("s3://", BUCKET, PREFIX)
print(f"sagemaker role arn: {ROLE}")
print(f"sagemaker bucket: {BUCKET}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"data uri: {s3uri}")


sagemaker role arn: arn:aws:iam::348052051973:role/service-role/AmazonSageMakerServiceCatalogProductsExecutionRole
sagemaker bucket: sagemaker-us-east-1-348052051973
sagemaker session region: us-east-1
data uri: s3://sagemaker-us-east-1-348052051973/whisper/data/zhtw-common-voice-processed


In [7]:
# For distributed training
# distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}
# instance_type = 'ml.p3.16xlarge'
# training_batch_size  = 4
# eval_batch_size = 2

# For single instance training
distribution = None
instance_type = 'ml.p3.2xlarge'
training_batch_size  = 16
eval_batch_size = 8

In [5]:
from sagemaker.huggingface import HuggingFace

#create an unique id to tag training job, model name and endpoint name. 
id = int(time.time())

TRAINING_JOB_NAME = f"whisper-zhtw-{id}"
print('Training job name: ', TRAINING_JOB_NAME)

hyperparameters = {'max_steps':16000, # you can increase the max steps to improve model accuracy
                   'train_batch_size': training_batch_size,
                   'eval_batch_size': eval_batch_size,
                   'model_name': "openai/whisper-small",
                   'language': "Chinese",
                   'dataloader_num_workers': 16,
                  }

# define metrics definitions
metric_definitions=[
        {'Name': 'eval_loss', 'Regex': "'eval_loss': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_wer', 'Regex': "'eval_wer': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_runtime', 'Regex': "'eval_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'eval_samples_per_second', 'Regex': "'eval_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'epoch', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"}]

Training job name:  whisper-zhtw-1675236655


In [6]:
from sagemaker.inputs import TrainingInput
training_input_path=f's3://{BUCKET}/whisper/data/zhtw-common-voice-processed'

training = TrainingInput(
    s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
    s3_data=training_input_path,
    distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key 
    input_mode='FastFile'
)

In [None]:
OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'

huggingface_estimator = HuggingFace(entry_point='train.py',
                                    source_dir='./scripts',
                                    output_path= OUTPUT_PATH, 
                                    instance_type=instance_type,
                                    instance_count=1,
                                    transformers_version='4.17.0',
                                    pytorch_version='1.10.2',
                                    py_version='py38',
                                    role=ROLE,
                                    hyperparameters = hyperparameters,
                                    metric_definitions = metric_definitions,
                                    volume_size=200,
                                    distribution=distribution,
                                   )

#Starts the training job using the fit function, training takes approximately 2 hours to complete.
huggingface_estimator.fit({'train': training}, job_name=TRAINING_JOB_NAME)

2023-02-01 07:32:35 Starting - Starting the training job...
2023-02-01 07:33:01 Starting - Preparing the instances for trainingProfilerReport-1675236755: InProgress
.........
2023-02-01 07:34:35 Downloading - Downloading input data
2023-02-01 07:34:35 Training - Downloading the training image...............[34m1%|          | 180/16000 [06:40<10:06:20,  2.30s/it][0m
[34m1%|          | 181/16000 [06:42<9:33:37,  2.18s/it][0m
[34m1%|          | 182/16000 [06:44<9:22:05,  2.13s/it][0m
[34m1%|          | 183/16000 [06:46<9:19:20,  2.12s/it][0m
[34m1%|          | 184/16000 [06:48<9:20:16,  2.13s/it][0m
[34m1%|          | 185/16000 [06:50<9:02:52,  2.06s/it][0m
[34m1%|          | 186/16000 [06:52<9:06:40,  2.07s/it][0m
[34m1%|          | 187/16000 [06:55<9:17:18,  2.11s/it][0m
[34m1%|          | 188/16000 [06:57<9:21:57,  2.13s/it][0m
[34m1%|          | 189/16000 [06:59<9:43:41,  2.22s/it][0m
[34m1%|          | 190/16000 [07:01<9:37:27,  2.19s/it][0m
[34m1%|          | 