In [1]:
LOCAL_MODE = False

# 0. 환경설정

In [2]:
import argparse
import os
import requests
import tempfile
import subprocess, sys

import pandas as pd
import numpy as np
from glob import glob
import copy
from collections import OrderedDict
from pathlib import Path
import joblib

import logging
import logging.handlers

import json
import base64
import boto3
import sagemaker
from botocore.client import Config
from botocore.exceptions import ClientError

import time
from datetime import datetime as dt
import datetime
from pytz import timezone
from dateutil.relativedelta import *

In [3]:
# 한국 시간
KST = dt.today() + relativedelta(hours=9)
KST_aday_before = KST - relativedelta(days=1) 
yyyy, mm, dd = str(KST_aday_before.year), str(KST_aday_before.month).zfill(2), str(KST_aday_before.day).zfill(2)
print(f"Start job time: {KST}")

Start job time: 2023-03-21 12:55:18.465732


In [4]:
def get_secret():
    secret_name = "dev/ForecastPalmOilPrice"
    region_name = "ap-northeast-2"
    
    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name,
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        if e.response['Error']['Code'] == 'DecryptionFailureException': # Secrets Manager can't decrypt the protected secret text using the provided KMS key.
            raise e
        elif e.response['Error']['Code'] == 'InternalServiceErrorException': # An error occurred on the server side.
            raise e
        elif e.response['Error']['Code'] == 'InvalidParameterException': # You provided an invalid value for a parameter.
            raise e
        elif e.response['Error']['Code'] == 'InvalidRequestException': # You provided a parameter value that is not valid for the current state of the resource.
            raise e
        elif e.response['Error']['Code'] == 'ResourceNotFoundException': # We can't find the resource that you asked for.
            raise e
    else:
        if 'SecretString' in get_secret_value_response:
            secret = get_secret_value_response['SecretString']
            return secret
        else:
            decoded_binary_secret = base64.b64decode(get_secret_value_response['SecretBinary'])
            return decoded_binary_secret

keychain = json.loads(get_secret())
ACCESS_KEY_ID = keychain['AWS_ACCESS_KEY_ID']
ACCESS_SECRET_KEY = keychain['AWS_ACCESS_SECRET_KEY']

BUCKET_NAME_USECASE = keychain['PROJECT_BUCKET_NAME']
DATALAKE_BUCKET_NAME = keychain['DATALAKE_BUCKET_NAME']

S3_PATH_REUTER = keychain['S3_PATH_REUTER']
S3_PATH_WWO = keychain['S3_PATH_WWO']
S3_PATH_STAGE = keychain['S3_PATH_STAGE']
S3_PATH_GOLDEN = keychain['S3_PATH_GOLDEN']
S3_PATH_TRAIN = keychain['S3_PATH_TRAIN']
S3_PATH_FORECAST = keychain['S3_PATH_PREDICTION']

region = 'ap-northeast-2'
boto3_session = boto3.Session(aws_access_key_id = ACCESS_KEY_ID,
                              aws_secret_access_key = ACCESS_SECRET_KEY,
                              region_name = region)
sm_session = sagemaker.Session(boto_session = boto3_session)

s3_resource = boto3_session.resource('s3')
palmoil_bucket = s3_resource.Bucket(BUCKET_NAME_USECASE)
datalake_bucket = s3_resource.Bucket(DATALAKE_BUCKET_NAME)

sm_client = boto3_session.client('sagemaker')
qs_client = boto3_session.client('quicksight')
s3_client = boto3_session.client('s3')
sts_client = boto3_session.client("sts")

In [49]:
%%writefile src/v1.2/visualization.py

import argparse
import os
import requests
import tempfile
import subprocess, sys
import json

import glob
import pandas as pd
import joblib
import pickle
import tarfile
from io import StringIO, BytesIO

import logging
import logging.handlers

import time
import calendar
from datetime import datetime as dt

import boto3


###############################
######### util 함수 설정 ##########
###############################
def _get_logger():
    loglevel = logging.DEBUG
    l = logging.getLogger(__name__)
    if not l.hasHandlers():
        l.setLevel(loglevel)
        logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))        
        l.handler_set = True
    return l  
logger = _get_logger()

