# Training the V1 model

## Setup

In [1]:
# sagemaker
import boto3
import sagemaker
from sagemaker import get_execution_role

# import a PyTorch wrapper
from sagemaker.pytorch import PyTorch

# importing PyTorchModel
from sagemaker.pytorch import PyTorchModel

In [2]:
# SageMaker session and role
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

# default S3 bucket
bucket = sagemaker_session.default_bucket()

## Data

In [4]:
# iterate through S3 objects and print contents
counter = 0
for obj in boto3.resource('s3').Bucket(bucket).objects.all():
    if counter == 10:
        break
    print(obj.key)
    counter += 1

images/train/wario/wario_land_3_img_0.jpg
images/train/wario/wario_land_3_img_10.jpg
images/train/wario/wario_land_3_img_100.jpg
images/train/wario/wario_land_3_img_1000.jpg
images/train/wario/wario_land_3_img_1001.jpg
images/train/wario/wario_land_3_img_1002.jpg
images/train/wario/wario_land_3_img_1003.jpg
images/train/wario/wario_land_3_img_1004.jpg
images/train/wario/wario_land_3_img_1005.jpg
images/train/wario/wario_land_3_img_1006.jpg


## Training

In [None]:
# specify an output path
prefix = 'model_v1'
output_path = 's3://{}/{}'.format(bucket, prefix)

# instantiate a pytorch estimator
estimator = PyTorch(entry_point='train.py',
                    source_dir='model_v1', 
                    role=role,
                    framework_version='1.0',
                    train_instance_count=1,
                    train_instance_type='ml.c4.xlarge',
                    output_path=output_path,
                    sagemaker_session=sagemaker_session,
                    hyperparameters={
                        'epochs': 80 # could change to higher
                    })

In [None]:
%%time 
# train the estimator on S3 training data
estimator.fit({'train': input_data})

## Deploying the model for inference

In [None]:



# Create a model from the trained estimator data
# And point to the prediction script
model = PyTorchModel(model_data=estimator.model_data,
                     role = role,
                     framework_version='1.0',
                     entry_point='predict.py',
                     source_dir='source_solution')

In [None]:
%%time
# deploy and create a predictor
predictor = model.deploy(initial_instance_count=1, instance_type='ml.t2.medium')