# YOLOv5 on SageMaker--训练

## 1 说明
本章内容为用SageMaker进行训练，数据来自S3。

## 2 运行环境
Kernel 选择pytorch_latest_p36。  
本文在boto3 1.17.12和sagemaker 2.26.0下测试通过。

In [None]:
import boto3,sagemaker
print(boto3.__version__)
print(sagemaker.__version__)

## 3 获取image

本项目已build完毕image，存放到ECR中，可直接部署到SageMaker。请选择选择合适版本。

In [None]:
tag = "v3.1"

In [None]:
import boto3
region = boto3.session.Session().region_name
image_uri = '048912060910.dkr.ecr.{}.amazonaws.com.cn/nwcd/yolov5-training:{}'.format(region,tag)
image_uri

## 4 在SageMaker上训练

In [None]:
# 设置数据存放S3 bucket和前缀
bucket = 'junzhong'
pre_key = 'yolov5'

In [None]:
training_uri='s3://{}/{}/training/'.format(bucket, pre_key)
outpath='s3://{}/{}/results/'.format(bucket, pre_key)

In [None]:
import sagemaker,boto3

iam = boto3.client('iam')
roles = iam.list_roles(PathPrefix='/service-role')
role=""
for current_role in roles["Roles"]:
    if current_role["RoleName"].startswith("AmazonSageMaker-ExecutionRole-"):
        role=current_role["Arn"]
        break
#如果role为空表示有问题
print(role)
sm = boto3.client('sagemaker')

In [None]:
#设置是否使用spot实例进行训练
use_spot = True

YOLOv5相关参数在`contariner/local_test/input/data/training/cfg/`目录下的`hyp.yaml`中，如需修改，请先修改。每次修改完后需要再同步。

In [None]:
!aws s3 sync container/local_test/input/data/training/ s3://{bucket}/{pre_key}/training/

In [None]:
from datetime import datetime
now = datetime.now()
job_name = 'yolov5-' + now.strftime("%Y-%m-%d-%H-%M-%S")
job_name

In [None]:
response = sm.create_training_job(
      TrainingJobName=job_name,
      HyperParameters={
          'img':"640",
          'batch':"16",
          'epochs':"3",
          'hyp':"/opt/ml/input/data/training/cfg/hyp.yaml",
          'data':"/opt/ml/input/data/training/cfg/data.yaml",
          'cfg':"/opt/ml/input/data/training/cfg/yolov5s.yaml",
          'weights':"/opt/ml/input/data/training/weights/yolov5s.pt"
      },
      AlgorithmSpecification={
          'TrainingImage': image_uri,
          'TrainingInputMode': 'File',
      },
      RoleArn=role,
      InputDataConfig=[
          {
              'ChannelName': 'training',
              'DataSource': {
                  'S3DataSource': {
                      'S3DataType': 'S3Prefix',
                      'S3Uri': training_uri,
                      'S3DataDistributionType': 'FullyReplicated',
                  },
              },
              'InputMode': 'File'
          }
      ],
      OutputDataConfig={
          'S3OutputPath': outpath
      },
      ResourceConfig={
          'InstanceType': 'ml.p3.2xlarge',
          'InstanceCount': 1,
          'VolumeSizeInGB': 100,
      },
      EnableManagedSpotTraining=use_spot,
      StoppingCondition={"MaxWaitTimeInSeconds": 3600,"MaxRuntimeInSeconds": 3600} if use_spot else {"MaxRuntimeInSeconds": 3600}
  )
response

查看状态，也可到SageMaker控制台查看。使用本Workshop提供的数据，大概需要15分钟。  
每120秒获取一次状态，因此最多可能有2分钟的延迟。

In [None]:
status = sm.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))

try:
    sm.get_waiter('training_job_completed_or_stopped').wait(TrainingJobName=job_name)
    training_info = sm.describe_training_job(TrainingJobName=job_name)
    status = training_info['TrainingJobStatus']
    print("Training job ended with status: " + status)
except:
    print('Training failed to start')
    message = sm.describe_training_job(TrainingJobName=job_name)['FailureReason']
    print('Training failed with the following error: {}'.format(message))

如果看到,

> `Training job ended with status: Completed`

这意味着训练成功完成。

## 5 下载训练结果

复制下面代码输出的`model_data`，在推理中要使用

In [None]:
respone = sm.describe_training_job(TrainingJobName=job_name)
model_data = respone['ModelArtifacts']['S3ModelArtifacts']
!echo -n $model_data > model_data.txt

In [None]:
!aws s3 cp {model_data} model.tar.gz

In [None]:
!tar -xvf model.tar.gz