# 训练

## 1 说明
本章内容为用SageMaker进行训练，数据来自S3。

## 2 运行环境
Kernel 选择Python 3。  
本文在boto3 1.17.17下测试通过。

In [None]:
import boto3
print(boto3.__version__)

## 3 获取image

本项目已build完毕image，存放到ECR中，可直接进行训练。基于训练成本考虑，image只在cn-northwest-1(宁夏)

In [None]:
image_uri = '048912060910.dkr.ecr.cn-northwest-1.amazonaws.com.cn/nwcd/ocr-training:rec'

## 4 在SageMaker上训练

In [None]:
# 设置数据存放S3 bucket和前缀
bucket = 'junzhong'
pre_key = 'data/ocr/Chinese'

In [None]:
#可修改uri
training_uri='s3://{}/{}/training.zip'.format(bucket, pre_key)
validation_uri='s3://{}/{}/validation.zip'.format(bucket, pre_key)
checkpoint='s3://{}/{}/checkpoint/'.format(bucket, pre_key)
outpath='s3://{}/{}/results/'.format(bucket, pre_key)

In [None]:
import boto3

iam = boto3.client('iam')
roles = iam.list_roles(PathPrefix='/service-role')
role=""
for current_role in roles["Roles"]:
    if current_role["RoleName"].startswith("AmazonSageMaker-ExecutionRole-"):
        role=current_role["Arn"]
        break
#如果role为空表示有问题
print(role)
sm = boto3.client('sagemaker')

In [None]:
#设置是否使用spot实例进行训练
use_spot = True

In [None]:
from datetime import datetime
now = datetime.now()
job_name = 'ocr-' + now.strftime("%Y-%m-%d-%H-%M-%S")
job_name

In [None]:
response = sm.create_training_job(
      TrainingJobName=job_name,
      HyperParameters={
          'gpu':'0',
          'Global.epoch_num':'1',
          'Train.loader.batch_size_per_card':'128'
      },
      AlgorithmSpecification={
          'TrainingImage': image_uri,
          'TrainingInputMode': 'File',
      },
      RoleArn=role,
      InputDataConfig=[
          {
              'ChannelName': 'training',
              'DataSource': {
                  'S3DataSource': {
                      'S3DataType': 'S3Prefix',
                      'S3Uri': training_uri,
                      'S3DataDistributionType': 'FullyReplicated',
                  },
              },
              'InputMode': 'File'
          },
          {
              'ChannelName': 'validation',
              'DataSource': {
                  'S3DataSource': {
                      'S3DataType': 'S3Prefix',
                      'S3Uri': validation_uri,
                      'S3DataDistributionType': 'FullyReplicated',
                  },
              },
              'InputMode': 'File'
          }
      ],
      OutputDataConfig={
          'S3OutputPath': outpath
      },
      ResourceConfig={
          'InstanceType': 'ml.p3.2xlarge',
          'InstanceCount': 1,
          'VolumeSizeInGB': 100,
      },
      CheckpointConfig={'S3Uri':checkpoint},
      EnableManagedSpotTraining=use_spot,
      StoppingCondition={"MaxWaitTimeInSeconds": 432000,"MaxRuntimeInSeconds": 432000} if use_spot else {"MaxRuntimeInSeconds": 432000}
  )
response

In [None]:
status = sm.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))

## 5 下载训练结果

In [None]:
respone = sm.describe_training_job(TrainingJobName=job_name)
model_data = respone['ModelArtifacts']['S3ModelArtifacts']

In [None]:
!aws s3 cp {model_data} ./

In [None]:
!tar -xvf model.tar.gz -C container/local_test/model/
!rm model.tar.gz

## 6 测试

In [None]:
!docker run -v $(pwd)/container/local_test/:/opt/ml/ --rm $image_uri \
   python3 tools/infer_rec.py -c rec_chinese_common_train_v2.0.yml -o Global.pretrained_model=/opt/ml/model/latest Global.load_static_weights=false Global.use_gpu=false Global.infer_img=doc/imgs_words/ch/word_2.jpg

## 7 转化为推理格式

In [None]:
!docker run -v $(pwd)/container/local_test/:/opt/ml/ --rm $image_uri \
   python3 tools/export_model.py -c rec_chinese_common_train_v2.0.yml -o Global.pretrained_model=/opt/ml/model/latest Global.save_inference_dir=/opt/ml/model/inference/rec/

## 8 下载检测模型、方向分类器的推理模型

In [None]:
%cd container/local_test/model/inference

In [None]:
!sudo chmod a+w ./

In [None]:
# 下载中文OCR模型的检测模型并解压
!wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar
!tar xf ch_ppocr_server_v2.0_det_infer.tar
!rm ch_ppocr_server_v2.0_det_infer.tar
!mv ch_ppocr_server_v2.0_det_infer det
# 下载中文OCR模型的文本方向分类器模型并解压
!wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
!tar xf ch_ppocr_mobile_v2.0_cls_infer.tar
!rm ch_ppocr_mobile_v2.0_cls_infer.tar
!mv ch_ppocr_mobile_v2.0_cls_infer cls

In [None]:
%cd ../

In [None]:
!tar zcvf inference.tar.gz ./inference/

In [None]:
!aws s3 cp inference.tar.gz s3://$bucket/$pre_key/

In [None]:
#回到1-training
%cd ../../../

In [None]:
!echo -n s3://$bucket/$pre_key/inference.tar.gz > model_data.txt