# Spleen 3D segmentation with MONAI

This tutorial shows how to run SageMaker managed training using MONAI for 3D Segmentation.

This notebook and train.py script in source folder were derived from [this notebook](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb)

Key features demonstrated here:
1. SageMaker managed training with EFS integration
2. SageMaker Hyperparameter tuning 

The Spleen dataset can be downloaded from https://registry.opendata.aws/msd/.

![spleen](http://medicaldecathlon.com/img/spleen0.png)

Target: Spleen  
Modality: CT  
Size: 61 3D volumes (41 Training + 20 Testing)  
Source: Memorial Sloan Kettering Cancer Center  
Challenge: Large ranging foreground size
    

### install and import monai libraries 

In [None]:
!pip install  "monai[all]==0.8.0"

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

import sagemaker 
from sagemaker import get_execution_role


role = get_execution_role()
sess = sagemaker.Session()
region = sess.boto_session.region_name
bucket = sess.default_bucket()

In [None]:
bucket

## Prepare the dataset: Spleen dataset
+ Download the Spleen dataset if it is not available locally
+ Transform the images using Compose 
+ Visualize the image 

In [None]:
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = "./Task09_Spleen.tar"

MONAILabelServerIP = "../Spleen3D" ## IP address of the MONAI Label Server if deployed
data_dir = MONAILabelServerIP 

if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, data_dir+'/datasets', md5)

In [None]:
## transform the images 
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),  ## keys include image and label with image first
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

In [None]:
train_images = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.show()

## Model training 

+ Divide the dataset into training and testing
+ Upload the dataset into S3 
+ SageMaker training job

In [None]:
from monai.apps import download_and_extract
import os
import glob

train_images = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-1], data_dicts[-1:]

In [None]:
## copy dataset for training 
!mkdir -p ../Spleen3D/train/imagesTr
!mkdir -p ../Spleen3D/train/labelsTr

## folder for testing dataset
!mkdir -p ../Spleen3D/test/imagesTr
!mkdir -p ../Spleen3D/test/labelsTr

In [None]:
## copy dataset for training 
for file in train_files:
    image = file['image']
    image_dest = "../Spleen3D/train/imagesTr"
    label = file['label']
    label_dest = "../Spleen3D/train/labelsTr"
    shutil.copy(image,image_dest)
    shutil.copy(label,label_dest)

In [None]:
## copy dataset for testing  
for file in val_files:
    image = file['image']
    image_dest = "../Spleen3D/test/imagesTr"
    label = file['label']
    label_dest = "../Spleen3D/test/labelsTr"
    shutil.copy(image,image_dest)
    shutil.copy(label,label_dest)

In [None]:
## upload the dataset to S3
prefix="MONAI-Segmentation"
bucket = sess.default_bucket()
## upload training dataset
S3_inputs = sess.upload_data(
    path="../Spleen3D/train",
    key_prefix=prefix+"/train",
    bucket=bucket 
)

## upload testing dataset
S3_test = sess.upload_data(
    path="../Spleen3D/test",
    key_prefix=prefix+"/test",
    bucket=bucket 
)


In [None]:
S3_inputs

### SageMaker training job

In [None]:
%time 
import sagemaker
from sagemaker.inputs import FileSystemInput
from sagemaker.pytorch import PyTorch

metrics=[
   {'Name': 'train:average epoch loss', 'Regex': 'average loss: ([0-9\\.]*)'},
   {'Name': 'train:current mean dice', 'Regex': 'current mean dice: ([0-9\\.]*)'},
   {'Name': 'train:best mean dice', 'Regex': 'best mean dice: ([0-9\\.]*)'}
]

estimator = PyTorch(source_dir='source',
                    entry_point='train.py',
                    role=role,
                    framework_version='1.6.0',
                    py_version='py3',
                    instance_count=1,
                    instance_type='ml.p2.xlarge',
                    hyperparameters={
                       "seed": 2,
                       "lr": 0.001,
                       "epochs": 10
                    },
                    metric_definitions=metrics,
#                     ### spot instance training ###
#                    use_spot_instances=True,
#                     max_run=2400,
#                     max_wait=2400
                )


estimator.fit(S3_inputs)

## Inference 

+ deploy the model with customized inference script
+ inference with testing image in S3
+ visualization the results

In [None]:
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m5.2xlarge',entry_point='inference.py',source_dir='source',
                            serializer=sagemaker.serializers.JSONSerializer(),deserializer=sagemaker.deserializers.NumpyDeserializer())

In [None]:
from sagemaker.pytorch.model import PyTorchModel

model_data="s3://sagemaker-us-east-1-741261399688/pytorch-training-2022-04-19-02-45-48-341/model.tar.gz" ## model artifact from S3

model = PyTorchModel(
    entry_point="inference.py", ## inference code with customerization
    source_dir="source",        ## folder with the inference code
    role=role,
    model_data=model_data,
    framework_version="1.5.0",
    py_version="py3",
)


predictor3 = model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.2xlarge',serializer=sagemaker.serializers.JSONSerializer(),deserializer=sagemaker.deserializers.NumpyDeserializer())

In [51]:
%time
payload={"bucket": 'sagemaker-us-east-1-741261399688',
    "key":"MONAI-Segmentation/test"}
predictor3.predict(payload)

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 7.15 µs


ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received server error (503) from primary with message "{
  "code": 503,
  "type": "InternalServerException",
  "message": "Unsupported model output data type."
}
". See https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logEventViewer:group=/aws/sagemaker/Endpoints/pytorch-inference-2022-04-19-03-56-38-186 in account 741261399688 for more information.

## Clean up the resources

In [None]:
import boto3
client = boto3.client('sagemaker')
endpoints=client.list_endpoints()['Endpoints']
endpoints

In [None]:
for endpoint in endpoints:
    response = client.delete_endpoint(
        EndpointName=endpoint['EndpointName']
    )