# DeepHyperX on SageMaker--训练

## 1 说明
本章内容为用SageMaker进行训练，数据来自S3。

## 2 运行环境
Kernel 选择pytorch_latest_p36。  
本文在boto3 1.17.84和sagemaker 2.43.0下测试通过。

In [None]:
import boto3,sagemaker
print(boto3.__version__)
print(sagemaker.__version__)

## 3 在SageMaker上训练

In [None]:
# 设置数据存放S3 bucket
bucket = 'junzhong'

In [None]:
input_path='s3://{}/data/deephyper/'.format(bucket)
output_path='s3://{}/result/deephyper/'.format(bucket)

In [None]:
import boto3
iam = boto3.client('iam')
roles = iam.list_roles(PathPrefix='/service-role')
role=""
for current_role in roles["Roles"]:
    if current_role["RoleName"].startswith("AmazonSageMaker-ExecutionRole-"):
        role=current_role["Arn"]
        break
#如果role为空表示有问题，需要先打开https://cn-northwest-1.console.amazonaws.cn/sagemaker/home?region=cn-northwest-1#/notebook-instances/create以创建IAM Role
print(role)

In [None]:
from sagemaker.pytorch import PyTorch

#根据需要修改训练实例，和是否使用Spot实例
instance_type="ml.p3.2xlarge"
use_spot_instances=False

estimator = PyTorch(entry_point="main.py",
                     source_dir="./source",
                     role=role,
                     output_path=output_path,
                     framework_version='1.6.0',
                     hyperparameters={"folder":"/opt/ml/input/data/training/",
                                      "model":"he",
                                      "dataset":"leaf",
                                      "cuda":"0",
                                      "training_sample":0.7,
                                      "patch_size":17,
                                      "epoch":20,
                                      "batch_size":32}, 
                     py_version="py3",
                     instance_count=1,
                     instance_type=instance_type,
                     use_spot_instances=use_spot_instances,
                     max_wait=432000 if use_spot_instances else None,
                    )

In [None]:
estimator.fit(input_path)

In [None]:
import os
os.makedirs("result", exist_ok=True)

In [None]:
!aws s3 cp $estimator.model_data ./result

In [None]:
%%sh
cd result
tar zxvf model.tar.gz