In [1]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import HyperparameterTuner, IntegerParameter, CategoricalParameter, ContinuousParameter

import os

In [2]:
sagemaker_session = sagemaker.Session()

bucket = 'animal-familiars-mnist-test'
prefix = 'sagemaker/pytorch-mnist'

role = sagemaker.get_execution_role()

In [3]:
print(bucket, role)
print(os.getcwd())

animal-familiars-mnist-test arn:aws:iam::974431750608:role/service-role/AmazonSageMaker-ExecutionRole-20210503T160271
/home/ec2-user/SageMaker/ee148b_sagemaker_mnist_public


In [4]:
estimator = PyTorch(entry_point='main.py',
                    source_dir='./',
                    framework_version='1.8.0',
                    role=role,
                    py_version='py3',
                    instance_count=1,
                    instance_type='ml.g4dn.xlarge',
                    hyperparameters={
                        'batch-size': 256,
                    })

In [5]:
# Define hyperparameter tuning
hyperparameter_ranges = {'lr': ContinuousParameter(0.01, 0.2),
                         'epochs': IntegerParameter(10, 50)}

# TODO: should definitely use validation accuracy here, not test accuracy
# But for the sake of trying it out, we use test accuracy here
objective_metric_name = 'Test-Accuracy'
metric_definitions = [{'Name': 'Test-Accuracy',
                       'Regex': 'Test-Accuracy: ([0-9\\.]+)'},
                      {'Name': 'Test-loss',
                       'Regex': 'Test-loss: ([0-9\\.]+)'}]

In [6]:
train_data_location = f's3://{bucket}/'
test_data_location = f's3://{bucket}/'

print('train_data_location: {train_data_location}')
print('test_data_location: {test_data_location}')

train_data_location: {train_data_location}
test_data_location: {test_data_location}


In [None]:
tuner = HyperparameterTuner(estimator,
                            objective_metric_name,
                            hyperparameter_ranges,
                            metric_definitions,
                            max_jobs=3
                           )
tuner.fit({'train': train_data_location, 'test': test_data_location})

....................................................................................................................................................................................................................................................

In [None]:
# Print results
tuner.analytics()