In [15]:
import os
from sagemaker.pytorch import PyTorch
from sagemaker import get_execution_role

In [16]:
# Set location of dataset, and locations to write preprocessed data and artifact
test_mode = True

if test_mode:
    dataset_location = "s3://rumc-stoic-p-sagemaker-luuk-boulogne/stoic21/test/dataset/"
    preprocessed_location = "s3://rumc-stoic-p-sagemaker-luuk-boulogne/stoic21/test/preprocessed/"
    artifact_location = "s3://rumc-stoic-p-sagemaker-luuk-boulogne/stoic21/stoic2021-training/artifact/"
else:
    dataset_location = "s3://stoic2021-training/"
    preprocessed_location = "s3://rumc-stoic-p-sagemaker-luuk-boulogne/stoic21/stoic2021-training/preprocessed/"
    artifact_location = "s3://rumc-stoic-p-sagemaker-luuk-boulogne/stoic21/stoic2021-training/artifact/"


In [None]:
# Do preprocessing

estimator = PyTorch(  #Pytorch is not actually needed for preprocessing, 
    entry_point='do_preprocess.py',  # A python file in source_dir
    source_dir=os.getcwd(),  # local path (in SageMaker Studio)
    role=get_execution_role(),
    instance_type='ml.m5.large',
    instance_count=1,
    # volume_size=500,  # Scratch volume mounted at /tmp/ in GB
    framework_version='1.12',
    py_version='py38',
    max_run=60,  # in seconds, maximum is 5 days
    checkpoint_s3_uri=preprocessed_location,  # gets pulled from to checkpoint_local_path at the start of training, and pushed to at the end
    checkpoint_local_path= '/preprocessed/',
)

estimator.fit(
    inputs={  # Each input folder passed must have contents. 
        'dataset': dataset_location,
    }  # Paths are available on the training instance in environment variables named SM_CHANNEL_<KEY>, e.g. SM_CHANNEL_IMAGES
)

In [None]:
# Do Training

estimator = PyTorch(
    entry_point='train.py',  # A python file in source_dir
    source_dir=os.getcwd(),  # local path (in SageMaker Studio)
    role=get_execution_role(),
    instance_type='ml.g4dn.xlarge',
    instance_count=1,
    volume_size=5,  # Scratch volume mounted at /tmp/ in GB
    framework_version='1.12',
    py_version='py38',
    max_run=600,  # in seconds, maximum is 5 days
    checkpoint_s3_uri=artifact_location,  # gets pulled from to checkpoint_local_path at the start of training, and pushed to at the end
    checkpoint_local_path= '/artifact/',  
)

estimator.fit(
    inputs={  # Each input folder passed must have contents. 
        'preprocessed': preprocessed_location,
    }  # Paths are available on the training instance in environment variables named SM_CHANNEL_<KEY>, e.g. SM_CHANNEL_IMAGES
)