# Training on AWS SageMaker

In order to train model on AWS SageMakers, models need to be Dockerized and the Docker image needs to be deployed on an appropriate instance. 

Dataset is uploaded to an AWS S3 Bucket and SageMaker service needs a provided path to the data on that bucket. 

Models will be trained using GPU on ml.g4dn.xlarge instance.

To obtain all the required permissions, a new role with permission to execute SageMaker service needs to be created.
Additionally, AWS secret key id and AWS secret key need to be added to '~/.aws/credentials' file

In [1]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.image_uris import retrieve 

# Set up S3 bucket and paths to input data and model artifacts
bucket_name = ...
prefix = "sagemaker/SKU110K"

training_data_path = f's3://{bucket_name}/{prefix}'
model_artifacts_path = f's3://{bucket_name}/model-artifacts/'

# Set up SageMaker session, role, and instance types
sagemaker_session = sagemaker.Session()
train_instance_type = 'ml.g4dn.xlarge'

role = ...

Code to retrieve a Docker image used for training the model

In [None]:

training_image = retrieve("pytorch", 
                            region="eu-central-1", 
                            version="1.13.1", 
                            py_version="py39", 
                            instance_type=train_instance_type, 
                            image_scope="training")
print(training_image)


Code to invoke training job and for monitoring all logs related to the training job

In [None]:
# Define the training job
estimator = PyTorch(
    image_uri=training_image,
    source_dir="code",
    entry_point="train.py",
    role=role,
    py_version="py39",
    framework_version="1.13.1",
    instance_count=1,
    instance_type=train_instance_type,
    output_path=model_artifacts_path,
    sagemaker_session=sagemaker_session,
    hyperparameters = {'epochs': 10, 'batch-size': 2, 'model': 'Faster_RCNN', 'sagemaker': True}
)
estimator.fit({'train': f's3://{bucket_name}/{prefix}',
                'test': f's3://{bucket_name}/{prefix}'}, logs="All")