In [None]:
!pip install joblib

In [None]:
import pandas as pd
import boto3
import sagemaker
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

In [None]:
# session and role
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

# create an S3 bucket
bucket = sagemaker_session.default_bucket()
print(bucket)

In [None]:
# load data

data_dir = 'capstone_data'

prefix = 'capstone_project'

# already ran
# upload all data to S3
#input_data = sagemaker_session.upload_data(path=data_dir, bucket=bucket, key_prefix=prefix)

In [None]:
# confirm that data is in S3 bucket
empty_check = []
for obj in boto3.resource('s3').Bucket(bucket).objects.all():
    empty_check.append(obj.key)
    print(obj.key)

assert len(empty_check) !=0, 'S3 bucket is empty.'
print('Test passed!')

In [None]:
# create model
from sagemaker.sklearn.estimator import SKLearn
model = SKLearn(entry_point='train.py',
                source_dir='source_sklearn',
                role=role,
                train_instance_count=1, 
                train_instance_type='ml.c4.xlarge',
                sagemaker_session=sagemaker_session,
                framework_version='0.20.0',
                py_version='py3'
               )

In [None]:
%%time

# Train your estimator on S3 training data
model.fit({'train': f's3://{bucket}/{prefix}'})

In [None]:
%%time

# uncomment, if needed
# from sagemaker.pytorch import PyTorchModel


# deploy your model to create a predictor
predictor = model.deploy(initial_instance_count=1, instance_type='ml.t2.medium')

In [None]:
# test model

import os

# read in test data, assuming it is stored locally
test_lstm = pd.read_csv(os.path.join(data_dir, "test_lstm.csv"), header=None, names=None)

# labels are in the first column
test_y = test_lstm.iloc[:,0]
test_x = test_lstm.iloc[:,1:]

In [None]:
test_y_preds = predictor.predict(test_x)

In [None]:
rmse = mean_squared_error(test_lstm[0], test_y_preds, squared=False)
print(rmse)

In [None]:
test_lstm[0].plot()
test_y_preds.plot()