# Prepare session

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.local import LocalSession
import s3fs
import subprocess
from sagemaker.s3 import S3Downloader, S3Uploader
from pathlib import Path
import json

image_name = "sagemaker-test"
ecr_namespace = image_name + "/"
default_bucket = "prod-test"
default_uri = "s3://" + default_bucket
atf_s3_uri = default_uri + "/sagemaker"

role = get_execution_role()
account_id = role.split(":")[4]
boto_session = boto3.Session()
region = boto_session.region_name
bucket = default_bucket

sagemaker_session = sagemaker.Session(
    boto_session=boto_session,
    default_bucket=default_bucket
)
    
s3_helper = s3fs.S3FileSystem()
data_location_uri = default_uri + "/training_data/full"

print(account_id)
print(region)
print(role)
print(sagemaker_session)
print(default_uri)
print(atf_s3_uri)
print(data_location_uri)

# Dev in real

## Build and push image

In [None]:
ecr_client = boto3.client('ecr')
response = ecr_client.describe_images(
    repositoryName=image_name,
    imageIds=[{'imageTag': 'latest'}],
)
str(response["imageDetails"][0]["imagePushedAt"])

## Define parameters

In [None]:
opt_ml_dir = "/opt/ml/processing"
execution_id = "exp-real-sm"
image_uri = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{image_name}:latest"
print(image_uri)

In [None]:
processing_instance_count = 1
processing_instance_type = "ml.m5.2xlarge"
training_instance_type = "ml.m5.2xlarge"

## Prepare data

In [None]:
train_data_uri = atf_s3_uri + f"/prepared_data/{execution_id}/train"
test_data_uri = atf_s3_uri + f"/prepared_data/{execution_id}/test"
! aws s3 ls $train_data_uri/
! aws s3 ls $test_data_uri/

## Train

In [None]:
! pip install stepfunctions

In [None]:
import stepfunctions
from stepfunctions import steps
from stepfunctions.inputs import ExecutionInput
from stepfunctions.workflow import Workflow

input_data = {
    "TrainingUri": f"{train_data_uri}",
    "TestUri": f"{test_data_uri}",
    "BaselineUri": f"{train_data_uri}/train/train.csv",
}
hyperparameters = {"learning_rate": 0.05}

ecr_image_name = "sagemaker-test"
image_detail = {
    "ImageUri": f"{account_id}.dkr.ecr.{region}.amazonaws.com/{ecr_image_name}:latest",
}

output_data = {
    "ModelOutputUri": atf_s3_uri + "/model",
}

execution_input = ExecutionInput(
    schema={
        "PreprocessingJobName": str,
        "TrainingJobName": str,
        "EvaluationProcessingJobName": str,
    }
)
input_data

## Test evaluation step

In [None]:
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor

model_data_s3_uri = "s3://prod-test/sagemaker/model/exp-step-functions-979a5224513111ecad540a957bd2374c/output/model.tar.gz"
MODELEVALUATION_SCRIPT_LOCATION = "./container/code/evaluate.py"

input_evaluation_code = sagemaker_session.upload_data(
    MODELEVALUATION_SCRIPT_LOCATION,
    bucket=default_bucket,
    key_prefix=f"sagemaker/evaluation/{execution_id}",
)

model_evaluation_processor = ScriptProcessor(
    image_uri=image_uri,
    command=["python"],
    role=role,
    instance_count=processing_instance_count,
    instance_type=processing_instance_type,
    max_runtime_in_seconds=1200,
)
processing_evaluation_step = steps.ProcessingStep(
    "SageMaker Processing Model Evaluation step",
    processor=model_evaluation_processor,
    job_name=execution_input["EvaluationProcessingJobName"],
    inputs=[
        ProcessingInput(
            input_name="input-1",
            source=model_data_s3_uri,
            destination=opt_ml_dir + "/model",
        ),
        ProcessingInput(
            input_name="input-2",
            source=test_data_uri,
            destination=opt_ml_dir + "/test",
        ),
        ProcessingInput(
            input_name="code",
            source=input_evaluation_code,
            destination=opt_ml_dir + "/input/code",
        ),
    ],
    outputs=[
        ProcessingOutput(
            output_name="evaluation",
            source=opt_ml_dir + "/evaluation",
            destination=atf_s3_uri + f"/evaluation/{execution_id}",
        ),
    ],
    container_entrypoint=["python", "/opt/ml/processing/input/code/evaluate.py"],
)

In [None]:
failed_state_sagemaker_processing_failure = stepfunctions.steps.states.Fail(
    "ML Workflow failed", cause="SageMakerProcessingJobFailed"
)
catch_state_processing = stepfunctions.steps.states.Catch(
    error_equals=["States.TaskFailed"],
    next_step=failed_state_sagemaker_processing_failure,
)

processing_evaluation_step.add_catch(catch_state_processing)
# training_step.add_catch(catch_state_processing)

