This notebook is developed using ml.t3.medium instance with `Python 3 (Data Science)` kernel on SageMaker Studio.

Import SageMaker SDK and Create a Session

In [None]:
import boto3
import sagemaker
import time
from time import gmtime, strftime

session = sagemaker.Session()
role = sagemaker.get_execution_role()
aws_region = session.boto_region_name

# Project Bucket
bucket = session.default_bucket()
dataset_prefix = 'medical-imaging-workshop/dataset'
scaled_dataset_prefix = 'medical-imaging-workshop/scaled_dataset'
scaled_zipped_dataset_prefix = 'medical-imaging-workshop/scaled_zipped_dataset'

Define a SageMaker `PyTorch Estimator`

In [None]:
from sagemaker.pytorch import PyTorch
def get_pytorch_estimator(entry_point, hyperparameters, instance_type, 
                          instance_count, output_prefix, 
                          dist_training_config=None, volume_size=10, 
                          subnets=None, security_group_ids=None):
    pt_estimator = PyTorch(
        role=role,
        sagemaker_session=session,
        subnets=subnets,
        security_group_ids=security_group_ids,

        source_dir='src',
        entry_point=entry_point,
        hyperparameters=hyperparameters,
        py_version='py38',
        framework_version='1.12',

        instance_count=instance_count,
        instance_type=instance_type,
        volume_size=volume_size,

        enable_sagemaker_metrics=True,
        metric_definitions=metric_def,

        debugger_hook_config=False,
        disable_profiler=True,
        distribution=dist_training_config,

        code_location=f's3://{bucket}/{output_prefix}/output',
        output_path=f's3://{bucket}/{output_prefix}/output',
        max_run=432000 # Max runtime of of 5 days
    )
    
    return pt_estimator

# Training loop metrics to persist
metric_def = [
    {
        "Name": "train_loss",
        "Regex": "train_loss: (.*?)$",
    },
    {
        "Name": "average_loss",
        "Regex": "average loss: (.*?)$",
    },
    {
        "Name": "mean_dice",
        "Regex": "current mean dice: (.*?) ",
    },
    {
        "Name": "time_per_epoch",
        "Regex": "secs_time_per_epoch: (.*?)$",
    },
    {
        "Name": "dice_tc",
        "Regex": "tc: (.*?) ",
    },
    {
        "Name": "dice_wt",
        "Regex": "wt: (.*?) ",
    },
    {
        "Name": "dice_et",
        "Regex": "et: (.*?)$",
    },
]

### Single GPU Device Experiments - Original Dataset 484 training pairs (4.65 GB)

Run training for three of MONAI's dataset classes:
1. `Dataset`: standard data loading
2. `PersistentDataset`: persist processed data on disk
2. `CacheDataset`: persist processed data in CPU memory

In [None]:
training_data_on_s3 = "s3://{}/{}/Task01_BrainTumour".format(bucket, dataset_prefix)

hyperparameters = {
    'torch_dataset_type': "Dataset",
    'lr': 5e-3,
    'epochs': 10,
    'batch_size': 16,
    'num_workers': 4
}
    
    
for dataset_type in ['Dataset']: ## 'Dataset', 'PersistentDataset', 'CacheDataset'

    hyperparameters["torch_dataset_type"] = dataset_type
    
    # Instanciate a training container with pytorch image
    WORKFLOW_DATE_TIME = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
    output_prefix = "brats_ebs/{}/{}/sagemaker".format(WORKFLOW_DATE_TIME, dataset_type)
    pt_estimator = get_pytorch_estimator('single_gpu_training.py', 
                                         hyperparameters, 
                                         'ml.g5.2xlarge', 
                                         1, 
                                         output_prefix, 
                                         dist_training_config=None, 
                                         volume_size=100)


    # Luanch training job
    pt_estimator.fit(
        job_name='monai-1gpu-{}-{}'.format(dataset_type, WORKFLOW_DATE_TIME),
        inputs={'train':training_data_on_s3},
        wait=False
    )
    time.sleep(1)

#### Results:
The above runs should produce 3 training jobs. Visit the SageMaker training jobs for details on each. The `CacheDataset` run should be the fastest, followed by `PersistentDataset` and `Dataset`

### Multi GPU Device Experiment using SageMaker distributed training library (data parallel)

In [None]:
# output path: model artifacts and source code
TRAINING_JOB_DATETIME = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
output_prefix = "brats_ebs/{}/sagemaker".format(TRAINING_JOB_DATETIME)

# compute resources
instance_type = 'ml.p3.16xlarge'
instance_count = 1
world_size = instance_count * 8
num_vcpu = 64
num_workers = 16 

# network hyperparameters
hyperparameters = {'lr': 1e-4 * world_size,
                   'batch_size': 4 * world_size,
                   'epochs': 10,
                   'num_workers': num_workers
                  }

dist_config = {'smdistributed':
               {'dataparallel':{'enabled': True}}
              }

pt_estimator = get_pytorch_estimator('multi_gpu_training.py',
                                     hyperparameters,
                                     instance_type,
                                     instance_count,
                                     output_prefix,
                                     dist_training_config=dist_config,
                                    )

In [None]:
pt_estimator.fit(
    job_name='brats-{}p316-s3-{}batch-{}worker-{}'.format(instance_count,
                                                          4*world_size,
                                                          num_workers,
                                                          TRAINING_JOB_DATETIME),
    inputs={'train':training_data_on_s3},
    wait=False
)

#### Results:
The above run should produce 1 SageMaker training job which uses 8 GPUs in an `ml.p3.16xlarge` instance. Visit the SageMaker training jobs for details.