https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/pytorch_cnn_cifar10/pytorch_local_mode_cifar10.ipynb

## データ準備

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket() # ex. 'sagemaker-us-east-2-xxxxxxxxxxxx'

role = sagemaker.get_execution_role() # ex. arn:aws:iam::xxxxxxxxxxxx:role/service-role/...

In [None]:
# We use the sagemaker.Session.upload_data function to upload our datasets to an S3 location. 
# The return value inputs identifies the location -- we will use this later when we start the training job.
inputs = sagemaker_session.upload_data(path='../data/raw', bucket=bucket, key_prefix='data/raw/cifar10') # ex. s3://sagemaker-us-east-2-xxxxxxxxxxxx/data/raw/cifar10

## インスタンスタイプ設定

In [None]:
import subprocess

# instance_type = 'local'
instance_type = 'ml.m5.large'

if subprocess.call('nvidia-smi') == 0:
    ## Set type to GPU if one is present
    instance_type = 'local_gpu'
    
print("Instance type = " + instance_type)

## 学習コンテナ起動

default

In [None]:
from sagemaker.pytorch import PyTorch
# from datetime import datetime, timedelta, timezone

# estimator
hyper_param = {
    'workers': 2,
    'epochs':2,
    'batch_size': 4,
    'lr': 0.001,
    'momentum': 0.9,
}

cifar10_estimator = PyTorch(entry_point='../models/cifar10_sagemaker.py',
                            hyperparameters=hyper_param,
                            role=role,
                            framework_version='1.1.0',
                            train_instance_count=1,
                            train_instance_type=instance_type)

# fit
# date = datetime.now(timezone(timedelta(hours=+9), 'JST')).strftime("%Y%m%d%H%M")

cifar10_estimator.fit(
    inputs=inputs,
#     job_name='' + date,
)

hyperparameter tuning

In [None]:
# from sagemaker.pytorch import PyTorch
# from datetime import datetime, timedelta, timezone
# from sagemaker.tuner import IntegerParameter, ContinuousParameter, HyperparameterTuner

# # estimator
# hyper_param = {
#     'workers': 2,
#     'epochs':2,
#     'batch_size': 4,
#     'lr': 0.001,
#     'momentum': 0.9,
# }

# cifar10_estimator = PyTorch(entry_point='../models/cifar10_sagemaker.py',
#                             hyperparameters=hyper_param,
#                             role=role,
#                             framework_version='1.1.0',
#                             train_instance_count=1,
#                             train_instance_type=instance_type)

# # tuner
# hyperparameter_ranges = {
#     'batch_size': IntegerParameter(4, 64),
#     'lr': ContinuousParameter(1e-4, 0.1),
#     'momentum': ContinuousParameter(0.5, 0.9)
# }
# metric_definitions = [
#     {'Name': 'loss', 'Regex': 'loss: (\S+)'}
# ]

# cifar10_tuner = HyperparameterTuner(estimator=cifar10_estimator,
#                                     objective_metric_name='loss',
#                                     objective_type='Minimize',
#                                     hyperparameter_ranges=hyperparameter_ranges,
#                                     metric_definitions=metric_definitions,
#                                     max_jobs=1,
#                                     max_parallel_jobs=1,
#                                     early_stopping_type='Auto')

# # fit
# # date = datetime.now(timezone(timedelta(hours=+9), 'JST')).strftime("%Y%m%d%H%M")

# cifar10_tuner.fit(
#     inputs=inputs
# #     job_name='' + date,
# )

## 推論エンドポイント起動

Amazon SageMaker Python SDK

In [None]:
from sagemaker.pytorch import PyTorchModel

cifar10_predictor = cifar10_estimator.deploy(initial_instance_count=1,
                                             instance_type=instance_type)

AWS SDK for Python (Boto3)

In [None]:
# # TODO
# # モデル作成
# import boto3
# sm = boto3.client('sagemaker')

# training_job_name = 'sagemaker-pytorch-xxxxxxxx'
# model_name = training_job_name + '-mod'
# container = '520713654638.dkr.ecr.us-east-2.amazonaws.com/sagemaker-pytorch:1.1.0-cpu-py3'

# info = sm.describe_training_job(TrainingJobName=training_job_name)
# model_data = info['ModelArtifacts']['S3ModelArtifacts']
# print(model_data)

# primary_container = {
#     'Image': container,
#     'ModelDataUrl': model_data
# }

# create_model_response = sm.create_model(
#                             ModelName = model_name,
#                             ExecutionRoleArn = role,
#                             PrimaryContainer = primary_container
#                         )

# print(create_model_response['ModelArn'])

In [None]:
# # エンドポイント設定
# from datetime import datetime, timedelta, timezone
# date = datetime.now(timezone(timedelta(hours=+9), 'JST')).strftime("%Y%m%d%H%M")
# endpoint_config_name = 'cifar10-' + date 
# print(endpoint_config_name)

# create_endpoint_config_response = sm.create_endpoint_config(
#     EndpointConfigName = endpoint_config_name,
#     ProductionVariants=[{
#         'InstanceType':'ml.m4.xlarge',
#         'InitialVariantWeight':1,
#         'InitialInstanceCount':1,
#         'ModelName':model_name,
#         'VariantName':'AllTraffic'}])

# print("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

In [None]:
# %%time
# # エンドポイント作成
# import time

# endpoint_name = 'cifar10-' + date
# print(endpoint_name)

# create_endpoint_response = sm.create_endpoint(
#                                 EndpointName=endpoint_name,
#                                 EndpointConfigName=endpoint_config_name)
# print(create_endpoint_response['EndpointArn'])

# resp = sm.describe_endpoint(EndpointName=endpoint_name)
# status = resp['EndpointStatus']
# print("Status: " + status)

# while status=='Creating':
#     time.sleep(60)
#     resp = sm.describe_endpoint(EndpointName=endpoint_name)
#     status = resp['EndpointStatus']
#     print("Status: " + status)

# print("Arn: " + resp['EndpointArn'])
# print("Status: " + status)

## 推論実行

Amazon SageMaker Python SDK

In [None]:
# テストデータ準備
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# 画像表示準備
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
# 推論実行（サンプル）
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%4s' % classes[labels[j]] for j in range(4)))

outputs = cifar10_predictor.predict(images.numpy())

_, predicted = torch.max(torch.from_numpy(np.array(outputs)), 1)

print('Predicted: ', ' '.join('%4s' % classes[predicted[j]]
                              for j in range(4)))

In [None]:
# 推論実行（精度算出）
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = cifar10_predictor.predict(images.numpy())
        _, predicted = torch.max(torch.from_numpy(np.array(outputs)), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

AWS SDK for Python (Boto3)

In [None]:
# TODO

## 推論エンドポイント削除

Amazon SageMaker Python SDK

In [None]:
# cifar10_estimator.delete_endpoint()

AWS SDK for Python (Boto3)

In [None]:
# TODO