# Train an object detection model using Ground Truth labels
At this stage, you have fully labeled your dataset and you can train a machine learning model to perform object detection. You'll do so using the **augmented manifest** output of your labeling job - no additional file translation or manipulation required! For a more complete description of the augmented manifest, see our other [example notebook](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/ground_truth_labeling_jobs/object_detection_augmented_manifest_training/object_detection_augmented_manifest_training.ipynb).

**NOTE:** Object detection is a complex task, and training neural networks to high accuracy requires large datasets and careful hyperparameter tuning. The following cells illustrate how to train a neural network using a Ground Truth output augmented manifest, and how to interpret the results. However, you shouldn't expect a network trained on 10 or 1000 images to do a great job on unseen images!

## Settings
Be sure to modify the below with your user name.

In [None]:
USER = <YOUR USER NAME IN QUOTES> # example 'whitefish'
BUCKET = <LAB BUCKET NAME in QUOTES>  # example 'grr.amazon.com-lab'

# these settings let you use example data in case your previous lab was not complete
DATA_BUCKET = BUCKET # 'grr.amazon.com-lab' 
DATA_USER = 'example'
S3_BASE_OVERRIDE = None #'s3://grr.amazon.com-lab/labs/groundtruth/output/all-bird-labels'
EPOCHS = 1
MINIBATCH = 1

Now import the required libs. You use the `%` magic directive to enable inline plots.

In [None]:
%matplotlib inline

import json, boto3, sagemaker, re, time, sys, glob, os
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker import get_execution_role
import numpy as np
from time import gmtime, strftime
import importlib, helper
importlib.reload(helper)
from helper import training_status, visualize_detection

## Prepare Data
First, you willl split your augmented manifest into a training set and a validation set using an 80/20 split and save the results to files that the model will use during training. To do this, the output manifest is read to get a list of all the images that are labeled. Then 

In [None]:
if S3_BASE_OVERRIDE:
    S3_BASE = S3_BASE_OVERRIDE  
else:
    S3_BASE =  's3://{}/labs/groundtruth/output/{}'.format(DATA_BUCKET, DATA_USER)
    
print(S3_BASE)

In [None]:
output_manifest = S3_BASE + '/manifests/output/output.manifest'
!aws s3 cp {output_manifest} temp/output.manifest
with open('temp/output.manifest', 'r') as f:
    output = [json.loads(line.strip()) for line in f.readlines()]
print('read output manifest '+output_manifest)

# Retrieve the worker annotations.
worker_reponse = S3_BASE + '/annotations/worker-response'
!aws s3 cp {worker_reponse} temp/od_output_data/worker-response --recursive --quiet

# Find the worker files.
worker_file_names = glob.glob(
    'temp/od_output_data/worker-response/**/*.json', recursive=True)
with open('temp/output.manifest', 'r') as f:
    output = [json.loads(line) for line in f.readlines()]

# Shuffle output in place.
np.random.shuffle(output)
    
dataset_size = len(output)
train_test_split_index = round(dataset_size*0.8)

train_data = output[:train_test_split_index]
validation_data = output[train_test_split_index:]

num_training_samples = 0
with open('temp/train.manifest', 'w') as f:
    for line in train_data:
        f.write(json.dumps(line))
        f.write('\n')
        num_training_samples += 1
print('created training manifest')
    
with open('temp/validation.manifest', 'w') as f:
    for line in validation_data:
        f.write(json.dumps(line))
        f.write('\n')
print('created training manifest')

Next, we'll upload these manifest files to the previously defined S3 bucket so that they can be used in the training job.

In [None]:
s3_prefix = 'labs/objectdetect/{}'.format(USER)
s3_base_path = 's3://{}/{}'.format(BUCKET, s3_prefix)

s3_output_path = s3_base_path + '/output'
s3_train_data_path = s3_base_path + '/train.manifest'
s3_validation_data_path = s3_base_path + '/validation.manifest'

!aws s3 cp temp/train.manifest {s3_train_data_path}
!aws s3 cp temp/validation.manifest {s3_validation_data_path}
print('uploaded manifests to s3')

## Initial Training
Now that you are done with all the setup that is needed, you are ready to train your object detector. To begin, create a sageMaker.estimator.Estimator object. This estimator will launch the training job.

In [None]:
# get required objects and parameters
region = boto3.Session().region_name
role = get_execution_role()
training_image = get_image_uri(region, 'object-detection', repo_version='latest')
session = sagemaker.Session()

# create an estimator
estimator = sagemaker.estimator.Estimator(
    training_image,
    role,
    train_instance_count=1,
    train_instance_type='ml.p2.xlarge',
    train_volume_size=50,
    train_max_run=360000,
    input_mode='Pipe',
    output_path=s3_output_path,
    sagemaker_session = session
)
print('created estimator')

The object detection algorithm at its core is the Single-Shot Multi-Box detection algorithm (SSD). This algorithm uses a base_network, which is typically a VGG or a ResNet. The Amazon SageMaker object detection algorithm supports VGG-16 and ResNet-50 now. It also has a lot of options for hyperparameters that help configure the training job. The next step in your training, is to setup these hyperparameters and data channels for training the model. Consider the following example definition of hyperparameters. See the SageMaker Object Detection documentation for more details on the hyperparameters.

One of the hyperparameters here for instance is the epochs. This defines how many passes of the dataset you iterate over and determines that training time of the algorithm. In this example, you train the model for 5 epochs to generate a basic model for the PASCAL VOC dataset.

