In [None]:
import os
import sys
import tarfile
from six.moves import urllib
from ipywidgets import FloatProgress
from IPython.display import display

DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'


def cifar10_download(data_dir='/tmp/cifar10_data', print_progress=True):
    """Download and extract the tarball from Alex's website."""
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if os.path.exists(os.path.join(data_dir, 'cifar-10-batches-py')):
        print('cifar dataset already downloaded')
        return

    filename = DATA_URL.split('/')[-1]
    filepath = os.path.join(data_dir, filename)

    if not os.path.exists(filepath):
        f = FloatProgress(min=0, max=100)
        display(f)
        sys.stdout.write('\r>> Downloading %s ' % (filename))        

        def _progress(count, block_size, total_size):
            if print_progress:
                f.value = 100.0 * count * block_size / total_size

        filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
        print()
        statinfo = os.stat(filepath)
        print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')

    tarfile.open(filepath, 'r:gz').extractall(data_dir)

In [None]:
import os
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()
bucket = sagemaker_session.default_bucket()

In [None]:
# Download cifar10 datset
cifar10_download()

In [None]:
# Upload to S3
sagemaker_session.upload_data(path='/tmp/cifar10_data/cifar-10-batches-py', key_prefix='cifar10_data')

In [None]:
# Configure the hyperparameters from the instructor
training_image = '500842391574.dkr.ecr.us-west-2.amazonaws.com/horovod:latest'
#hosting_image = '<<TO BE TESTED>>'

# Training data channel
channels = {'training': 's3://'+bucket+'/cifar10_data'}

# Optmized training parameters
hyperparameters = {'learning_rate': .0001, 'epochs': 20, 'batch_size': 32}

# Output of trained model
output_location = "s3://{}".format(bucket)

In [None]:
from sagemaker.estimator import Estimator
# SageMaker estimator
horovod_estimator = Estimator(
    training_image,
    role=role,
    output_path=output_location,
    train_instance_count=2,
    train_instance_type='ml.p3.8xlarge',
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session
)

In [None]:
# Start training
horovod_estimator.fit(channels)