In [74]:
import boto3
import json
import os

boto_session = boto3.Session(profile_name='dev-profile')
client = boto_session.client("sagemaker-runtime")

# This has to be the same name as used in terraform to name the endpoint
ENDPOINT_NAME = "mistral-model-endpoint"
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

#os.environ["HF_API_TOKEN"] = 'token'
# HF_API_TOKEN= os.environ["HF_API_TOKEN"]

### Check: Is the endpoint functioning as expected?

In [75]:
# first test to verify endpoint is functioning correctly
user_message = "Write a poem about a cat named Homer."
prompt = f"<s>[INST] {user_message} [/INST]"

payload = {
    "inputs": prompt,
}

response = client.invoke_endpoint(
    EndpointName=ENDPOINT_NAME,
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(payload)
)
result = json.loads(response['Body'].read().decode())
print(result)

[{'generated_text': "<s>[INST] Write a poem about a cat named Homer. [/INST] In a quiet little nook, beneath the moon's soft glow,\nLived a cat named Homer, known to the town,\nHis fur was thick and warm, a coat of rich brown,\nHis eyes shimmered in the dark, like polished gold.\n\nHomer was no ordinary feline, that much was true,\nHe held a secret wisdom, in those ancient eyes,\nA gentle soul, with a meow and a mou,\nHe"}]


### RAG using the hosted endpoint

In [76]:
# loading the vector DB with all of our chunked nasa articles
from langchain_community.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.runnables import RunnablePassthrough

embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")

db = FAISS.load_local(folder_path="faiss_db/", embeddings=embeddings, index_name="nasa_index", allow_dangerous_deserialization=True)

# This retriever returns the top 5 similar chunks
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})



In [77]:
from pydantic import BaseModel, Field
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

# Define the SageMaker configuration model
class SageMakerConfig(BaseModel):
    endpoint_name: str = Field(...)
    profile_name: str = Field('default')

# Define the custom SageMaker LLM class
class SageMakerLLM(LLM):
    endpoint_name: str
    profile_name: str

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _call(self, prompt, **kwargs):
        payload = {"inputs": prompt}
        response = boto3.Session(profile_name=self.profile_name).client("sagemaker-runtime").invoke_endpoint(
            EndpointName=self.endpoint_name,
            ContentType="application/json",
            Accept="application/json",
            Body=json.dumps(payload)
        )
        result = json.loads(response['Body'].read().decode())
        
        # Print out the response for debugging
        # print("Raw response from SageMaker endpoint:", result)
        
        # Assuming the response is a list of dictionaries with 'generated_text' keys
        if isinstance(result, list) and 'generated_text' in result[0]:
            return result[0]['generated_text']
        else:
            raise ValueError("Unexpected response format from SageMaker endpoint")

    @property
    def _llm_type(self):
        return "sagemaker_llm"



In [78]:
# Initialize the custom SageMaker LLM
config = SageMakerConfig(endpoint_name="mistral-model-endpoint", profile_name='dev-profile')
sagemaker_llm = SageMakerLLM(endpoint_name=config.endpoint_name, profile_name=config.profile_name)

# Define the prompt template
prompt_template = """
Answer the question based on your knowledge. Use the following context to help:

{context}

</s>

{question}
</s>
"""

prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=prompt_template,
)

In [79]:
# Chain setup
llm_chain = prompt | sagemaker_llm | StrOutputParser()

# Combine retriever and LLM chain
rag_chain = {"context": retriever, "question": RunnablePassthrough()} | llm_chain


In [80]:
# Example query
question = "What are the latest findings from the Mars Rover?"

# Invoke the RAG chain
response = rag_chain.invoke(question)

print(response)


Answer the question based on your knowledge. Use the following context to help:

[Document(page_content='Sols 4120-4122: Mars Throws Us a Curveball! \n As we previously documented, the first “Mineral King” drill hole did not quite reach the target depth that we typically desire to ensure that we have enough sample in the drill stem to deliver to our internal CheMin and SAM instruments. While we did get a successful X-ray diffraction CheMin analysis, we did not quite have enough sample left for SAM to be able to complete their Evolved Gas Analysis (EGA). The rover engineers selected a new potential drill spot on the same block, and this morning we got the results of the APXS, MAHLI and preload test (to check for stability and drillability) on that spot. While the chemistry and imaging indicated that it was a good candidate the preload test did not pass. The selected target was just a little too close to the rover. As the APXS strategic planner today, I reported the results of the APXS 