In [None]:
import uuid
evaluation_job_name = "exp-step-functions-evaluation-4a7f963e513611ec9c140a957bd2374c"
# evaluation_job_name = "exp-step-functions-evaluation-{}".format(
#     uuid.uuid1().hex
# )  # Each Evaluation Job requires a unique name

In [None]:
from sagemaker.s3 import S3Downloader
import json

sm_client = boto3.client("sagemaker")
def lambda_handler(event, context):
    if "EvaluationProcessingJobName" in event:
        job_name = event["EvaluationProcessingJobName"]
    else:
        raise KeyError("EvaluationProcessingJobName not found for event: {}.".format(json.dumps(event)))
    
    # Get the processing job
    response = sm_client.describe_processing_job(ProcessingJobName=job_name)
    status = response["ProcessingJobStatus"]
#     logger.info("Processing job:{} has status:{}".format(job_name, status))
    
    # Get the metrics as a dictionary
    evaluation_output_config = response["ProcessingOutputConfig"]
    for output in evaluation_output_config["Outputs"]:
        if output["OutputName"] == "evaluation":
            evaluation_s3_uri = "{}/{}".format(output["S3Output"]["S3Uri"], "eval.json")
            break

    evaluation_output = S3Downloader.read_file(evaluation_s3_uri)
    evaluation_output_dict = json.loads(evaluation_output)
    return {
        "statusCode": 200,
        "results": {
            "TrainingJobName": job_name,
            "TrainingJobStatus": status,
            "TrainingMetrics": evaluation_output_dict,
        },
    }

lambda_handler({"EvaluationProcessingJobName": evaluation_job_name}, None)

In [None]:
training_query_step = steps.compute.LambdaStep(
    "Query Training Results",
    parameters={
        "FunctionName": query_training_function_name,
        "Payload": {"EvaluationProcessingJobName.$": "$.EvaluationProcessingJobName"},
    },
    result_path="$.QueryTrainingResults",
)

check_accuracy_fail_step = steps.states.Fail(
    "Model Error Too Low", comment="RMSE accuracy higher than threshold"
)

check_accuracy_succeed_step = steps.states.Succeed("Model Error Acceptable")

# TODO: Update query method to query validation error using better result path
threshold_rule = steps.choice_rule.ChoiceRule.NumericLessThan(
    variable=training_query_step.output()["QueryTrainingResults"]["Payload"]["results"][
        "TrainingMetrics"
    ][0]["Value"],
    value=10,
)

check_accuracy_step = steps.states.Choice("RMSE < 10")

check_accuracy_step.add_choice(rule=threshold_rule, next_step=check_accuracy_succeed_step)
check_accuracy_step.default_choice(next_step=check_accuracy_fail_step)

In [None]:
model_name = "exp-step-functions-evaluation"
workflow_role_arn = "arn:aws:iam::852039983533:role/AmazonSageMaker-StepFunctionsWorkflowExecutionRole"

workflow_definition = steps.states.Chain([processing_evaluation_step])
workflow = Workflow(model_name, workflow_definition, workflow_role_arn)

In [None]:
workflow.create()
execution = workflow.execute(
    inputs={
        "EvaluationProcessingJobName": evaluation_job_name,  # Each SageMaker processing job requires a unique name,
    }
)
execution_output = execution.get_output(wait=True)

In [None]:
workflow.delete()

## Test Training step

In [None]:
est = sagemaker.estimator.Estimator(
    image_uri,
    role,
    instance_count=1,
    instance_type=training_instance_type,
    hyperparameters=hyperparameters,
    output_path=output_data["ModelOutputUri"],  # NOTE: Can't use execution_input here
    max_run=1200,  # timeout in seconds
    use_spot_instances=True,
    max_wait=1200,  # <= max_run
)

# Specify the data source
s3_input_train = sagemaker.inputs.TrainingInput(
    s3_data=input_data["TrainingUri"]
)
data = {"train": s3_input_train}

# Create the training step
training_step = steps.TrainingStep(
    "Training Job",
    estimator=est,
    data=data,
    job_name=execution_input["TrainingJobName"],
    result_path="$.TrainingResults",
)

In [None]:
model_name = "exp-step-functions-model"
workflow_role_arn = "arn:aws:iam::852039983533:role/AmazonSageMaker-StepFunctionsWorkflowExecutionRole"

workflow_definition = steps.states.Chain([training_step])
workflow = Workflow(model_name, workflow_definition, workflow_role_arn)

In [None]:
import uuid
workflow.create()
training_job_name = "exp-step-functions-{}".format(
    uuid.uuid1().hex
)  # Each Training Job requires a unique name
execution = workflow.execute(
    inputs={
        "TrainingJobName": training_job_name,  # Each Sagemaker Training job requires a unique name,
    }
)
execution_output = execution.get_output(wait=True)

In [None]:
execution_output