# Building a custom container for PyTorch on Amazon SageMaker

In [3]:
import sagemaker
sess = sagemaker.Session()
role = sagemaker.get_execution_role()
account = sess.boto_session.client('sts').get_caller_identity()['Account']
region = sess.boto_session.region_name

ContextualVersionConflict: (requests 2.18.4 (/home/colm/anaconda3/lib/python3.6/site-packages), Requirement.parse('requests<2.21,>=2.20.0'), {'sagemaker'})

## Prepare files required to build the containers

In [None]:
!cat Dockerfile

In [None]:
# Copy Dockerfile
!cp Dockerfile build/

In [None]:
# Copy training script
!cp mnist_cnn.py build/

## Create and login to a repository in ECR


### Container settings

In [None]:
repo_name = 'pytorch' # ECR repository
image_tag = 'pytorch-py3' # ECR image tag
base_job_name = 'pytorch-mnist-cnn' # SageMaker training prefix

train_instance_type='ml.p3.2xlarge'
gpu_count=1

#train_instance_type='ml.c5.9xlarge'
#gpu_count=0

batch_size=128

%env dockerfile Dockerfile
%env account {account}
%env region {region}
%env repo_name {repo_name}
%env image_tag {image_tag}

### Create repository and login

In [None]:
%%sh

aws ecr describe-repositories --repository-names $repo_name > /dev/null 2>&1
if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name $repo_name > /dev/null
fi

$(aws ecr get-login --region $region --no-include-email)

## Build and tag Docker image

In [None]:
%cd build
!docker build -t $image_tag -f $dockerfile .
%cd ..    

In [None]:
!docker tag $image_tag $account.dkr.ecr.$region.amazonaws.com/$repo_name:latest

In [None]:
!docker images

In [None]:
# It's probably a good idea to inspect your container before pushing it :)
# !docker -it /bin/bash $CONTAINER

## Push Docker image to ECR

In [None]:
!docker push $account.dkr.ecr.$region.amazonaws.com/$repo_name:latest

## Upload MNIST data to S3

In [None]:
local_directory = 'data'
prefix          = repo_name+'/input'

train_input_path      = sess.upload_data(local_directory+'/training/',   key_prefix=prefix+'/training')
validation_input_path = sess.upload_data(local_directory+'/validation/', key_prefix=prefix+'/validation')

## Train with custom container

In [None]:
output_path = 's3://{}/{}/output'.format(sess.default_bucket(), repo_name)
image_name  = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, repo_name)

print(output_path)
print(image_name)

estimator = sagemaker.estimator.Estimator(
                       image_name=image_name,
                       base_job_name=base_job_name,
                       role=role, 
                       train_instance_count=1, 
                       train_instance_type=train_instance_type,
                       output_path=output_path,
                       sagemaker_session=sess)

estimator.set_hyperparameters(lr=0.01, epochs=10, batch_size=batch_size)

estimator.fit({'training': train_input_path, 'validation': validation_input_path})