def get_secret():
    secret_name = "dev/ForecastPalmOilPrice"
    region_name = "ap-northeast-2"
    
    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name,
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        if e.response['Error']['Code'] == 'DecryptionFailureException': # Secrets Manager can't decrypt the protected secret text using the provided KMS key.
            raise e
        elif e.response['Error']['Code'] == 'InternalServiceErrorException': # An error occurred on the server side.
            raise e
        elif e.response['Error']['Code'] == 'InvalidParameterException': # You provided an invalid value for a parameter.
            raise e
        elif e.response['Error']['Code'] == 'InvalidRequestException': # You provided a parameter value that is not valid for the current state of the resource.
            raise e
        elif e.response['Error']['Code'] == 'ResourceNotFoundException': # We can't find the resource that you asked for.
            raise e
    else:
        if 'SecretString' in get_secret_value_response:
            secret = get_secret_value_response['SecretString']
            return secret
        else:
            decoded_binary_secret = base64.b64decode(get_secret_value_response['SecretBinary'])
            return decoded_binary_secret
        

def register_manifest(source_path,
                      target_path,
                      s3_client,
                      BUCKET_NAME_USECASE):
    template_json = {"fileLocations": [{"URIPrefixes": []}],
                     "globalUploadSettings": {
                         "format": "CSV",
                         "delimiter": ","
                     }}
    paginator = s3_client.get_paginator('list_objects_v2')
    response_iterator = paginator.paginate(Bucket = BUCKET_NAME_USECASE,
                                           Prefix = source_path.split(BUCKET_NAME_USECASE+'/')[1]
                                          )
    for page in response_iterator:
        logger.info(f"\n#### page {page}")
        for content in page['Contents']:
            template_json['fileLocations'][0]['URIPrefixes'].append(f's3://{BUCKET_NAME_USECASE}/'+content['Key'])
    with open(f'./manifest_testing.manifest', 'w') as f:
        json.dump(template_json, f, indent=2)

    res = s3_client.upload_file('./manifest_testing.manifest',
                                BUCKET_NAME_USECASE,
                                f"{target_path.split(BUCKET_NAME_USECASE+'/')[1]}/visual_validation.manifest")
    return f"{target_path.split(BUCKET_NAME_USECASE+'/')[1]}/visual_validation.manifest"
    

def refresh_of_spice_datasets(user_account_id,
                              qs_data_name,
                              manifest_file_path,
                              BUCKET_NAME_USECASE,
                              qs_client):
    
    ds_list = qs_client.list_data_sources(AwsAccountId=user_account_id)
    datasource_ids = [summary["DataSourceId"] for summary in ds_list["DataSources"] if qs_data_name in summary["Name"]]    
    for datasource_id in datasource_ids:
        response = qs_client.update_data_source(
            AwsAccountId=user_account_id,
            DataSourceId=datasource_id,
            Name=qs_data_name,
            DataSourceParameters={
                'S3Parameters': {
                    'ManifestFileLocation': {
                        'Bucket': BUCKET_NAME_USECASE,
                        'Key':  manifest_file_path
                    },
                },
            })
        logger.info(f"datasource_id:{datasource_id} 의 manifest를 업데이트: {response}")
    
    res = qs_client.list_data_sets(AwsAccountId = user_account_id)
    datasets_ids = [summary["DataSetId"] for summary in res["DataSetSummaries"] if qs_data_name in summary["Name"]]
    ingestion_ids = []

    for dataset_id in datasets_ids:
        try:
            ingestion_id = str(calendar.timegm(time.gmtime()))
            qs_client.create_ingestion(DataSetId = dataset_id,
                                       IngestionId = ingestion_id,
                                       AwsAccountId = user_account_id)
            ingestion_ids.append(ingestion_id)
        except Exception as e:
            logger.info(e)
            pass
    for ingestion_id, dataset_id in zip(ingestion_ids, datasets_ids):
        while True:
            response = qs_client.describe_ingestion(DataSetId = dataset_id,
                                                    IngestionId = ingestion_id,
                                                    AwsAccountId = user_account_id)
            if response['Ingestion']['IngestionStatus'] in ('INITIALIZED', 'QUEUED', 'RUNNING'):
                time.sleep(5)     #change sleep time according to your dataset size
            elif response['Ingestion']['IngestionStatus'] == 'COMPLETED':
                print("refresh completed. RowsIngested {0}, RowsDropped {1}, IngestionTimeInSeconds {2}, IngestionSizeInBytes {3}".format(
                    response['Ingestion']['RowInfo']['RowsIngested'],
                    response['Ingestion']['RowInfo']['RowsDropped'],
                    response['Ingestion']['IngestionTimeInSeconds'],
                    response['Ingestion']['IngestionSizeInBytes']))
                break
            else:
                logger.info("refresh failed for {0}! - status {1}".format(dataset_id,
                                                                          response['Ingestion']['IngestionStatus']))
                break
    return response
        
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_path", type=str, help='prediction_data')
    parser.add_argument("--qs_data_name", type=str, default='forecast_result')
    parser.add_argument('--model_package_group_name', type=str, default = BUCKET_NAME_USECASE)  
    return parser.parse_args()
        
