# YOLOv5 on SageMaker--训练

## 说明
本章内容为自定义容器镜像，推送到AWS ECR，然后用SageMaker拉起进行训练，数据来自S3。

## 运行环境
Kernel 选择pytorch_latest_p36。  
本文在boto3 1.15.16和sagemaker 2.15.0下测试通过。

In [None]:
import boto3,sagemaker
print(boto3.__version__)
print(sagemaker.__version__)

## Amazon 深度学习容器

* [容器镜像清单](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)
* 本文基于pytorch training: `727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04`

## Build自定义训练镜像

In [None]:
import boto3
region = boto3.session.Session().region_name
account_id = boto3.client('sts').get_caller_identity().get('Account')
ecr_repository = 'yolov5-training'
tag = ':latest'
uri_suffix = 'amazonaws.com'
if region in ['cn-north-1', 'cn-northwest-1']:
    uri_suffix = 'amazonaws.com.cn'
image_uri = '{}.dkr.ecr.{}.{}/{}'.format(account_id, region, uri_suffix, ecr_repository + tag)
print(image_uri)
ecr = '{}.dkr.ecr.{}.{}'.format(account_id, region, uri_suffix)

In [None]:
#国内pytorch training基础镜像地址，不要修改
base_img='727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04'
#登录基础镜像ECR，不要修改
!aws ecr get-login-password --region cn-northwest-1 | docker login --username AWS --password-stdin 727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn

In [None]:
!aws ecr create-repository --repository-name $ecr_repository

In [None]:
%%time
%cd container
!docker build -t $ecr_repository -f Dockerfile --build-arg BASE_IMG=$base_img .
%cd ../

In [None]:
!docker tag $ecr_repository $image_uri
!$(aws ecr get-login --no-include-email)
!docker push $image_uri

## 在本地使用容器进行训练(可选)
本地机器如果带GPU，使用`nvidia-docker run`，如果不带GPU，使用`docker run`，建议使用2xlarge以上机型，否则可能不足以分配内存。  
训练模型结果存放在`container/local_test/model/runs/exp0/weights`

In [None]:
# !nvidia-docker run -v $(pwd)/container/local_test/:/opt/ml/ --shm-size=8g --rm $ecr_repository train
!docker run -v $(pwd)/container/local_test/:/opt/ml/ --shm-size=8g --rm $ecr_repository train

## 在SageMaker上训练
YOLOv5相关超参数在`contariner/local_test/input/data/training/cfg/`目录下的`hyp.yaml`和`train-args.json`中，如需修改，请先修改。

In [None]:
# 设置数据存放S3 bucket和前缀
bucket = 'junzhong'
pre_key = 'yolov5'

In [None]:
!aws s3 sync container/local_test/input/data/training/ s3://{bucket}/{pre_key}/training/

In [None]:
from datetime import datetime
now = datetime.now()
job_name = 'yolov5-' + now.strftime("%Y-%m-%d-%H-%M-%S")
job_name

In [None]:
training_uri='s3://{}/{}/training/'.format(bucket, pre_key)
outpath='s3://{}/{}/results/'.format(bucket, pre_key)

In [None]:
import sagemaker,boto3

iam = boto3.client('iam')
roles = iam.list_roles(PathPrefix='/service-role')
role=""
for current_role in roles["Roles"]:
    if current_role["RoleName"].startswith("AmazonSageMaker-ExecutionRole-"):
        role=current_role["Arn"]
        break
#如果role为空表示有问题
print(role)
sm = boto3.client('sagemaker')

如果不使用Spot，把下文参数`EnableManagedSpotTraining`设置为`False`，并删除`MaxWaitTimeInSeconds`以及对应值

In [None]:
response = sm.create_training_job(
      TrainingJobName=job_name,
      AlgorithmSpecification={
          'TrainingImage': image_uri,
          'TrainingInputMode': 'File',
      },
      RoleArn=role,
      InputDataConfig=[
          {
              'ChannelName': 'training',
              'DataSource': {
                  'S3DataSource': {
                      'S3DataType': 'S3Prefix',
                      'S3Uri': training_uri,
                      'S3DataDistributionType': 'FullyReplicated',
                  },
              },
              'InputMode': 'File'
          }
      ],
      OutputDataConfig={
          'S3OutputPath': outpath
      },
      ResourceConfig={
          'InstanceType': 'ml.p3.2xlarge',
          'InstanceCount': 1,
          'VolumeSizeInGB': 100,
      },
      EnableManagedSpotTraining=True,
      StoppingCondition={
        "MaxRuntimeInSeconds": 360000,
        "MaxWaitTimeInSeconds": 360000
      }
  )
response

查看状态，也可到SageMaker控制台查看。使用本Workshop提供的数据，大概需要15分钟。  
每120秒获取一次状态，因此最多可能有2分钟的延迟。

In [None]:
status = sm.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))

try:
    sm.get_waiter('training_job_completed_or_stopped').wait(TrainingJobName=job_name)
    training_info = sm.describe_training_job(TrainingJobName=job_name)
    status = training_info['TrainingJobStatus']
    print("Training job ended with status: " + status)
except:
    print('Training failed to start')
    message = sm.describe_training_job(TrainingJobName=job_name)['FailureReason']
    print('Training failed with the following error: {}'.format(message))

如果看到,

> `Training job ended with status: Completed`

这意味着训练成功完成。

### 下载训练结果

In [None]:
respone = sm.describe_training_job(TrainingJobName=job_name)
model_url = respone['ModelArtifacts']['S3ModelArtifacts']
model_url

In [None]:
!aws s3 cp {model_url} ./model.tar.gz

In [None]:
!tar -xvf ./model.tar.gz

### 拷贝到推理目录

In [None]:
!cp runs/exp0/weights/best.pt ../2-inference/source/yolov5s.pt