# Use Below link to get endpoint of different models:
# https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text2text-generation-flan-t5-ul2.ipynb

# 1. Setup

In [None]:
# Install the latest version of ipywidgets
#!pip install --upgrade ipywidgets

# Install a specific version of ipywidgets (version 7.0.0) quietly without displaying output
#!pip install ipywidgets==7.0.0 --quiet

# Upgrade the sagemaker package to the latest version quietly without displaying output
#!pip install --upgrade sagemaker --quiet

In [None]:
# Import necessary libraries for working with SageMaker, AWS, and JSON
import sagemaker
import boto3
import json

# Import the Session class from the sagemaker.session module
from sagemaker.session import Session

In [None]:
# Create a SageMaker Session
sagemaker_session = Session()

# Get the Amazon Resource Name (ARN) of the AWS Identity and Access Management (IAM) role
aws_role = sagemaker_session.get_caller_identity_arn()

# Get the AWS region using boto3 Session
aws_region = boto3.Session().region_name

# Create a new SageMaker Session
sess = sagemaker.Session()

In [None]:
#check if aws region is correct
print(aws_region)

# 2. Select a pre-trained model

In [None]:
# Define the model_id and model_version variables
model_id, model_version = (
    "huggingface-text2text-flan-t5-xl",  # Model identifier or name
    "1.*",                                # Model version (wildcard for any version starting with 1)
)


# 3. Retrieve Artifacts & Deploy an Endpoint

In [None]:
def get_sagemaker_session(local_download_dir) -> sagemaker.Session:
    """
    Return the SageMaker session.

    Args:
        local_download_dir (str): Local directory for downloading artifacts.

    Returns:
        sagemaker.Session: SageMaker session with specified settings.
    """

    # Create a SageMaker client using boto3
    sagemaker_client = boto3.client(
        service_name="sagemaker", region_name=boto3.Session().region_name
    )

    # Specify session settings, such as the local download directory
    session_settings = sagemaker.session_settings.SessionSettings(
        local_download_dir=local_download_dir
    )

    # Create a SageMaker session with the specified client and settings
    session = sagemaker.session.Session(
        sagemaker_client=sagemaker_client, settings=session_settings
    )

    return session


In [None]:
# Create a directory named 'download_dir' if it doesn't exist (-p flag ensures no error if it already exists)
!mkdir -p download_dir

In [None]:
# Environment variables for a large model
_large_model_env = {"SAGEMAKER_MODEL_SERVER_WORKERS": "1", "TS_DEFAULT_WORKERS_PER_MODEL": "1"}

# Model configuration map for the specified Hugging Face model
_model_config_map = {
    "huggingface-text2text-flan-t5-xl": {
        "instance_type": "ml.m5.2xlarge",  # Instance type for deploying the model
        "env": {"MMS_DEFAULT_WORKERS_PER_MODEL": "1"},  # Environment variables for the specified model
    },
}


In [None]:
# Import necessary modules from SageMaker
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base

In [None]:
# Create an endpoint name using the name_from_base utility function
endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

In [None]:
print(endpoint_name)

In [None]:
# Check if the specified model_id is in the _model_config_map
if model_id in _model_config_map:
    # Retrieve the inference instance type from the _model_config_map
    inference_instance_type = _model_config_map[model_id]["instance_type"]
    print(1)  # Print a message, for example
else:
    # If model_id is not in the map, use a default inference instance type
    inference_instance_type = "ml.m5.2xlarge"


In [None]:
# Retrieve the inference docker container URI. This is the base Hugging Face container image for the default model above.
deploy_image_uri = image_uris.retrieve(
    region=None,  # Automatically inferred from AWS region
    framework=None,  # Framework is automatically inferred from model_id
    image_scope="inference",  # Specify the image scope for inference
    model_id=model_id,  # Model identifier
    model_version=model_version,  # Model version
    instance_type=inference_instance_type,  # Inference instance type
)


In [None]:
# Retrieve the inference script URI. This includes all dependencies and scripts for model loading, inference handling, etc.
deploy_source_uri = script_uris.retrieve(
    model_id=model_id,  # Model identifier
    model_version=model_version,  # Model version
    script_scope="inference",  # Specify the script scope for inference
)


In [None]:
# Retrieve the model URI for inference.
model_uri = model_uris.retrieve(
    model_id=model_id,  # Model identifier
    model_version=model_version,  # Model version
    model_scope="inference",  # Specify the model scope for inference
)


In [None]:
# Create the SageMaker model instance
if model_id in _model_config_map:
    # For those large models, we already repack the inference script and model
    # artifacts for you, so the `source_dir` argument to Model is not required.
    model = Model(
        image_uri=deploy_image_uri,
        model_data=model_uri,
        role=aws_role,
        predictor_cls=Predictor,
        name=endpoint_name,
        env=_model_config_map[model_id]["env"],  # Set environment variables
    )
    print(1)  # Print a message, for example
