## This notebook will test the deployed endpoint with examples. 

### Update SageMaker

In [None]:
!pip install -U sagemaker

### Restart kernel and then execute the next cell.

In [None]:
import boto3
import datetime

client = boto3.client('sagemaker')
endpoint_name = client.list_endpoints(
    NameContains='meta-textgeneration-l',
    CreationTimeAfter=datetime.datetime.now().strftime("%Y-%m-%d"),
    StatusEquals='InService'
)['Endpoints'][0]['EndpointName']
endpoint_name

In [None]:
inference_conponent_name = response = client.list_inference_components(
    SortBy='CreationTime',
    SortOrder='Descending',
    StatusEquals='InService',
    EndpointNameEquals=endpoint_name
)['InferenceComponents'][0]['InferenceComponentName']
inference_conponent_name

In [None]:
import json

zero_shot_prompts = [
    "I believe the meaning of life is",
    "Simply put, the theory of relativity states that ",
    """A brief message congratulating the team on the launch:

Hi everyone,

I just """,
]
few_shot_prompts = [
    """Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]

payloads = []
for prompt in zero_shot_prompts:
    payloads.append(
        {
            "inputs": prompt, 
            "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6, "return_full_text": False},
        }
    )
for prompt in few_shot_prompts:
    payloads.append(
        {
            "inputs": prompt, 
            "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6, "return_full_text": False},
        }
    )


def query_endpoint(payload):
    client = boto3.client("sagemaker-runtime")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, InferenceComponentName=inference_conponent_name,
        ContentType="application/json",
        Body=json.dumps(payload),
    )
    response = response["Body"].read().decode("utf8")
    response = json.loads(response)
    return response

for payload in payloads:
    query_response = query_endpoint(payload)
    print(payload["inputs"])
    print(f"> {query_response[0]['generated_text']}")
    print("\n======\n")