if __name__=='__main__':
    ############################################
    ###### Secret Manager에서 키값 가져오기  #######
    ########################################### 
    logger.info(f"\n### Loading Key value from Secret Manager")
    keychain = json.loads(get_secret())
    ACCESS_KEY_ID = keychain['AWS_ACCESS_KEY_ID']
    ACCESS_SECRET_KEY = keychain['AWS_ACCESS_SECRET_KEY']
    BUCKET_NAME_USECASE = keychain['PROJECT_BUCKET_NAME']
    DATALAKE_BUCKET_NAME = keychain['DATALAKE_BUCKET_NAME']
    S3_PATH_REUTER = keychain['S3_PATH_REUTER']
    S3_PATH_WWO = keychain['S3_PATH_WWO']
    S3_PATH_STAGE = keychain['S3_PATH_STAGE']
    S3_PATH_GOLDEN = keychain['S3_PATH_GOLDEN']
    S3_PATH_TRAIN = keychain['S3_PATH_TRAIN']
    S3_PATH_FORECAST = keychain['S3_PATH_PREDICTION']
    
    boto3_session = boto3.Session(aws_access_key_id = ACCESS_KEY_ID,
                                  aws_secret_access_key = ACCESS_SECRET_KEY,
                                  region_name = 'ap-northeast-2')
    
    s3_client = boto3_session.client('s3')
    sm_client = boto3_session.client('sagemaker')
    qs_client = boto3_session.client('quicksight')

    sts_client = boto3_session.client("sts")
    user_account_id = sts_client.get_caller_identity()["Account"]
    ######################################
    ## 커맨드 인자, Hyperparameters 처리 ##
    ######################################  
    logger.info("######### Argument Info ####################################")
    logger.info("### start training code")    
    logger.info("### Argument Info ###")
    args = parse_args()             
    logger.info(f"args.source_path: {args.source_path}")
    logger.info(f"args.qs_data_name: {args.qs_data_name}")
    logger.info(f"args.model_package_group_name: {args.model_package_group_name}")
 
    source_path = args.source_path
    qs_data_name = args.qs_data_name    
    model_package_group_name = args.model_package_group_name
    
    target_path = source_path.rsplit('/',1)[0]+'/manifest'
    logger.info(f"\n#### target_path : {target_path}")

    logger.info(f"\n#### register_manifest")
    manifest_file_path = register_manifest(source_path, 
                                           target_path,
                                           s3_client,
                                           BUCKET_NAME_USECASE)
    logger.info(f'### manifest_file_path : {manifest_file_path}')
    logger.info(f"\n#### refresh_of_spice_datasets")
    res = refresh_of_spice_datasets(user_account_id,
                                    qs_data_name,
                                    manifest_file_path,
                                    BUCKET_NAME_USECASE,
                                    qs_client)
    logger.info(f'### refresh_of_spice_datasets : {res}')

Overwriting src/v1.2/visualization.py


In [50]:
!aws s3 cp 'src/v1.2/visualization.py' 's3://crude-palm-oil-prices-forecast/src/visualization.py' --exclude ".ipynb_checkpoints*"

upload: src/v1.2/visualization.py to s3://crude-palm-oil-prices-forecast/src/visualization.py


In [51]:
visualization_code = 's3://crude-palm-oil-prices-forecast/src/visualization.py'
%store visualization_code

Stored 'visualization_code' (str)


In [39]:
%store

Stored variables and their in-db values:
model_validation_code             -> 's3://crude-palm-oil-prices-forecast/src/model_val
prediction_code                   -> 's3://crude-palm-oil-prices-forecast/src/predictio
visualization_code                -> 's3://crude-palm-oil-prices-forecast/src/visualiza


In [40]:
%store -r

