In [8]:
%%time
import os
import json

import sagemaker
import boto3
from sagemaker.pytorch import PyTorch
from sagemaker import get_execution_role
import torch
from sagemaker.utils import unique_name_from_base

sess = sagemaker.Session()
bucket = sess.default_bucket()
role = get_execution_role()
prefix = 'wenet2x'
output_path = f"s3://{bucket}/{prefix}"

print("torch.__version__:{}".format(torch.__version__))
print("boto3.__version__:{}".format(boto3.__version__))
print("sagemaker.__version__:{}".format(sagemaker.__version__))
print("bucket:{}".format(bucket))
print("role:{}".format(role))

torch.__version__:1.12.1+cpu
boto3.__version__:1.24.84
sagemaker.__version__:2.117.0
bucket:sagemaker-us-east-1-348052051973
role:arn:aws:iam::348052051973:role/service-role/AmazonSageMakerServiceCatalogProductsExecutionRole
CPU times: user 136 ms, sys: 28.1 ms, total: 164 ms
Wall time: 280 ms


In [9]:
%%markdown
Copy the wenet/examples/librispeech/s0/*.sh and wenet/examples/librispeech/s0/local to wenet/ as requested by Sagemaker
Overwrite the wenet/wenet/bin/train.py with the given one
Change the /root/wenet to /opt/ml/input in all data.list files (especially for train_960 and dev)

If you are cloning from Github:
The "Librispeech" in data.list file in Github has the wrong captalization because it's wrong when I upload it. Please change it yourself!

Copy the wenet/examples/librispeech/s0/*.sh and wenet/examples/librispeech/s0/local to wenet/ as requested by Sagemaker
Overwrite the wenet/wenet/bin/train.py with the given one
Change the /root/wenet to /opt/ml/input in all data.list files (especially for train_960 and dev)

If you are cloning from Github:
The "Librispeech" in data.list file in Github has the wrong captalization because it's wrong when I upload it. Please change it yourself!


In [10]:
from sagemaker.inputs import TrainingInput
prefix_dataset = "wenet/export"
loc =f"s3://{bucket}/{prefix_dataset}"
print(loc)
training = TrainingInput(
    s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
    s3_data=loc,
    distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key 
    input_mode='FastFile'
)



s3://sagemaker-us-east-1-348052051973/wenet/export


In [11]:
%%time
instance_type = "ml.p3.2xlarge"
# instance_type='local'

max_run = 432000
checkpoint_s3_uri = f"s3://{bucket}/{prefix}/checkpoints"

hyperparameters = {
    'datadir':'/opt/ml/input/data/training',
    'stage': '4',
    'stop_stage': '5',
    'train_config': 'examples/librispeech/s0/conf/train_conformer.yaml',
    'model_dir': '/opt/ml/model',
    'checkpoint_dir': '/opt/ml/checkpoints',
    'output_dir': '/opt/ml/output/data',
}

est = PyTorch(
    entry_point="run.sh",
    source_dir="./wenet",
    image_uri = "348052051973.dkr.ecr.us-east-1.amazonaws.com/sagemaker-pytorch110:3",
    # framework_version="1.11.0",
    py_version="py38",
    role=role,
    instance_count=1,
    instance_type=instance_type,
    volume_size=200,
    disable_profiler=True,
    debugger_hook_config=False,
    base_job_name=prefix,
    hyperparameters = hyperparameters,
    checkpoint_s3_uri = checkpoint_s3_uri,
    # keep_alive_period_in_seconds=1800,
    max_run = max_run,
    tags = [{"Key": "team", "Value": "asr"}, {"Key": "person", "Value": "andrew"}, {"Key": "project", "Value": "abc"}],
)


CPU times: user 149 ms, sys: 16.1 ms, total: 165 ms
Wall time: 164 ms


In [None]:
%%time
job_name = est.fit({"training":training})
#job_name = est.fit()

2022-12-08 04:30:25 Starting - Starting the training job...
2022-12-08 04:30:55 Starting - Preparing the instances for training.........
2022-12-08 04:32:12 Downloading - Downloading input data......
2022-12-08 04:33:02 Training - Downloading the training image..........[34m2022-12-08 05:23:16,997 DEBUG TRAIN Batch 0/1400 loss 304.129456 loss_att 283.731781 loss_ctc 351.723999 lr 0.00022416 rank 0[0m


In [None]:
model_data = est.model_data
print("Model artifact saved at:\n", model_data)