In [75]:
import boto3
import json
import os

boto_session = boto3.Session(profile_name='dev-profile')
client = boto_session.client("sagemaker-runtime")

# This has to be the same name as used in terraform to name the endpoint
ENDPOINT_NAME = "mistral-model-endpoint"
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

#os.environ["HF_API_TOKEN"] = 'token'
# HF_API_TOKEN= os.environ["HF_API_TOKEN"]

In [83]:
# first test to verify endpoint is functioning correctly
user_message = "Write a poem about a cat named Homer."
prompt = f"<s>[INST] {user_message} [/INST]"

payload = {
    "inputs": prompt,
}

response = client.invoke_endpoint(
    EndpointName=ENDPOINT_NAME,
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(payload)
)
result = json.loads(response['Body'].read().decode())
print(result)

[{'generated_text': "<s>[INST] Write a poem about a cat named Homer. [/INST] In a quiet little nook, where the sun gently shed,\nLive stories, tales and lore, of a feline named Homer, a thread.\nA coat of midnight hues, beneath the crescent moon's pull,\nHis eyes, emerald orbs, gleaming with ageless allure, and a brood so full.\n\nWith whiskers twitching, and ears that perked so fine,\nHis graceful fingers swiped"}]


In [84]:
from langchain_community.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.runnables import RunnablePassthrough

# Finally, we need to combine the llm_chain with the retriever to create a RAG chain.
# We pass the original question through to the final generation step, as well as the retrieved context docs
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")

db = FAISS.load_local(folder_path="faiss_db/", embeddings=embeddings, index_name="nasa_index", allow_dangerous_deserialization=True)

# This retriever returns the top 5 similar chunks
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})



In [85]:
from langchain.llms.sagemaker_endpoint import LLMContentHandler

# TODO figure out how to adapt it
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"
    
    # this should be adapted to include context
    def transform_input(self, prompt, model_kwargs):
        input_str = json.dumps({
                "inputs" : prompt,
                "parameters" : {**model_kwargs}
            })
        return input_str.encode('utf-8')
    
    def transform_output(self, output):
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generation"]["content"]
    

content_handler = ContentHandler()


In [90]:
from langchain import SagemakerEndpoint

llm=SagemakerEndpoint(
    client = client,
    endpoint_name=ENDPOINT_NAME, 
    region_name='eu-west-1', 
    model_kwargs={"max_new_tokens": 100, "top_p": 0.9, "temperature": 0.2},
    content_handler=content_handler
 )

In [87]:
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser


# Create a prompt template
# this should follow the format of the model, so make sure to use the appropriate formatting.
prompt_template = """
<|system|>
The following is a friendly conversation between a human and an AI. 
The AI is talkative and provides lots of specific details from its context.
If the AI does not know the answer to a question, it truthfully says it does not know. 
Use the following context to help:

{context}

</s>
{question}
</s>
"""

prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=prompt_template,
)

llm_chain = prompt | llm | StrOutputParser()

# You can also use tokenizer.apply_chat_template to convert a list of messages (as dicts: {'role': 'user', 'content': '(...)'})
# into a string with the appropriate chat format.

In [88]:
# RunnablePassthrough Usage: The query is then passed along using RunnablePassthrough(). 
# This function is a part of LangChain’s API and is used to pass the query to the next step in the chain.
rag_chain = {"context": retriever, "question": RunnablePassthrough()} | llm_chain

In [91]:
import logging
# A RAG response considering our NASA context
try:
    # Need to reduce size of chunks
    response = rag_chain.invoke("What can you tell me about the latest space discoveries?")
    logging.info(response)
except Exception as e:
    logging.error(f"An error occurred: {e}")

ERROR:root:An error occurred: Error raised by inference endpoint: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (422) from primary with message "{"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 3651 `inputs` tokens and 4000 `max_new_tokens`","error_type":"validation"}". See https://eu-west-1.console.aws.amazon.com/cloudwatch/home?region=eu-west-1#logEventViewer:group=/aws/sagemaker/Endpoints/mistral-model-endpoint in account 051285996764 for more information.
