In [None]:
import boto3, sagemaker
from sagemaker.processing import ScriptProcessor, ProcessingOutput
from sagemaker.tensorflow import TensorFlow
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

bucket = sagemaker.session.Session().default_bucket()
role = sagemaker.get_execution_role()
print(bucket)
print(role)

## 学習データ作成
### データ作成用コンテナイメージのビルド
* Pillow を入れるためにBYOC
* SageMaker Processing で動かす

In [None]:
image_name = 'amazon-elasticache-police-generate-image'
tag = ':1'
%cd ./docker/make_image_container
!docker rmi -f $(docker images -a -q)
!docker build -t {image_name}{tag} .
%cd ../../

## イメージをECRへ

In [None]:
account_id = boto3.client('sts').get_caller_identity().get('Account')
region = boto3.session.Session().region_name
ecr_endpoint = f'{account_id}.dkr.ecr.{region}.amazonaws.com/' 
repository_uri = f'{ecr_endpoint}{image_name}'
image_uri = f'{repository_uri}{tag}'

# ECR ログイン
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {ecr_endpoint}

!docker tag {image_name}{tag} {image_uri}

# 同名のリポジトリがあった場合は削除
!aws ecr delete-repository --repository-name $image_name --force
# リポジトリを作成
!aws ecr create-repository --repository-name $image_name
# イメージをプッシュ
!docker push {image_uri}

### 学習データ作成

In [None]:
processing_input_dir = '/opt/ml/processing/input'
processing_output_dir = '/opt/ml/processing/generated_image'
job_name = 'generate-image'

In [None]:
processor = ScriptProcessor(base_job_name=job_name,
                            image_uri=image_uri,
                            command=['python3'],
                            role=role,
                            instance_count=1,
                            instance_type='ml.c5.xlarge'
                           )

In [None]:
processor.run(code='./src/generate_image.py',
              outputs=[ProcessingOutput(output_name='output',source=processing_output_dir)],
              arguments=[
                  '--output-dir',processing_output_dir,
                  '--check-names','Amazon ElastiCache/Amazon ElasticCache'
              ]
             )

### 生成した学習データ確認

In [None]:
# 学習データの S3 URI を取得
processor_description = processor.jobs[-1].describe()
generate_data_s3_uri = processor_description['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri']

In [None]:
!aws s3 cp {generate_data_s3_uri}/train_X.npy ./
!aws s3 cp {generate_data_s3_uri}/train_y.npy ./

In [None]:
train_X = np.load('./train_X.npy')
train_y = np.load('./train_y.npy')
# plt.imshow(train_X[0,:,:,0],'gray')
# print(train_y[0])

In [None]:
train_X.shape,train_y.shape

In [None]:
rows = 10
cols = 1
axes=[]
fig=plt.figure(figsize=(70,10))

for a in range(rows*cols):
    i = np.random.randint(0,train_X.shape[0])
    axes.append( fig.add_subplot(rows, cols, a+1) )
    subplot_title='alert' if train_y[i]==1 else 'No Problem'
    axes[-1].set_title(subplot_title)  
    plt.imshow(train_X[i,:,:],'gray')
fig.tight_layout()    
plt.show()

## 学習

In [None]:
estimator = TensorFlow(
    entry_point='./src/train.py',
    role=role,
    instance_count=1,
    instance_type='ml.g4dn.xlarge',
    framework_version='2.4',
    py_version='py37',
    hyperparameters={
        "epochs": 30
    },
)

In [None]:
estimator.fit({'train': generate_data_s3_uri})

In [None]:
# trainingに使用したコンテナイメージ確認
print(estimator.latest_training_job.describe()['AlgorithmSpecification']['TrainingImage'])

## 推論
### SageMaker Hosting の場合

In [None]:
predictor = estimator.deploy(
    instance_type='ml.m5.xlarge',
    initial_instance_count=1
)

#### 綴があっている画像
No Problem と出力されればモデルが正解を返している

In [None]:
true_arr=((np.array(Image.open('./test_data/AmazonElastiCache_ZenjidoJP-FeltPenLMT-TTF.ttf0.png'))-127.5)/127.5).reshape(1,50,700,1)
print('No Problem') if predictor.predict(true_arr.tolist())['predictions'][0][0] < 0.5 else print('Alert!!')
Image.open('./test_data/AmazonElastiCache_ZenjidoJP-FeltPenLMT-TTF.ttf0.png')

#### 綴が誤っている画像
Alert と出力されればモデルが正解を返している

In [None]:
false_arr = ((np.array(Image.open('./test_data/AmazonElasticCache_ZenjidoJP-FeltPenLMT-TTF.ttf0.png'))-127.5)/127.5).reshape(1,50,700,1)
print('No Problem') if predictor.predict(false_arr.tolist())['predictions'][0][0] < 0.5 else print('Alert!!')
Image.open('./test_data/AmazonElasticCache_ZenjidoJP-FeltPenLMT-TTF.ttf0.png')

In [None]:
predictor.delete_endpoint()