In [None]:
# NB. These hyperparameters are at the user's discretion and are beyond the scope of this demo.
hyperparameters = {  
            "base_network": "resnet-50",
            "use_pretrained_model": "1",
            "num_classes": "1",
            "mini_batch_size": MINIBATCH,
            "epochs": EPOCHS,
            "learning_rate": "0.001",
            "lr_scheduler_step": "",
            "lr_scheduler_factor": "0.1",
            "optimizer": "sgd",
            "momentum": "0.9",
            "weight_decay": "0.0005",
            "overlap_threshold": "0.5",
            "nms_threshold": "0.45",
            "image_shape": "300",
            "label_width": "350",
            "num_training_samples": str(num_training_samples)
        }

estimator.set_hyperparameters(**hyperparameters)
print('set hyperparameters')

Now that the hyperparameters are setup, prepare the handshake between your data channels and the algorithm. To do this, you need to create the sagemaker.session.s3_input objects from your data channels. These objects are then put in a simple dictionary, which the algorithm consumes.

In [None]:
def create_data_input(uri):
    return sagemaker.session.s3_input(
        uri, 
        s3_data_type='AugmentedManifestFile',
        distribution='FullyReplicated', 
        content_type='application/x-recordio',
        record_wrapping='RecordIO',
        attribute_names = ['source-ref', DATA_USER]
    )

data_channels = {
    'train': create_data_input(s3_train_data_path), 
    'validation': create_data_input(s3_validation_data_path)
}
print('created data channels')

Now that you have the Estimator object, you have to set the hyperparameters for the Estimator and  have the data channels linked with the algorithm. The only remaining thing to do is to train the algorithm. The following command will train the algorithm. Training the algorithm involves a few steps. Firstly, the instances that you requested while creating the Estimator classes are provisioned and are setup with the appropriate libraries. Then, the data from your channels are downloaded into the instance. Once this is done, the training job begins. The provisioning and data downloading will take time, depending on the size of the data. Therefore it might be a few minutes before you start getting data logs for your training jobs. The data logs will also print out Mean Average Precision (mAP) on the validation data, among other losses, for every run of the dataset once or one epoch. This metric is a proxy for the quality of the algorithm.

Once the job has finished a "Training job completed" message will be printed. The trained model can be found in the S3 bucket that was setup as output_path in the estimator.

In [None]:
# create unique job name
job_name = USER + '-od-' + time.strftime('%Y%m%d%H%M%S', time.gmtime())
print('using job name '+job_name)
# start training
estimator.fit(inputs=data_channels, wait=False, logs=True, job_name=job_name)
print('training job launched')

To check the progess of the training job, you can repeatedly evaluate the following cell. When the training job status reads `'Completed'`, move on to the next part of the tutorial.

In [None]:
training_status(job_name)

## Hosting

Once the training is done, you can deploy the trained model as an Amazon SageMaker real-time hosted endpoint. This will allow you to make predictions (or inference) from the model. Note that you don't have to host on the same instance (or type of instance) that you used to train. Training is a prolonged and compute heavy job that require a different of compute and memory requirements that hosting typically do not. You can choose any type of instance you want to host the model. In your case you chose the `ml.p3.2xlarge` instance to train, but you choose to host the model on the less expensive cpu instance, `ml.m4.xlarge`. The endpoint deployment can be accomplished as follows:

In [None]:
object_detector = estimator.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')

## Inference

Now that the trained model is deployed at an endpoint that is up-and-running, you can use this endpoint for inference. To do this, you will download the previously created validation set.

In [None]:
file_names = []
gt_bboxes = []
s3 = boto3.resource('s3')
for file_data in validation_data:
    file_path = file_data['source-ref']
    key = file_path.split(BUCKET+'/')[-1] # get the key portion of uri  
    dest = 'od_output_data/'+file_path.split('/')[-1]

    file_names.append(dest)
    s3.Bucket(BUCKET).download_file(key, dest) # download file
    
    if 'annotations' in file_data[DATA_USER]:
        annotations = file_data[DATA_USER]['annotations']
        bboxes = []
        for a in annotations:
            bboxes.append([int(a['left']), int(a['top']), int(a['width']), int(a['height'])]) 
        gt_bboxes.append(bboxes)

Now, use your endpoint to try to detect objects within this image. Since the image is jpeg, you use the appropriate content_type to run the prediction job. The endpoint returns a JSON file that you can simply load and peek into.

In [None]:
for file_name in file_names:
    with open(file_name, 'rb') as image:
        f = image.read()
        b = bytearray(f)
        results = object_detector.predict(b)
        detections = json.loads(results)

The results are in a format that is similar to the .lst format with an addition of a confidence score for each detected object. The format of the output can be represented as `[class_index, confidence_score, xmin, ymin, xmax, ymax]`. Typically, you don't consider low-confidence predictions.

You have provided additional script to easily visualize the detection outputs. You can visualize the high-confidence predictions with bounding box by filtering out low-confidence detections using the script below:

In [None]:
object_categories = ['bird']

# Setting a threshold 0.20 will only plot 
# detection results that have a confidence score greater than 0.20.
# adjust this value until you only see a few colored bounding boxes
threshold = 0.55

# Visualize the detections
max_to_display = 10
for file_name, bboxes in zip(file_names,gt_bboxes):
    visualize_detection(file_name, detections['prediction'], object_categories, threshold, bboxes)
    max_to_display -= 1
    if max_to_display == 0:
        break

No suprise, a model trained on 10 images is not doing so great compared to the human labels. Also, for the sake of this quick training, you trained the model with only one epoch. To achieve better detection results, you could try to tune the hyperparameters and train the model for more epochs with a larger dataset. 

### Delete the Endpoint

Having an endpoint running will incur some costs. Therefore as a clean-up job, you should delete the endpoint.

In [None]:
sagemaker.Session().delete_endpoint(object_detector.endpoint)