# SageMaker Feature Store

In [None]:
# to get the latest sagemaker python sdk
!pip install -U sagemaker

In [1]:
import boto3
import sagemaker
from sagemaker.session import Session

print(sagemaker.__version__)

region = boto3.Session().region_name

boto_session = boto3.Session(region_name=region)

sagemaker_client = boto_session.client(service_name='sagemaker', region_name=region)
featurestore_runtime = boto_session.client(service_name='sagemaker-featurestore-runtime', region_name=region)

feature_store_session = Session(
    boto_session=boto_session,
    sagemaker_client=sagemaker_client,
    sagemaker_featurestore_runtime_client=featurestore_runtime
)


2.17.0


In [2]:
# You can modify the following to use a bucket of your choosing
default_s3_bucket_name = 'beyoung-sagemaker'
prefix = 'sagemaker-featurestore-demo'

print(default_s3_bucket_name)

beyoung-sagemaker


In [3]:
from sagemaker import get_execution_role

# You can modify the following to use a role of your choosing. See the documentation for how to create this.
role = get_execution_role()
print (role)

arn:aws:iam::476271697919:role/service-role/AmazonSageMaker-ExecutionRole-20200728T220473


In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import io

s3_client = boto3.client('s3', region_name=region)

fraud_detection_bucket_name = 'sagemaker-sample-files'
identity_file_key = 'datasets/tabular/fraud_detection/synthethic_fraud_detection_SA/sampled_identity.csv'
transaction_file_key = 'datasets/tabular/fraud_detection/synthethic_fraud_detection_SA/sampled_transactions.csv'

identity_data_object = s3_client.get_object(Bucket=fraud_detection_bucket_name, Key=identity_file_key)
transaction_data_object = s3_client.get_object(Bucket=fraud_detection_bucket_name, Key=transaction_file_key)

identity_data = pd.read_csv(io.BytesIO(identity_data_object['Body'].read()))
transaction_data = pd.read_csv(io.BytesIO(transaction_data_object['Body'].read()))

identity_data = identity_data.round(5)
transaction_data = transaction_data.round(5)

identity_data = identity_data.fillna(0)
transaction_data = transaction_data.fillna(0)

# Feature transformations for this dataset are applied before ingestion into FeatureStore.
# One hot encode card4, card6
encoded_card_bank = pd.get_dummies(transaction_data['card4'], prefix = 'card_bank')
encoded_card_type = pd.get_dummies(transaction_data['card6'], prefix = 'card_type')

transformed_transaction_data = pd.concat([transaction_data, encoded_card_type, encoded_card_bank], axis=1)
# blank space is not allowed in feature name
transformed_transaction_data = transformed_transaction_data.rename(columns={"card_bank_american express": "card_bank_american_express"})

In [5]:
identity_data.head()

Unnamed: 0,TransactionID,id_01,id_02,id_03,id_04,id_05,id_06,id_07,id_08,id_09,...,id_11,id_12,id_13,id_14,id_15,id_16,id_17,id_18,id_19,id_20
0,2990130,-5,38780.0,0.0,0.0,0.0,-70,0,1,100.0,...,32,80,253,241,260,125,T,F,F,T
1,2990266,-10,69246.0,0.0,0.0,0.0,-67,0,2,100.0,...,47,47,122,33,38,60,T,F,T,F
2,2992553,-45,348819.0,0.0,0.0,0.0,-73,0,0,100.0,...,21,143,268,111,2,135,F,F,T,F
3,2994568,-15,337170.0,0.0,0.0,0.0,-10,1,2,100.0,...,55,127,253,202,135,49,F,F,T,T
4,2994749,-5,680670.0,0.0,0.0,8.0,-1,2,2,100.0,...,52,43,257,7,19,254,F,F,T,T


In [6]:
transformed_transaction_data.head()

