## Training with SageMaker XGBoost as Framework

When training with the SageMaker built-in XGBoost container as a framework, we provide the entry-point script as well as an optional source directory with extra modules. The two different ways to run SageMaker XGBoost are described at: https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html.

More details on training with XGBoost as a framework here:

https://sagemaker.readthedocs.io/en/stable/using_xgboost.html

The code with details on the contents of the XGBoost framework container is at:

https://github.com/aws/sagemaker-xgboost-container

When executed on SageMaker a number of helpful environment variables are available to access properties of the training environment, such as:
 
- SM_MODEL_DIR: A string representing the path to the directory to write model artifacts to. Any artifacts saved in this folder are uploaded to S3 for model hosting after the training job completes.
- SM_OUTPUT_DIR: A string representing the filesystem path to write output artifacts to. Output artifacts may include checkpoints, graphs, and other files to save, not including model artifacts. These artifacts are compressed and uploaded to S3 to the same S3 prefix as the model artifacts.

Supposing two input channels, 'train' and 'validation', were used in the call to the XGBoost estimator's fit() method, the following environment variables will be set, following the format SM_CHANNEL_[channel_name]:

- SM_CHANNEL_TRAIN: A string representing the path to the directory containing data in the 'train' channel
- SM_CHANNEL_VALIDATION: Same as above, but for the 'validation' channel.

A typical training script loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model to model_dir so that it can be hosted later. Hyperparameters are passed to your script as arguments and can be retrieved with an argparse.ArgumentParser instance.

In [None]:
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

# this will create a 'default' sagemaker bucket if it doesn't exist (sagemaker-region-accountid)
bucket = sagemaker_session.default_bucket()
print(bucket)

# Get the ARN of the IAM role used by this Studio instance to pass to training jobs and other Amazon SageMaker tasks.
role = get_execution_role()
print(role)

In [None]:
%%time

import sagemaker
import boto3
from sagemaker.xgboost.estimator import XGBoost

# set the hyperparameters
hyperparams = {
    "num_class": "3",
    "silent": "0",
    "objective": "multi:softmax",
    "num_round": "10" 
}

# build a SageMaker estimator Framework class
xgb_estimator = XGBoost(
    role=role,
    framework_version='latest',
    instance_count=1,
    instance_type='ml.m5.large',
    output_path='s3://{}/iris/output'.format(bucket),
    entry_point="./src/train_script.py", # NEW PARAMETER
    hyperparameters=hyperparams,
    sagemaker_session=sagemaker_session
)

s3_input_train = sagemaker.inputs.TrainingInput(s3_data='s3://{}/iris/data/iris_train.csv'.format(bucket), content_type='csv')
s3_input_validation = sagemaker.inputs.TrainingInput(s3_data='s3://{}/iris/data/iris_val.csv'.format(bucket), content_type='csv')

xgb_estimator.fit({
    "train": s3_input_train,
    "validation": s3_input_validation
})