else:
    # For other models, include source_dir, entry_point, and sagemaker_session parameters
    model = Model(
        image_uri=deploy_image_uri,
        source_dir=deploy_source_uri,
        model_data=model_uri,
        entry_point="inference.py",  # Entry point file in source_dir and present in deploy_source_uri
        role=aws_role,
        predictor_cls=Predictor,
        name=endpoint_name,
        sagemaker_session=get_sagemaker_session("download_dir"),  # SageMaker session with download directory
    )


In [None]:
# Deploy the model. Note that when deploying the model through the Model class,
# we need to pass the Predictor class to enable running inference through the SageMaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,  # Specify the Predictor class for inference
    endpoint_name=endpoint_name,  # Set the endpoint name for the deployed model
)


# 4. Query endpoint and parse response

In [None]:
# Define string variables for newline, bold, and unbold with specific escape sequences
newline, bold, unbold = "\n", "\033[1m", "\033[0m"

# These variables can be used for formatting text in a terminal or console environment.
# - `newline`: Contains the escape sequence for a newline character (`"\n"`).
# - `bold`: Contains the escape sequence for enabling bold text in some terminal environments (`"\033[1m"`).
# - `unbold`: Contains the escape sequence for disabling bold text and returning to normal formatting (`"\033[0m"`).


In [None]:
def query_endpoint(encoded_text, endpoint_name):
    """
    Queries a SageMaker endpoint with the provided encoded text.

    Args:
        encoded_text (bytes): The encoded text data to be sent to the endpoint.
        endpoint_name (str): The name of the SageMaker endpoint to query.

    Returns:
        dict: The response from the SageMaker endpoint.
    """
    # Create a SageMaker runtime client
    client = boto3.client("runtime.sagemaker")
    
    # Invoke the specified endpoint with the encoded text as input
    response = client.invoke_endpoint(
        EndpointName=endpoint_name,  # Specify the name of the SageMaker endpoint
        ContentType="application/x-text",  # Set the content type for the input data
        Body=encoded_text  # Provide the encoded text as the input body
    )
    
    # Return the response from the endpoint
    return response


In [None]:
import json

def parse_response(query_response):
    """
    Parses the response from a SageMaker endpoint query.

    Args:
        query_response (dict): The response received from the SageMaker endpoint.

    Returns:
        str: The generated text extracted from the response.
    """
    # Extract the model predictions from the response body
    model_predictions = json.loads(query_response["Body"].read())
    
    # Extract the generated text from the model predictions
    generated_text = model_predictions["generated_text"]
    
    # Return the generated text
    return generated_text


# Example

In [None]:
text1 = "Translate to German:  My name is Arthur"

In [None]:
for text in [text1]:
    query_response = query_endpoint(text.encode("utf-8"), endpoint_name=endpoint_name)
    generated_text = parse_response(query_response)
    print(
        f"Inference:{newline}"
        f"input text: {text}{newline}"
        f"generated text: {bold}{generated_text}{unbold}{newline}"
    )

In [None]:
text2 = "A step by step recipe to make bolognese pasta:"

In [None]:
for text in [text2]:
    query_response = query_endpoint(text.encode("utf-8"), endpoint_name=endpoint_name)
    generated_text = parse_response(query_response)
    print(
        f"Inference:{newline}"
        f"input text: {text}{newline}"
        f"generated text: {bold}{generated_text}{unbold}{newline}"
    )

In [None]:
text3 = """
     From the table schema below, generate a SQL code for question "What is the total sales quantity in New Zealand?"
     Schema of the table is as below in {} brackets.
     {"database": "datalake-with-catalog",
    "table": "hanz_sales",
    "table_description": "The table represents daily sales of various types of cement materials in New Zealand",
    "columns": [
        ["Material", "The various types of cement material code", "string", ""],
        ["Ship-to","The customer Codes", "string", ""],
        ["Material-Description", "Material Description", "string", ""],
        ["Ship-to party", "Customer description", "string", ""],
        ["Del.Date", "Delivery Date", "date", ""],
        ["Rpt Qty", "Reported Quantity", "float", "Tons"],
        ["Reporting UOM", "Quantity Unit in Tons", "string", ""]
    ]} 
    Question: What is the total sales quantity in New Zealand?
    Answer: The query should be "Select sum(Rpt Qty) from hanz_sales;"
    
    Similar way answer what are the unique materials?
     """

In [None]:
for text in [text3]:
    query_response = query_endpoint(text.encode("utf-8"), endpoint_name=endpoint_name)
    generated_text = parse_response(query_response)
    print(
        f"Inference:{newline}"
        f"input text: {text}{newline}"
        f"generated text: {bold}{generated_text}{unbold}{newline}"
    )

In [None]:
text4 = """
     Give me SQL code for: Select all customers from the customer table who have placed an order in the last 30 days
     """

In [None]:
for text in [text4]:
    query_response = query_endpoint(text.encode("utf-8"), endpoint_name=endpoint_name)
    generated_text = parse_response(query_response)
    print(
        f"Inference:{newline}"
        f"input text: {text}{newline}"
        f"generated text: {bold}{generated_text}{unbold}{newline}"
    )

# 7. Clean up the endpoint

In [None]:
# Delete the SageMaker endpoint
model_predictor.delete_model()
model_predictor.delete_endpoint()