Unnamed: 0,TransactionID,isFraud,TransactionDT,TransactionAmt,card1,card2,card3,card4,card5,card6,...,N8,N9,card_type_0,card_type_credit,card_type_debit,card_bank_0,card_bank_american_express,card_bank_discover,card_bank_mastercard,card_bank_visa
0,3343087,0,8810855,29.0,12469,360.0,150.0,mastercard,126.0,debit,...,F,T,0,0,1,0,0,0,1,0
1,3307318,0,7955295,107.95,16188,178.0,150.0,mastercard,224.0,debit,...,F,T,0,0,1,0,0,0,1,0
2,3555327,0,15084339,159.95,1825,555.0,150.0,visa,226.0,debit,...,T,F,0,0,1,0,0,0,0,1
3,3310736,0,8017157,159.95,10057,225.0,150.0,mastercard,224.0,debit,...,F,F,0,0,1,0,0,0,1,0
4,3034711,0,1127470,117.0,11444,555.0,150.0,visa,226.0,debit,...,F,F,0,0,1,0,0,0,0,1


In [7]:
from time import gmtime, strftime, sleep

identity_feature_group_name = 'identity-feature-group-' + strftime('%d-%H-%M-%S', gmtime())
transaction_feature_group_name = 'transaction-feature-group-' + strftime('%d-%H-%M-%S', gmtime())

In [8]:
from sagemaker.feature_store.feature_group import FeatureGroup

identity_feature_group = FeatureGroup(name=identity_feature_group_name, sagemaker_session=feature_store_session)
transaction_feature_group = FeatureGroup(name=transaction_feature_group_name, sagemaker_session=feature_store_session)

In [9]:
identity_feature_group

FeatureGroup(name='identity-feature-group-03-15-46-06', sagemaker_session=<sagemaker.session.Session object at 0x7f9810b283d0>, feature_definitions=[])

In [10]:
import time

current_time_sec = int(round(time.time()))

def cast_object_to_string(data_frame):
    for label in data_frame.columns:
        if data_frame.dtypes[label] == 'object':
            data_frame[label] = data_frame[label].astype("str").astype("string")

# cast object dtype to string. The SageMaker FeatureStore Python SDK will then map the string dtype to String feature type.
cast_object_to_string(identity_data)
cast_object_to_string(transformed_transaction_data)

# record identifier and event time feature names
record_identifier_feature_name = "TransactionID"
event_time_feature_name = "EventTime"

# append EventTime feature
identity_data[event_time_feature_name] = pd.Series([current_time_sec]*len(identity_data), dtype="float64")
transformed_transaction_data[event_time_feature_name] = pd.Series([current_time_sec]*len(transaction_data), dtype="float64")

# load feature definitions to the feature group. SageMaker FeatureStore Python SDK will auto-detect the data schema based on input data.
identity_feature_group.load_feature_definitions(data_frame=identity_data); # output is suppressed
transaction_feature_group.load_feature_definitions(data_frame=transformed_transaction_data); # output is suppressed


In [12]:
def wait_for_feature_group_creation_complete(feature_group):
    status = feature_group.describe().get("FeatureGroupStatus")
    while status == "Creating":
        print("Waiting for Feature Group Creation")
        time.sleep(5)
        status = feature_group.describe().get("FeatureGroupStatus")
    if status != "Created":
        raise RuntimeError(f"Failed to create feature group {feature_group.name}")
    print(f"FeatureGroup {feature_group.name} successfully created.")

identity_feature_group.create(
    s3_uri=f"s3://{default_s3_bucket_name}/{prefix}",
    record_identifier_name=record_identifier_feature_name,
    event_time_feature_name=event_time_feature_name,
    role_arn=role,
    enable_online_store=True
)

transaction_feature_group.create(
    s3_uri=f"s3://{default_s3_bucket_name}/{prefix}",
    record_identifier_name=record_identifier_feature_name,
    event_time_feature_name=event_time_feature_name,
    role_arn=role,
    enable_online_store=True
)

wait_for_feature_group_creation_complete(feature_group=identity_feature_group)
wait_for_feature_group_creation_complete(feature_group=transaction_feature_group)

Waiting for Feature Group Creation
Waiting for Feature Group Creation
FeatureGroup identity-feature-group-03-15-46-06 successfully created.
FeatureGroup transaction-feature-group-03-15-46-06 successfully created.


In [13]:
identity_feature_group.describe()

