## Set up your endpoint name

Please either copy your own endpoint name or follow the instructions provided by the workshop instructor.

In [1]:
ENDPOINT_NAME = 'huggingface-pytorch-tgi-inference-2023-07-20-02-45-33-447'

In [2]:
import boto3
import json

def query_endpoint_and_parse_response(payload_dict, endpoint_name):
    encoded_json = json.dumps(payload_dict).encode("utf-8")
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )
   
    return json.loads(response['Body'].read().decode())[0]['generated_text']



## Set up model parameters


The following parameters are available for controlling text generation using the GenerationConfig class:

- do_sample (bool, optional, defaults to False): Determines whether to use sampling or greedy decoding.
- temperature (float, optional, defaults to 1.0): Modulates the next token probabilities.
- max_new_tokens (int, optional): Sets the maximum number of tokens to generate, excluding those in the prompt.
- top_k (int, optional, defaults to 50): Sets the number of highest probability vocabulary tokens to keep using top-k filtering.
- top_p (float, optional, defaults to 1.0): When set to a float less than 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.

For a complete list of available parameters and their descriptions, refer to the GenerationConfig class documentation at https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation.

In [3]:
parameters = {
    "max_new_tokens": 200,
    "top_k": 5,
    "top_p": .15,
    "do_sample": True,
    "temperature": 0.01
}


## Prompt with layman inputs

In [4]:
prompt_data ="""
I have a table called patient with fields ID, AGE, WEIGHT, HEIGHT. 
Write me a SQL Query which will return the entry with the highest age

"""#If you'd like to try your own prompt, edit this parameter!

In [5]:
payload = {"inputs": prompt_data, "parameters":parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [6]:
print(f"Result: {generated_texts}")

Result: SELECT ID FROM patient WHERE AGE > (SELECT max(AGE) FROM patient)


## Prompt with Table Schema

In [7]:
import json
import boto3
sagemaker_client = boto3.client('sagemaker-runtime')
payload = """You are an export of Presto Database.Your tasks is to generate a SQL query

Pay attention to use only the column names that you can see in the schema description. 
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Your Table sales schema as follows:

CREATE EXTERNAL TABLE sales (
	transaction_date DATE COMMENT 'Transaction date',
	user_id STRING COMMENT 'The user who make the purchase',
	product STRING COMMENT product name, e.g "Fruits", "Ice cream", "Milk",
	price DOUBLE COMMENT 'The price of the product'
)

Question: What is total sale amount of Fruits
SQLQuery:

"""


In [8]:
payload = {"inputs": payload, "parameters":parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [9]:
print(f"Result: {generated_texts}")

Result: SELECT sum(price) FROM sales WHERE product = 'Fruits'


Another example

Can we join a table?

In [58]:
payload ="""
You are an export of MySQL Database.Your tasks is to generate a SQL query

Pay attention to use only the column names that you can see in the schema description. 
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Your Table sales schema as follows:

CREATE EXTERNAL TABLE sales (
    transaction_date DATE COMMENT 'the transaction date in the format yyyy-mm-dd'
	user_id STRING COMMENT 'The user who make the purchase',
	product STRING COMMENT product name, e.g "Fruits", "Ice cream", "Milk",
	sales_amount DOUBLE COMMENT 'The price of the product'
)

Your Table users schema as follows

CREATE EXTERNAL TABLE users (
	user_id STRING COMMENT 'user id',
	name STRING COMMENT User name
)

Question: What is total purchase done by "John"
SQLQuery:
"""

In [59]:
payload = {"inputs": payload, "parameters":parameters}
generated_texts = query_endpoint_and_parse_response(payload, ENDPOINT_NAME)

In [60]:
print(f"Result: {generated_texts}")

Result: SELECT sum(sales_amount) FROM sales AS t1 JOIN users AS t2 ON t1.user_id = t2.user_id WHERE t2.name = "John"