# 1. 모델 빌딩 파이프라인 의 스텝(Step) 생성
## 1) 모델 빌딩 파이프라인 변수 생성
파이프라인에서 사용할 파이프라인 파라미터를 정의합니다. 파이프라인을 스케줄하고 실행할 때 파라미터를 이용하여 실행조건을 커스마이징할 수 있습니다. 파라미터를 이용하면 파이프라인 실행시마다 매번 파이프라인 정의를 수정하지 않아도 됩니다.

지원되는 파라미터 타입은 다음과 같습니다:

- ParameterString - 파이썬 타입에서 str
- ParameterInteger - 파이썬 타입에서 int
- ParameterFloat - 파이썬 타입에서 float
이들 파라미터를 정의할 때 디폴트 값을 지정할 수 있으며 파이프라인 실행시 재지정할 수도 있습니다. 지정하는 디폴트 값은 파라미터 타입과 일치하여야 합니다.

본 노트북에서 사용하는 파라미터는 다음과 같습니다.

- processing_instance_type - 프로세싱 작업에서 사용할 ml.* 인스턴스 타입
- processing_instance_count - 프로세싱 작업에서 사용할 인스턴스 개수
- training_instance_type - 학습작업에서 사용할 ml.* 인스턴스 타입
- model_approval_status - 학습된 모델을 CI/CD를 목적으로 등록할 때의 승인 상태 (디폴트는 "PendingManualApproval")
- input_data - 입력데이터에 대한 S3 버킷 URI
파이프라인의 각 스텝에서 사용할 변수를 파라미터 변수로서 정의 합니다.

# 2. 파이프라인 정의 및 실행

In [39]:
from sagemaker.workflow.parameters import (ParameterInteger,
                                           ParameterString,
                                          )

visualization_instance_type = ParameterString(
    name = "VisualizationInstanceType",
    default_value = "ml.m5.xlarge"
)
visualization_instance_count = ParameterInteger(
    name = "VisualizationInstanceCount",
    default_value = 1
)

In [43]:
(prediction_input_path.rsplit('/',1)[0]+'/manifest').split(BUCKET_NAME_USECASE+'/')[1]

'predicted-data/2023/03/19/1679292475.0/manifest'

In [9]:
prediction_input_path = f"s3://{BUCKET_NAME_USECASE}/{S3_PATH_FORECAST}/{yyyy}/{mm}/{dd}/1679292475.0/result"
print(prediction_input_path)


# print(prediction_input_path.rsplit('/',1)[0]+'/manifest'.split(BUCKET_NAME_USECASE+'/')[1])

s3://crude-palm-oil-prices-forecast/predicted-data/2023/03/20/1679292475.0/result


## 1) 스텝정의

### (2) ScriptProcessor 진행

In [52]:
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker import get_execution_role

skframework_version = "1.0-1"#"0.23-1"
role = sagemaker.get_execution_role()

skprocessor_visualization = SKLearnProcessor(
    framework_version = skframework_version,
    instance_type = visualization_instance_type,
    instance_count = visualization_instance_count,
    base_job_name = f"{BUCKET_NAME_USECASE}(Visualization)",
    role = role,
)

The input argument instance_type of function (sagemaker.image_uris.retrieve) is a pipeline variable (<class 'sagemaker.workflow.parameters.ParameterString'>), which is not allowed. The default_value of this Parameter object will be used to override it. Please make sure the default_value is valid.


In [22]:
prediction_input_path

's3://crude-palm-oil-prices-forecast/predicted-data/2023/03/20/1679292475.0/result'

In [25]:
!aws s3 ls 's3://crude-palm-oil-prices-forecast/predicted-data/2023/03/22/1679559317.0/result/'
# s3://crude-palm-oil-prices-forecast/predicted-data/2023/03/22/1679559317.0/manifest/visual_validation.manifest

2023-03-22 23:22:16       4356 prediction_result.csv


In [26]:
prediction_input_path = 's3://crude-palm-oil-prices-forecast/predicted-data/2023/03/22/1679559317.0/result/'

In [53]:
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.workflow.steps import ProcessingStep

step_visualization = ProcessingStep(
    name = f"{BUCKET_NAME_USECASE}-Visualization",
    processor = skprocessor_visualization,
    inputs = [],
    outputs = [],
    job_arguments = ["--source_path", prediction_input_path,
                     "--qs_data_name", 'forecast_result',
                     "--model_package_group_name", BUCKET_NAME_USECASE], 
    code = visualization_code,
)

