# Note:  This requires a set of large instance type due to the amount of data.

In [None]:
!pip install -q boto3
!pip install -q xgboost==0.90

In [None]:
import boto3
import sagemaker
import pandas as pd

sess   = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

sm = boto3.Session().client(service_name='sagemaker', region_name=region)

In [None]:
%store -r spark_processing_job_s3_output_prefix

In [None]:
print('Previous Spark Processing Job Name: {}'.format(spark_processing_job_s3_output_prefix))

# Specify the S3 Location of the Features

In [None]:

# TODO:  Add `/output/` and remove `-csv` once the new spark job is done!


prefix_train = '{}/tfidf-labeled-split-balanced-noheader-train'.format(spark_processing_job_s3_output_prefix)
prefix_validation = '{}/tfidf-labeled-split-balanced-noheader-validation'.format(spark_processing_job_s3_output_prefix)
prefix_test = '{}/tfidf-labeled-split-balanced-noheader-test'.format(spark_processing_job_s3_output_prefix)

balanced_tfidf_without_header_train_path = './{}'.format(prefix_train)
balanced_tfidf_without_header_validation_path = './{}'.format(prefix_validation)
balanced_tfidf_without_header_test_path = './{}'.format(prefix_test)

balanced_tfidf_without_header_train_s3_uri = 's3://{}/{}'.format(bucket, prefix_train)
balanced_tfidf_without_header_validation_s3_uri = 's3://{}/{}'.format(bucket, prefix_validation)
balanced_tfidf_without_header_test_s3_uri = 's3://{}/{}'.format(bucket, prefix_test)

s3_input_train_data = sagemaker.s3_input(s3_data=balanced_tfidf_without_header_train_s3_uri, content_type='text/csv')
s3_input_validation_data = sagemaker.s3_input(s3_data=balanced_tfidf_without_header_validation_s3_uri, content_type='text/csv')
s3_input_test_data = sagemaker.s3_input(s3_data=balanced_tfidf_without_header_test_s3_uri, content_type='text/csv')

print(s3_input_train_data.config)
print(s3_input_validation_data.config)
print(s3_input_test_data.config)

In [None]:
!cat src/xgboost_reviews.py

In [None]:
from sagemaker.xgboost import XGBoost

model_output_path = 's3://{}/models/amazon-reviews/script-mode/training-runs'.format(bucket)

xgb_estimator = XGBoost(entry_point='xgboost_reviews.py', 
                        source_dir='src/',
                        role=role,
                        train_instance_count=3, 
#                        train_instance_type='local',
                        train_instance_type='ml.c5.8xlarge',
                        framework_version='0.90-2',
                        py_version='py3',
                        output_path=model_output_path,
                        hyperparameters={'objective':'binary:logistic',
                                         'num_round': 1,
                                         'max_depth': 5},
                        enable_cloudwatch_metrics=True
                       )

### Train the model

In [None]:
xgb_estimator.fit(inputs={'train': s3_input_train_data, 
                          'validation': s3_input_validation_data}, wait=False) 

In [None]:
training_job_name = xgb_estimator.latest_training_job.name
print('training_job_name:  {}'.format(training_job_name))

In [None]:
from sagemaker.xgboost import XGBoost

xgb_estimator = XGBoost.attach(training_job_name=training_job_name)