{'FeatureGroupArn': 'arn:aws:sagemaker:us-west-2:476271697919:feature-group/identity-feature-group-03-15-46-06',
 'FeatureGroupName': 'identity-feature-group-03-15-46-06',
 'RecordIdentifierFeatureName': 'TransactionID',
 'EventTimeFeatureName': 'EventTime',
 'FeatureDefinitions': [{'FeatureName': 'TransactionID',
   'FeatureType': 'Integral'},
  {'FeatureName': 'id_01', 'FeatureType': 'Integral'},
  {'FeatureName': 'id_02', 'FeatureType': 'Fractional'},
  {'FeatureName': 'id_03', 'FeatureType': 'Fractional'},
  {'FeatureName': 'id_04', 'FeatureType': 'Fractional'},
  {'FeatureName': 'id_05', 'FeatureType': 'Fractional'},
  {'FeatureName': 'id_06', 'FeatureType': 'Integral'},
  {'FeatureName': 'id_07', 'FeatureType': 'Integral'},
  {'FeatureName': 'id_08', 'FeatureType': 'Integral'},
  {'FeatureName': 'id_09', 'FeatureType': 'Fractional'},
  {'FeatureName': 'id_10', 'FeatureType': 'Integral'},
  {'FeatureName': 'id_11', 'FeatureType': 'Integral'},
  {'FeatureName': 'id_12', 'FeatureTyp

In [14]:
transaction_feature_group.describe()

{'FeatureGroupArn': 'arn:aws:sagemaker:us-west-2:476271697919:feature-group/transaction-feature-group-03-15-46-06',
 'FeatureGroupName': 'transaction-feature-group-03-15-46-06',
 'RecordIdentifierFeatureName': 'TransactionID',
 'EventTimeFeatureName': 'EventTime',
 'FeatureDefinitions': [{'FeatureName': 'TransactionID',
   'FeatureType': 'Integral'},
  {'FeatureName': 'isFraud', 'FeatureType': 'Integral'},
  {'FeatureName': 'TransactionDT', 'FeatureType': 'Integral'},
  {'FeatureName': 'TransactionAmt', 'FeatureType': 'Fractional'},
  {'FeatureName': 'card1', 'FeatureType': 'Integral'},
  {'FeatureName': 'card2', 'FeatureType': 'Fractional'},
  {'FeatureName': 'card3', 'FeatureType': 'Fractional'},
  {'FeatureName': 'card4', 'FeatureType': 'String'},
  {'FeatureName': 'card5', 'FeatureType': 'Fractional'},
  {'FeatureName': 'card6', 'FeatureType': 'String'},
  {'FeatureName': 'B1', 'FeatureType': 'Integral'},
  {'FeatureName': 'B2', 'FeatureType': 'Integral'},
  {'FeatureName': 'B3', '

In [15]:
sagemaker_client.list_feature_groups() # use boto client to list FeatureGroups

{'FeatureGroupSummaries': [{'FeatureGroupName': 'transaction-feature-group-03-15-46-06',
   'FeatureGroupArn': 'arn:aws:sagemaker:us-west-2:476271697919:feature-group/transaction-feature-group-03-15-46-06',
   'CreationTime': datetime.datetime(2020, 12, 3, 15, 52, 4, 61000, tzinfo=tzlocal()),
   'FeatureGroupStatus': 'Created'},
  {'FeatureGroupName': 'identity-feature-group-03-15-46-06',
   'FeatureGroupArn': 'arn:aws:sagemaker:us-west-2:476271697919:feature-group/identity-feature-group-03-15-46-06',
   'CreationTime': datetime.datetime(2020, 12, 3, 15, 52, 0, 760000, tzinfo=tzlocal()),
   'FeatureGroupStatus': 'Created'}],
 'ResponseMetadata': {'RequestId': '3956878d-4daf-410b-868a-fba553900c7e',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '3956878d-4daf-410b-868a-fba553900c7e',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '494',
   'date': 'Thu, 03 Dec 2020 15:59:58 GMT'},
  'RetryAttempts': 0}}

In [16]:
identity_feature_group.ingest(
    data_frame=identity_data, max_workers=3, wait=True
)

IngestionManagerPandas(feature_group_name='identity-feature-group-03-15-46-06', sagemaker_session=<sagemaker.session.Session object at 0x7f9810b283d0>, data_frame=     TransactionID  id_01     id_02  id_03  id_04  id_05  id_06  id_07  id_08  \
0          2990130     -5   38780.0    0.0    0.0    0.0    -70      0      1   
1          2990266    -10   69246.0    0.0    0.0    0.0    -67      0      2   
2          2992553    -45  348819.0    0.0    0.0    0.0    -73      0      0   
3          2994568    -15  337170.0    0.0    0.0    0.0    -10      1      2   
4          2994749     -5  680670.0    0.0    0.0    8.0     -1      2      2   
..             ...    ...       ...    ...    ...    ...    ...    ...    ...   
471        3572028     -5   92780.0    0.0    0.0    0.0    -19      2      1   
472        3575285     -5   34477.0    0.0    0.0    0.0    -25      1      0   
473        3575848     -5   45284.0    0.0    0.0    3.0    -71      1      0   
474        3576043     -5  

In [17]:
transaction_feature_group.ingest(
    data_frame=transformed_transaction_data, max_workers=5, wait=True
)

IngestionManagerPandas(feature_group_name='transaction-feature-group-03-15-46-06', sagemaker_session=<sagemaker.session.Session object at 0x7f9810b283d0>, data_frame=      TransactionID  isFraud  TransactionDT  TransactionAmt  card1  card2  \
0           3343087        0        8810855          29.000  12469  360.0   
1           3307318        0        7955295         107.950  16188  178.0   
2           3555327        0       15084339         159.950   1825  555.0   
3           3310736        0        8017157         159.950  10057  225.0   
4           3034711        0        1127470         117.000  11444  555.0   
...             ...      ...            ...             ...    ...    ...   
1995        3252738        1        6443158         200.000   6019  583.0   
1996        3548960        1       14873644           6.517  10175  176.0   
1997        3319928        1        8196787          67.067  14276  177.0   
1998        3256349        1        6539367          59.000  146

In [18]:
record_identifier_value = str(2990130)

featurestore_runtime.get_record(FeatureGroupName=transaction_feature_group_name, RecordIdentifierValueAsString=record_identifier_value)


{'ResponseMetadata': {'RequestId': '84547154-d755-4547-9166-db55769aa942',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '84547154-d755-4547-9166-db55769aa942',
   'content-type': 'application/json',
   'content-length': '2636',
   'date': 'Thu, 03 Dec 2020 16:11:09 GMT'},
  'RetryAttempts': 0},
 'Record': [{'FeatureName': 'TransactionID', 'ValueAsString': '2990130'},
  {'FeatureName': 'isFraud', 'ValueAsString': '0'},
  {'FeatureName': 'TransactionDT', 'ValueAsString': '152647'},
  {'FeatureName': 'TransactionAmt', 'ValueAsString': '75.0'},
  {'FeatureName': 'card1', 'ValueAsString': '4577'},
  {'FeatureName': 'card2', 'ValueAsString': '583.0'},
  {'FeatureName': 'card3', 'ValueAsString': '150.0'},
  {'FeatureName': 'card4', 'ValueAsString': 'mastercard'},
  {'FeatureName': 'card5', 'ValueAsString': '219.0'},
  {'FeatureName': 'card6', 'ValueAsString': 'credit'},
  {'FeatureName': 'B1', 'ValueAsString': '69'},
  {'FeatureName': 'B2', 'ValueAsString': '80'},
  {'Featur

In [19]:
print(identity_feature_group.as_hive_ddl())

CREATE EXTERNAL TABLE IF NOT EXISTS sagemaker_featurestore.identity-feature-group-03-15-46-06 (
  TransactionID INT
  id_01 INT
  id_02 FLOAT
  id_03 FLOAT
  id_04 FLOAT
  id_05 FLOAT
  id_06 INT
  id_07 INT
  id_08 INT
  id_09 FLOAT
  id_10 INT
  id_11 INT
  id_12 INT
  id_13 INT
  id_14 INT
  id_15 INT
  id_16 INT
  id_17 STRING
  id_18 STRING
  id_19 STRING
  id_20 STRING
  EventTime FLOAT
)
ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
  STORED AS
  INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'
  OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'
LOCATION 's3://beyoung-sagemaker/sagemaker-featurestore-demo/476271697919/sagemaker/us-west-2/offline-store/identity-feature-group-03-15-46-06'


In [20]:
print(transaction_feature_group.as_hive_ddl())

CREATE EXTERNAL TABLE IF NOT EXISTS sagemaker_featurestore.transaction-feature-group-03-15-46-06 (
  TransactionID INT
  isFraud INT
  TransactionDT INT
  TransactionAmt FLOAT
  card1 INT
  card2 FLOAT
  card3 FLOAT
  card4 STRING
  card5 FLOAT
  card6 STRING
  B1 INT
  B2 INT
  B3 INT
  B4 INT
  B5 INT
  B6 INT
  B7 INT
  B8 INT
  B9 INT
  B10 INT
  B11 INT
  B12 INT
  F1 INT
  F2 INT
  F3 INT
  F4 INT
  F5 INT
  F6 INT
  F7 INT
  F8 INT
  F9 INT
  F10 INT
  F11 INT
  F12 INT
  F13 INT
  F14 INT
  F15 INT
  F16 INT
  F17 INT
  N1 STRING
  N2 STRING
  N3 STRING
  N4 STRING
  N5 STRING
  N6 STRING
  N7 STRING
  N8 STRING
  N9 STRING
  card_type_0 INT
  card_type_credit INT
  card_type_debit INT
  card_bank_0 INT
  card_bank_american_express INT
  card_bank_discover INT
  card_bank_mastercard INT
  card_bank_visa INT
  EventTime FLOAT
)
ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
  STORED AS
  INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'
  OU

In [21]:
account_id = boto3.client('sts').get_caller_identity()["Account"]

identity_feature_group_s3_prefix = prefix + '/' + account_id + '/sagemaker/' + region + '/offline-store/' + identity_feature_group_name + '/data'
transaction_feature_group_s3_prefix = prefix + '/' + account_id + '/sagemaker/' + region + '/offline-store/' + transaction_feature_group_name + '/data'

offline_store_contents = None
while (offline_store_contents is None):
    objects_in_bucket = s3_client.list_objects(Bucket=default_s3_bucket_name,Prefix=transaction_feature_group_s3_prefix)
    if ('Contents' in objects_in_bucket and len(objects_in_bucket['Contents']) > 1):
        offline_store_contents = objects_in_bucket['Contents']
    else:
        print('Waiting for data in offline store...\n')
        sleep(60)

print('Data available.')

Waiting for data in offline store...

Waiting for data in offline store...

Waiting for data in offline store...

Waiting for data in offline store...

Data available.


# Build Training Dataset

In [22]:
identity_query = identity_feature_group.athena_query()
transaction_query = transaction_feature_group.athena_query()

identity_table = identity_query.table_name
transaction_table = transaction_query.table_name

query_string = 'SELECT * FROM "'+transaction_table+'" LEFT JOIN "'+identity_table+'" ON "'+transaction_table+'".transactionid = "'+identity_table+'".transactionid'
print('Running ' + query_string)

# run Athena query. The output is loaded to a Pandas dataframe.
dataset = pd.DataFrame()
identity_query.run(query_string=query_string, output_location='s3://'+default_s3_bucket_name+'/'+prefix+'/query_results/')
identity_query.wait()
dataset = identity_query.as_dataframe()

dataset

Running SELECT * FROM "transaction-feature-group-03-15-46-06-1607010724" LEFT JOIN "identity-feature-group-03-15-46-06-1607010720" ON "transaction-feature-group-03-15-46-06-1607010724".transactionid = "identity-feature-group-03-15-46-06-1607010720".transactionid


Unnamed: 0,transactionid,isfraud,transactiondt,transactionamt,card1,card2,card3,card4,card5,card6,...,id_15,id_16,id_17,id_18,id_19,id_20,eventtime.1,write_time.1,api_invocation_time.1,is_deleted.1
0,3207251,0,5173433,58.43,8452,524.0,150.0,visa,226.0,debit,...,,,,,,,,,,
1,3269912,0,6905286,500.00,3895,399.0,150.0,american express,118.0,credit,...,109.0,196.0,T,F,T,F,1.607010e+09,2020-12-03 16:15:47.141,2020-12-03 16:10:27.000,False
2,3006216,0,506631,55.00,14649,548.0,150.0,visa,226.0,debit,...,,,,,,,,,,
3,3142951,0,3251537,77.00,7826,481.0,150.0,mastercard,224.0,debit,...,,,,,,,,,,
4,3247321,0,6234414,67.95,7664,490.0,150.0,visa,226.0,debit,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,3268027,0,6835786,107.95,13979,474.0,150.0,visa,226.0,credit,...,,,,,,,,,,
1996,3512673,0,13816421,117.00,12932,361.0,150.0,visa,226.0,debit,...,,,,,,,,,,
1997,3509132,0,13718899,77.95,2884,490.0,150.0,visa,226.0,debit,...,,,,,,,,,,
1998,3124196,0,2762595,100.00,3682,264.0,150.0,visa,162.0,credit,...,86.0,78.0,T,F,F,F,1.607010e+09,2020-12-03 16:15:17.807,2020-12-03 16:10:26.000,False


In [23]:
# Prepare query results for training.
query_execution = identity_query.get_query_execution()
query_result = 's3://'+default_s3_bucket_name+'/'+prefix+'/query_results/'+query_execution['QueryExecution']['QueryExecutionId']+'.csv'
print(query_result)

# Select useful columns for training with target column as the first.
dataset = dataset[["isfraud", "transactiondt", "transactionamt", "card1", "card2", "card3", "card5", "card_type_credit", "card_type_debit", "card_bank_american_express", "card_bank_discover", "card_bank_mastercard", "card_bank_visa", "id_01", "id_02", "id_03", "id_04", "id_05"]]

# Write to csv in S3 without headers and index column.
dataset.to_csv('dataset.csv', header=False, index=False)
s3_client.upload_file('dataset.csv', default_s3_bucket_name, prefix+'/training_input/dataset.csv')
dataset_uri_prefix = 's3://'+default_s3_bucket_name+'/'+prefix+'/training_input/';

dataset

s3://beyoung-sagemaker/sagemaker-featurestore-demo/query_results/25423aba-26f5-4940-8729-d38a0a726818.csv


Unnamed: 0,isfraud,transactiondt,transactionamt,card1,card2,card3,card5,card_type_credit,card_type_debit,card_bank_american_express,card_bank_discover,card_bank_mastercard,card_bank_visa,id_01,id_02,id_03,id_04,id_05
0,0,5173433,58.43,8452,524.0,150.0,226.0,0,1,0,0,0,1,,,,,
1,0,6905286,500.00,3895,399.0,150.0,118.0,1,0,1,0,0,0,-5.0,41726.0,0.0,0.0,0.0
2,0,506631,55.00,14649,548.0,150.0,226.0,0,1,0,0,0,1,,,,,
3,0,3251537,77.00,7826,481.0,150.0,224.0,0,1,0,0,1,0,,,,,
4,0,6234414,67.95,7664,490.0,150.0,226.0,0,1,0,0,0,1,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,0,6835786,107.95,13979,474.0,150.0,226.0,1,0,0,0,0,1,,,,,
1996,0,13816421,117.00,12932,361.0,150.0,226.0,0,1,0,0,0,1,,,,,
1997,0,13718899,77.95,2884,490.0,150.0,226.0,0,1,0,0,0,1,,,,,
1998,0,2762595,100.00,3682,264.0,150.0,162.0,1,0,0,0,0,1,-10.0,138748.0,0.0,0.0,3.0


# Train and Deploy the Model

In [30]:
training_output_path = 's3://'+default_s3_bucket_name+'/'+prefix+'/training_output'
training_image = '257758044811.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3'

In [31]:
from sagemaker.estimator import Estimator

training_model = Estimator(training_image,
                           role,
                           instance_count=1,
                           instance_type='ml.m5.2xlarge',
                           volume_size = 5,
                           max_run = 3600,
                           input_mode= 'File',
                           output_path=training_output_path,
                           sagemaker_session=feature_store_session)

In [32]:
training_model.set_hyperparameters(objective = "binary:logistic",
                                   num_round = 50)

In [33]:
import sagemaker.inputs

train_data = sagemaker.inputs.TrainingInput(dataset_uri_prefix, distribution='FullyReplicated',
                                            content_type='text/csv', s3_data_type='S3Prefix')
data_channels = {'train': train_data}

In [None]:
training_model.fit(inputs=data_channels, logs=True)