# SageMaker PyTorch binary segmentation intro

Automatic model tuning, also known as hyperparameter tuning, finds the best version of a model by running many jobs that test a range of hyperparameters on your dataset. You choose the tunable hyperparameters, a range of values for each, and an objective metric. You choose the objective metric from the metrics that the algorithm computes. Automatic model tuning searches the hyperparameters chosen to find the combination of values that result in the model that optimizes the objective metric.

## Introduction

This notebook demonstrates the use of the [PyTorch Segmentation models with pretrained backbones](https://github.com/qubvel/segmentation_models.pytorch) - Python library with Neural Networks for Image Segmentation based on [PyTorch](https://pytorch.org/).

The main features of this library are:

- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 124 available encoders (and 500+ encoders from timm)
- All encoders have pre-trained weights for faster and better convergence
- Popular metrics and losses for training routines

This notebook shows how to use `segmentation-models-pytorch` for **binary** semantic segmentation. We will use the [The Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) (this is an adopted example from Albumentations package [docs](https://albumentations.ai/docs/examples/pytorch_semantic_segmentation/), which is strongly recommended to read, especially if you never used this package for augmentations before).

In [None]:
!pip install sagemaker -U

In [3]:
from sagemaker.pytorch import PyTorch
from datetime import datetime
import sagemaker

sess = sagemaker.Session()

# Define IAM role
import boto3
import re
from sagemaker import get_execution_role

role = get_execution_role()

print(f"Using SageMaker version {sagemaker.__version__}")
print(f"Using boto3 version {boto3.__version__}")

Using SageMaker version 2.135.1.post0
Using boto3 version 1.26.79


In [4]:
training_dataset_s3_path = "s3://aws-ml-blog/artifacts/amazon-sagemaker-binary-segmentation-intro/oxford-pet-dataset"

In [5]:
!aws s3 ls {training_dataset_s3_path}/

                           PRE annotations/
                           PRE images/
2023-03-05 13:24:33   19173078 annotations.tar.gz
2023-03-05 13:25:57  791918971 images.tar.gz


## Launching a training job with the Python SDK

In [6]:
metric_definitions = [
    {'Name': 'test_dataset_iou', 'Regex': 'test_dataset_iou: ([0-9.]+).*$'},
    {'Name': 'test_per_image_iou', 'Regex': 'test_per_image_iou: ([0-9.]+).*$'},
]

In [7]:
estimator = PyTorch(entry_point='train.py',
                        source_dir='./code',
                        role=role,
                        framework_version='1.10',
                        py_version='py38',
                        instance_count=1,
                        instance_type='ml.g5.2xlarge',
                        # keep_alive_period_in_seconds=3600,
                        metric_definitions=metric_definitions,
                        hyperparameters={
                            'epochs': 1,
                            'arch': "DeepLabV3Plus" ## Unet | FPN | DeepLabV3 | DeepLabV3Plus | Unet \ UnetPlusPlus
                        })
estimator.fit({"training": training_dataset_s3_path}, logs=True)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2023-03-05-18-22-08-076


2023-03-05 18:22:08 Starting - Starting the training job...
2023-03-05 18:22:23 Starting - Preparing the instances for training......
2023-03-05 18:23:24 Downloading - Downloading input data......
2023-03-05 18:24:34 Training - Downloading the training image...............
2023-03-05 18:26:45 Training - Training image download completed. Training in progress...[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-03-05 18:27:17,799 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-03-05 18:27:17,820 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-03-05 18:27:17,822 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2023-03-05 18:27:17,980 sagemaker-training-toolkit INFO     Installing dependencies from requirements.txt:[0m
[34m/opt/conda/bin/python3.8 -m pi

## Launching a tuning job with the Python SDK

In [8]:
hpo_estimator = PyTorch(entry_point='train.py',
                        source_dir='./code',
                        role=role,
                        framework_version='1.10',
                        py_version='py38',
                        instance_count=1,
                        instance_type='ml.g5.2xlarge',
                        metric_definitions=metric_definitions
                        )

In [9]:
from sagemaker.tuner import (
    IntegerParameter,
    CategoricalParameter,
    ContinuousParameter,
    HyperparameterTuner,
)

hyperparameter_ranges = {
    "epochs": IntegerParameter(5, 20),
    "lr": ContinuousParameter(1e-4, 1e-3),
    "arch": CategoricalParameter(["FPN", 
                                  "DeepLabV3", 
                                  "DeepLabV3Plus", 
                                  "Unet", 
                                  "UnetPlusPlus", 
                                  "Linknet", 
                                  "PSPNet", 
                                  "PAN"])
}

In [10]:
objective_metric_name = "test_dataset_iou"
objective_type = "Maximize"
hpo_metric_definitions = [
    {'Name': 'test_dataset_iou', 'Regex': 'test_dataset_iou: ([0-9.]+).*$'},
]

In [11]:
tuner = HyperparameterTuner(
    hpo_estimator,
    objective_metric_name,
    hyperparameter_ranges,
    hpo_metric_definitions,
    max_jobs=24,
    max_parallel_jobs=4,
    objective_type=objective_type,
)

In [None]:
tuner.fit({"training": training_dataset_s3_path})

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating hyperparameter tuning job with name: pytorch-training-230305-1828


......................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

## Get tuner results in a dataframe

In [21]:
import pandas as pd

full_df = tuner.analytics().dataframe()
full_df.head()

Unnamed: 0,arch,epochs,lr,TrainingJobName,TrainingJobStatus,FinalObjectiveValue,TrainingStartTime,TrainingEndTime,TrainingElapsedTimeSeconds
0,"""DeepLabV3Plus""",13.0,0.0001,pytorch-training-230305-1828-024-3e231ad7,Completed,0.910989,2023-03-05 19:13:10+00:00,2023-03-05 19:19:16+00:00,366.0
1,"""UnetPlusPlus""",20.0,0.0001,pytorch-training-230305-1828-023-8f3767c7,Completed,0.914339,2023-03-05 19:09:44+00:00,2023-03-05 19:26:09+00:00,985.0
2,"""Linknet""",16.0,0.0001,pytorch-training-230305-1828-022-29dc6377,Completed,0.907217,2023-03-05 19:06:49+00:00,2023-03-05 19:13:52+00:00,423.0
3,"""FPN""",15.0,0.000107,pytorch-training-230305-1828-021-5e508084,Completed,0.898145,2023-03-05 19:07:34+00:00,2023-03-05 19:14:17+00:00,403.0
4,"""PAN""",17.0,0.000163,pytorch-training-230305-1828-020-b19ae976,Completed,0.894409,2023-03-05 19:05:34+00:00,2023-03-05 19:12:27+00:00,413.0


In [22]:
if len(full_df) > 0:
    df = full_df[full_df["FinalObjectiveValue"] > -float("inf")]
    if len(df) > 0:
        df = df.sort_values("FinalObjectiveValue", ascending=False)
        print("Number of training jobs with valid objective: %d" % len(df))
        print({"lowest": min(df["FinalObjectiveValue"]), "highest": max(df["FinalObjectiveValue"])})
        pd.set_option("display.max_colwidth", -1)  # Don't truncate TrainingJobName
    else:
        print("No training jobs have reported valid results yet.")

df

Number of training jobs with valid objective: 24
{'lowest': 0.857784628868103, 'highest': 0.9176449179649353}


  import sys


Unnamed: 0,arch,epochs,lr,TrainingJobName,TrainingJobStatus,FinalObjectiveValue,TrainingStartTime,TrainingEndTime,TrainingElapsedTimeSeconds
23,"""FPN""",14.0,0.000132,pytorch-training-230305-1828-001-73f304c4,Completed,0.917645,2023-03-05 18:30:17+00:00,2023-03-05 18:39:37+00:00,560.0
5,"""FPN""",16.0,0.0001,pytorch-training-230305-1828-019-04bb80bf,Completed,0.917006,2023-03-05 19:02:09+00:00,2023-03-05 19:09:23+00:00,434.0
10,"""FPN""",5.0,0.0001,pytorch-training-230305-1828-014-4a53958c,Completed,0.916254,2023-03-05 18:52:48+00:00,2023-03-05 18:55:55+00:00,187.0
1,"""UnetPlusPlus""",20.0,0.0001,pytorch-training-230305-1828-023-8f3767c7,Completed,0.914339,2023-03-05 19:09:44+00:00,2023-03-05 19:26:09+00:00,985.0
8,"""FPN""",19.0,0.000103,pytorch-training-230305-1828-016-03435183,Completed,0.913673,2023-03-05 18:56:40+00:00,2023-03-05 19:04:58+00:00,498.0
11,"""FPN""",20.0,0.000208,pytorch-training-230305-1828-013-3e64cce0,Completed,0.912859,2023-03-05 18:52:05+00:00,2023-03-05 19:00:44+00:00,519.0
15,"""DeepLabV3""",10.0,0.0001,pytorch-training-230305-1828-009-6200c769,Completed,0.911724,2023-03-05 18:43:14+00:00,2023-03-05 18:51:28+00:00,494.0
0,"""DeepLabV3Plus""",13.0,0.0001,pytorch-training-230305-1828-024-3e231ad7,Completed,0.910989,2023-03-05 19:13:10+00:00,2023-03-05 19:19:16+00:00,366.0
7,"""FPN""",20.0,0.000138,pytorch-training-230305-1828-017-41b36ff0,Completed,0.907232,2023-03-05 18:57:58+00:00,2023-03-05 19:06:36+00:00,518.0
2,"""Linknet""",16.0,0.0001,pytorch-training-230305-1828-022-29dc6377,Completed,0.907217,2023-03-05 19:06:49+00:00,2023-03-05 19:13:52+00:00,423.0