### 1) 스텝 정의

### 1) 파이프라인 실행

In [54]:
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.steps import ProcessingStep

pipeline = Pipeline(name = BUCKET_NAME_USECASE,
                    parameters = [
                        visualization_instance_type,        
                        visualization_instance_count,
                    ],
                    steps=[step_visualization]
)

In [55]:
import json

definition = json.loads(pipeline.definition())
definition

{'Version': '2020-12-01',
 'Metadata': {},
 'Parameters': [{'Name': 'VisualizationInstanceType',
   'Type': 'String',
   'DefaultValue': 'ml.m5.xlarge'},
  {'Name': 'VisualizationInstanceCount',
   'Type': 'Integer',
   'DefaultValue': 1}],
 'PipelineExperimentConfig': {'ExperimentName': {'Get': 'Execution.PipelineName'},
  'TrialName': {'Get': 'Execution.PipelineExecutionId'}},
 'Steps': [{'Name': 'crude-palm-oil-prices-forecast-Visualization',
   'Type': 'Processing',
   'Arguments': {'ProcessingResources': {'ClusterConfig': {'InstanceType': {'Get': 'Parameters.VisualizationInstanceType'},
      'InstanceCount': {'Get': 'Parameters.VisualizationInstanceCount'},
      'VolumeSizeInGB': 30}},
    'AppSpecification': {'ImageUri': '366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-cpu-py3',
     'ContainerArguments': ['--source_path',
      's3://crude-palm-oil-prices-forecast/predicted-data/2023/03/22/1679559317.0/result/',
      '--qs_data_name',
      'for

In [56]:
%%time
start = time.time()
pipeline.upsert(role_arn=sagemaker.get_execution_role())
execution = pipeline.start()
execution.wait() #실행이 완료될 때까지 기다린다.
end = time.time()

CPU times: user 587 ms, sys: 4.62 ms, total: 592 ms
Wall time: 4min 34s


In [57]:
print(f"visualization 시간 : {end - start:.1f} sec")
print(f"visualization 시간 : {((end - start)/60):.1f} min")

visualization 시간 : 274.0 sec
visualization 시간 : 4.6 min


[2022년 11월 29일]
- prediction 시간 : 423.1 sec
- prediction 시간 : 7.1 min

In [57]:
execution.describe()

{'PipelineArn': 'arn:aws:sagemaker:ap-northeast-2:108594546720:pipeline/crude-palm-oil-prices-forecast',
 'PipelineExecutionArn': 'arn:aws:sagemaker:ap-northeast-2:108594546720:pipeline/crude-palm-oil-prices-forecast/execution/03tvfcaaprlz',
 'PipelineExecutionDisplayName': 'execution-1677737728133',
 'PipelineExecutionStatus': 'Succeeded',
 'PipelineExperimentConfig': {'ExperimentName': 'crude-palm-oil-prices-forecast',
  'TrialName': '03tvfcaaprlz'},
 'CreationTime': datetime.datetime(2023, 3, 2, 6, 15, 27, 687000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2023, 3, 2, 6, 26, 57, 655000, tzinfo=tzlocal()),
 'CreatedBy': {},
 'LastModifiedBy': {},
 'ResponseMetadata': {'RequestId': '78e05f6b-51a4-40ce-9736-816ef89db047',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '78e05f6b-51a4-40ce-9736-816ef89db047',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '541',
   'date': 'Thu, 02 Mar 2023 06:28:05 GMT'},
  'RetryAttempts': 0}}

In [58]:
execution.list_steps()

[{'StepName': 'crude-palm-oil-prices-forecast-Visualization',
  'StartTime': datetime.datetime(2023, 3, 2, 6, 15, 29, 846000, tzinfo=tzlocal()),
  'EndTime': datetime.datetime(2023, 3, 2, 6, 26, 57, 79000, tzinfo=tzlocal()),
  'StepStatus': 'Succeeded',
  'AttemptCount': 0,
  'Metadata': {'ProcessingJob': {'Arn': 'arn:aws:sagemaker:ap-northeast-2:108594546720:processing-job/pipelines-03tvfcaaprlz-crude-palm-oil-price-dbgzcq8j0f'}}}]

In [None]:
model_approval_status = ParameterString(
    name="ModelApprovalStatus", default_value="PendingManualApproval"
)