# Training and tuning a Keras CNN on Fashion-MNIST

In [None]:
# https://github.com/zalandoresearch/fashion-mnist/

In [None]:
import sagemaker
sess = sagemaker.Session()
role = sagemaker.get_execution_role()
account = sess.boto_session.client('sts').get_caller_identity()['Account']
region = sess.boto_session.region_name

## Prepare files required to build the containers

In [None]:
!cat Dockerfile.gpu

In [None]:
# Copy Dockerfiles
!cp Dockerfile.* build/

In [None]:
# Copy training script
!cp mnist_cnn.py build/

## Create and login to a repository in ECR

### GPU settings

In [None]:
repo_name = 'keras-tf-gpu' # ECR repository
image_tag = 'keras-tf-gpu-py3' # ECR image tag
base_job_name = 'keras-tf-mnist-cnn' # SageMaker training prefix

%env dockerfile Dockerfile.gpu

train_instance_type='ml.p3.2xlarge'
gpu_count=1
batch_size=128*gpu_count

### Common settings

In [None]:
%env account {account}
%env region {region}
%env repo_name {repo_name}
%env image_tag {image_tag}

### Create repository and login

In [None]:
%%sh

aws ecr describe-repositories --repository-names $repo_name > /dev/null 2>&1
if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name $repo_name > /dev/null
fi

$(aws ecr get-login --region $region --no-include-email)

## Build and tag Docker image

In [None]:
%cd build
!docker build -t $image_tag -f $dockerfile .
%cd ..    

In [None]:
!docker tag $image_tag $account.dkr.ecr.$region.amazonaws.com/$repo_name:latest

In [None]:
!docker images

In [None]:
# It's probably a good idea to inspect your container before pushing it :)
# !docker -it /bin/bash $CONTAINER

## Push Docker image to ECR

In [None]:
!docker push $account.dkr.ecr.$region.amazonaws.com/$repo_name:latest

## Upload Fashion-MNIST data to S3

In [None]:
local_directory = 'data'
prefix          = repo_name+'/input'

train_input_path      = sess.upload_data(local_directory+'/train/',      key_prefix=prefix+'/train')
validation_input_path = sess.upload_data(local_directory+'/validation/', key_prefix=prefix+'/validation')

## Train with the custom container

In [None]:
output_path = 's3://{}/{}/output'.format(sess.default_bucket(), repo_name)
image_name  = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, repo_name)

print(output_path)
print(image_name)

estimator = sagemaker.estimator.Estimator(
                       image_name=image_name,
                       base_job_name=base_job_name,
                       role=role, 
                       train_instance_count=1, 
                       train_instance_type=train_instance_type,
                       output_path=output_path,
                       sagemaker_session=sess)

In [None]:
estimator.set_hyperparameters(lr=0.1, epochs=100, gpus=gpu_count, batch_size=batch_size)

estimator.fit({'training': train_input_path, 'validation': validation_input_path})

In [None]:
from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner

In [None]:
hyperparameter_ranges = {'lr': ContinuousParameter(0.001, 0.5),
                         'filter1': IntegerParameter(16, 256),
                         'filter2': IntegerParameter(16, 256),
                         'dropout1': ContinuousParameter(0.01, 0.99),
                         'dropout2': ContinuousParameter(0.01, 0.99)
                        }

In [None]:
estimator.set_hyperparameters(epochs=100, gpus=gpu_count, batch_size=batch_size)

In [None]:
objective_metric_name = 'val_acc'
objective_type = 'Maximize'
metric_definitions = [{'Name': 'val_acc',
                       'Regex': 'val_acc: ([0-9\\.]+)'}]

In [None]:
tuner = HyperparameterTuner(estimator,
                            objective_metric_name,
                            hyperparameter_ranges,
                            metric_definitions,
                            max_jobs=50,
                            max_parallel_jobs=2,
                            objective_type=objective_type)

In [None]:
tuner.fit({'training': train_input_path, 'validation': validation_input_path})

In [None]:
import boto3

boto3.client('sagemaker').describe_hyper_parameter_tuning_job(
    HyperParameterTuningJobName=tuner.latest_tuning_job.job_name)['HyperParameterTuningJobStatus']