### Sagemaker image classification training notebook
Purpose of this notebook is to show end to end machine learning workflow
1. Use label dataset created by Sagemaker GroundTruth. Then split the dataset into train and validation. 
2. Train the model using Sagemaker training container, 
3. Optimize model using Sagemaker Neo

Block below shows how to use GroundTruth labled dataset, then split data into training and validation
### IMP: Plese change "BUCKET =" to you your S3 bucket

In [None]:
import json
import numpy as np
import boto3
import sagemaker
import time

BUCKET = 'yourbucket'
EXP_NAME = 'dinodetector' # Any valid S3 prefix.

role = sagemaker.get_execution_role()
region = boto3.session.Session().region_name
s3 = boto3.client('s3')
bucket_region = s3.head_bucket(Bucket=BUCKET)['ResponseMetadata']['HTTPHeaders']['x-amz-bucket-region']
assert bucket_region == region, "You S3 bucket {} and this notebook need to be in the same region.".format(BUCKET)

# Enter your Sagemaker GroundTruth output manifests file location
OUTPUT_MANIFEST ='s3://dino-dataset/output/dino-image-classification/manifests/output/output.manifest'
# Download manifests file to local drive
!aws s3 cp {OUTPUT_MANIFEST} 'output.manifest'

with open('output.manifest', 'r') as f:
    output = [json.loads(line) for line in f.readlines()]

# Shuffle output in place.
np.random.shuffle(output)
    
dataset_size = len(output)

# Split dataset to 80/20 ration for training and validation

train_test_split_index = round(dataset_size*0.8)

train_data = output[:train_test_split_index]
validation_data = output[train_test_split_index:]

#Create training and validation files

num_training_samples = 0
with open('train.manifest', 'w') as f:
    for line in train_data:
        f.write(json.dumps(line))
        f.write('\n')
        num_training_samples += 1
    
with open('validation.manifest', 'w') as f:
    for line in validation_data:
        f.write(json.dumps(line))
        f.write('\n')

#### Upload training and validation dataset to S3 bucket, so that this dataset can be used by Sagemaker Training jobs later


In [None]:

s3.upload_file('train.manifest',BUCKET, EXP_NAME + '/train.manifest')
s3.upload_file('validation.manifest',BUCKET, EXP_NAME + '/validation.manifest')



### Create Sagemaker training job. Change hyperparamerter per training needs

In [None]:
# Create unique job name 
nn_job_name_prefix = 'sagemaker-nvidia-webinar-demo'
timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
nn_job_name = nn_job_name_prefix + timestamp
num_classes = 6
training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'image-classification', repo_version='latest')

training_params = \
{
    "AlgorithmSpecification": {
        "TrainingImage": training_image,
        "TrainingInputMode": "Pipe"
    },
    "RoleArn": role,
    "OutputDataConfig": {
        "S3OutputPath": 's3://{}/{}/output/'.format(BUCKET, EXP_NAME)
    },
    "ResourceConfig": {
        "InstanceCount": 1,   
        "InstanceType": "ml.p3.2xlarge",
        "VolumeSizeInGB": 50
    },
    "TrainingJobName": nn_job_name,
    "HyperParameters": {
        "epochs": "30",
        "image_shape": "3,224,224",
        "learning_rate": "0.001",
        "lr_scheduler_step": "10,20",
        "mini_batch_size": "32",
        "num_classes": str(num_classes),
        "num_layers": "18",
        "num_training_samples": str(num_training_samples),
        "resize": "224",
        "use_pretrained_model": "1"
    },
    "StoppingCondition": {
        "MaxRuntimeInSeconds": 86400
    },
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile",
                    "S3Uri": 's3://{}/{}/{}'.format(BUCKET, EXP_NAME, 'train.manifest'),
                    "S3DataDistributionType": "FullyReplicated",
                    "AttributeNames": ["source-ref","dino-image-classification"]
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None"
        },
        {
            "ChannelName": "validation",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile",
                    "S3Uri": 's3://{}/{}/{}'.format(BUCKET, EXP_NAME, 'validation.manifest'),
                    "S3DataDistributionType": "FullyReplicated",
                    "AttributeNames": ["source-ref","dino-image-classification"]
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None"
        }
    ]
}

### Now we will create the SageMaker training job.

In [None]:
sagemaker_client = boto3.client('sagemaker')
sagemaker_client.create_training_job(**training_params)

# Confirm that the training job has started
print('Transform job started')
while(True):
    status = sagemaker_client.describe_training_job(TrainingJobName=nn_job_name)['TrainingJobStatus']
    if status == 'Completed':
        print("Transform job ended with status: " + status)
        break
    if status == 'Failed':
        message = response['FailureReason']
        print('Transform failed with the following error: {}'.format(message))
        raise Exception('Transform job failed') 
    time.sleep(30)

## Optimize the model specifically for the architecture
Use Sagemaker Neo to optimize and compile model 

In [None]:
output_path = 's3://{}/{}/neo_output/'.format(BUCKET, EXP_NAME)
model_path = 's3://{}/{}/output/{}/output/model.tar.gz'.format(BUCKET, EXP_NAME, nn_job_name)

print(output_path)
print(model_path)
response = sagemaker_client.create_compilation_job(
    CompilationJobName='sagemaker-nvidia-webinar',
    RoleArn=role,
    InputConfig={
        'S3Uri': model_path,
        'DataInputConfig': "{'data':[1, 3, 224, 224]}",
        'Framework': 'MXNET'
    },
    OutputConfig={
        'S3OutputLocation': output_path,
        'TargetDevice': 'jetson_nano'
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 900
    }
)

print(response)