# 3a - Training Amazon's XGBoost

## Introduction
In this notebook, we will train Amazon's XGBoost implementation, evaluate the training performance, and output model artifacts.

## Setup

In [1]:
import json
import uuid
import boto3
import random
import sagemaker as sm

In [2]:
sm_session = sm.Session()
role = sm.get_execution_role()
boto3_session = boto3.session.Session()

In [3]:
# Get boto3 session attributes.
account = boto3_session.client("sts").get_caller_identity()["Account"]
region = boto3_session.region_name

# Create S3 resource and retrieve data bucket name.
s3_resource = boto3_session.resource("s3")
with open("/home/ec2-user/.aiml-bb/stack-data.json", "r") as f:
    data = json.load(f)
    data_bucket = data["data_bucket"]
    model_bucket = data["model_bucket"]

## Define resources for estimator

In [4]:
# Get XGBoost container image for current region.
xgb_container_image = sm.image_uris.retrieve("xgboost", region, "latest")

# Create a unique training job name.
training_job_name = f"xgboost-{str(uuid.uuid4())[:8]}"

In [12]:
train_input = sm.inputs.TrainingInput(
    s3_data=f"s3://{model_bucket}/preprocessing_output/train_data/", 
    content_type="csv"
)
validation_input = sm.inputs.TrainingInput(
    s3_data=f"s3://{model_bucket}/preprocessing_output/validate_data/", 
    content_type="csv"
)

## Create and fit estimator

In [15]:
# Create estimator running the XGBoost container.
xgb_estimator = sm.estimator.Estimator(
    xgb_container_image,
    role, 
    instance_count=4, 
    instance_type="ml.m5.4xlarge",
    output_path=f"s3://{model_bucket}/sagemaker-xgboost/"
)
# Define all hyperparameters for the model.
xgb_estimator.set_hyperparameters(
    max_depth=5,
    eta=0.2,
    gamma=4,
    min_child_weight=6,
    subsample=0.8,
    silent=0,
    objective="binary:logistic",
    num_round=100
)

In [None]:
# Fit the model.
xgb_estimator.fit({"train": train_input, "validation": validation_input})

2022-01-25 02:15:05 Starting - Starting the training job...
2022-01-25 02:15:31 Starting - Launching requested ML instancesProfilerReport-1643076905: InProgress
......
2022-01-25 02:16:31 Starting - Preparing the instances for training.......

In [None]:
# Monitor the status until completed.
job_run_status = (
    sm.describe_training_job(TrainingJobName=training_job_name)
    ["TrainingJobStatus"]
)
while job_run_status not in ("Failed", "Completed", "Stopped"):
    print(job_run_status)
    time.sleep(30)
    
    # Get status again after heartbeat.
    job_run_status = (
        sm.describe_training_job(TrainingJobName=training_job_name)
        ["TrainingJobStatus"]
    )