In [1]:
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.prompts import ChatPromptTemplate

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langgraph.graph.message import add_messages

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel

from dotenv import load_dotenv
import os
import json
from typing import Sequence
from typing_extensions import Annotated, TypedDict

In [2]:
class ModelConfig:
    """Static configuration and instances of models and embeddings"""
    
    # Initialize model instances once
    _MODELS = {
        "openai-gpt-4o": {
            "name": "gpt-4o",
            "provider": "OpenAI API",
            "description": "Most capable OpenAI model",
            "instance": ChatOpenAI(model="gpt-4o")
        }
    }

    _EMBEDDINGS = {
        "openai-default": {
            "name": "text-embedding-3-large",
            "provider": "OpenAI",
            "description": "Default OpenAI embedding model",
            "instance": OpenAIEmbeddings()
        }
    }

    @classmethod
    def get_llm(cls, model_id):
        """Get existing LLM instance"""
        if model_id not in cls._MODELS:
            raise ValueError(f"Unknown LLM: {model_id}")
        return cls._MODELS[model_id]["instance"]
    
    @classmethod
    def get_embeddings(cls, embedding_id):
        """Get existing embeddings instance"""
        if embedding_id not in cls._EMBEDDINGS:
            raise ValueError(f"Unknown embeddings: {embedding_id}")
        return cls._EMBEDDINGS[embedding_id]["instance"]

In [4]:
LLM_ID = "openai-gpt-4o"
EMBEDDINGS_ID = "openai-default"
LLM = ModelConfig.get_llm("openai-gpt-4o") 
EMBEDDINGS = ModelConfig.get_embeddings("openai-default")

In [5]:
# should probably add Grading prompt here as well

def setup_prompts():
    """Setup chat prompts"""
    system_prompt = ChatPromptTemplate.from_messages([
        ("system", "Your name is Talin. You're part of the HOME0001 team. You are an expert customer service rep for the housing collective HOME0001. Your goal is to help customers who are looking to buy a home at one of our locations. Use the following pieces of retrieved context to answer any question. If you don't find an answer just admit that you don't know. Only answer questions related to HOME0001. Talk in a pragmatic tone. don't get cute, keep it professional but avoid marketing or corporate speak at all costs. Talk like you're texting. Use three sentences maximum and keep the answer concise. Do not admit that you're a bot and do not give it away. Always act human. Talk about HOME0001 as 'us' and 'we'. \n\n{context}"),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ])
    
    context_prompt = ChatPromptTemplate.from_messages([
        ("system", "Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is."),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ])

    return system_prompt, context_prompt

In [6]:
SYSTEM_PROMPT, CONTEXT_PROMPT = setup_prompts()

In [7]:
# Dumb JSON Document Loader - To be replaced
def load_documents(file_path):
        """Load QA documents"""
        with open(file_path, 'r') as f:
            return [Document(page_content=item["answer"], 
                        metadata={"question": item["question"]}) 
                for item in json.load(f)]

In [8]:
documents_path = "../data/home0001qa.json"
DOCUMENTS = load_documents(documents_path)

In [9]:
def setup_vector_db(documents):

    if os.path.isfile(f"./FAISS/{EMBEDDINGS_ID}.faiss"):
            
            vectorstore = FAISS.load_local(
                folder_path="./FAISS",
                embeddings=EMBEDDINGS,
                index_name=EMBEDDINGS_ID,
                allow_dangerous_deserialization=True
            )
    else: 
        vectorstore = FAISS.from_documents(
            documents,
            EMBEDDINGS
        )
        vectorstore.save_local("./FAISS", EMBEDDINGS_ID)

    return vectorstore

In [10]:
VECTORSTORE = setup_vector_db(DOCUMENTS)

### implement grader into rag chain 

In [11]:
from langchain_ollama import ChatOllama

llm_model = "llama3.2:3b-instruct-fp16"

llm_json_mode = ChatOllama(model=llm_model, temperature=0, format='json')

In [12]:
# Doc grader instructions
doc_grader_instructions = """You are a grader assessing relevance of a retrieved document to a user question.

If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant."""

# Grader prompt
doc_grader_prompt = """Here is the retrieved document: \n\n {document} \n\n Here is the user question: \n\n {question}. 

This carefully and objectively assess whether the document contains at least some information that is relevant to the question.

Return JSON with single key, binary_score, that is 'yes' or 'no' score to indicate whether the document contains at least some information that is relevant to the question."""

In [13]:
def grade_documents(documents, question):
  
    # Score each doc
    filtered_docs = []

    for d in documents:
        doc_grader_prompt_formatted = doc_grader_prompt.format(
            document=d.page_content, question=question
        )
        result = llm_json_mode.invoke(
            [SystemMessage(content=doc_grader_instructions)]
            + [HumanMessage(content=doc_grader_prompt_formatted)]
        )
        grade = json.loads(result.content)["binary_score"]
        # Document relevant
        if grade.lower() == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        # Document not relevant
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            # We do not include the document in filtered_docs
            continue

    return {"documents": filtered_docs}

In [15]:
test_retriever = VECTORSTORE.as_retriever()
question = "Do i own my 0001 home outright?"
retrieved_docs = test_retriever.invoke(question)
filtered_docs = grade_documents(retrieved_docs, question)

---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---


In [13]:
# Shall we add prompt evaluation here?

def initialize_rag_chain():
        
    retriever = VECTORSTORE.as_retriever()
    
    history_aware_retriever = create_history_aware_retriever(
        LLM, retriever, CONTEXT_PROMPT
    )
    system_chain = create_stuff_documents_chain(LLM, SYSTEM_PROMPT)
    rag_chain = create_retrieval_chain(history_aware_retriever, system_chain)

    return rag_chain

In [14]:
RAG_CHAIN = initialize_rag_chain()

In [15]:
### Statefully manage chat history ###
# -> Check again how this works exactly
class State(TypedDict):
    input: str
    chat_history: Annotated[Sequence[BaseMessage], add_messages]
    context: str
    answer: str

In [16]:
def call_model(state: State):
    response = RAG_CHAIN.invoke(state)
    return {
        "chat_history": [
            HumanMessage(state["input"]),
            AIMessage(response["answer"]),
        ],
        "context": response["context"],
        "answer": response["answer"],
    }

In [17]:
def initialize_app():

    # TO DO: add document grade to check for retrieved doc relevance AND 
    # add case for when no doc is relevant
    
    workflow = StateGraph(state_schema=State)
    workflow.add_edge(START, "model")
    workflow.add_node("model", call_model)

    memory = MemorySaver()

    app = workflow.compile(checkpointer=memory)

    return app

In [18]:
LANGCHAIN_APP = initialize_app()

In [19]:
async def get_response(message: str, config: dict) -> str:
    """Get response for a message"""
    try:
        response = LANGCHAIN_APP.invoke(
            {"input": message},
            config=config
        )
        return response.get("answer", "Warning: Response object missing 'answer' field. Please check the invocation process.")
    except Exception as e:
        return f"Error: {str(e)}"

In [37]:
config_01 =  {"configurable": {"thread_id": "test-id-05"}}
response_01 = await get_response("hey what's up, i'm frank", config_01)
print(response_01)

Hey Frank, how can I help you today?


In [38]:
config_02 =  {"configurable": {"thread_id": "test-id-06"}}
response_02 = await get_response("hey what's up, i'm lutz", config_02)
print(response_02)

Hey Lutz, how can I help you today?


In [39]:
config_02 =  {"configurable": {"thread_id": "test-id-05"}}
response_02 = await get_response("what's my name?", config_02)
print(response_02)

You mentioned your name is Frank. How can I assist you with your home search?


In [40]:
config_01 =  {"configurable": {"thread_id": "test-id-06"}}
response_01 = await get_response("what's my name?", config_01)
print(response_02)

You mentioned your name is Frank. How can I assist you with your home search?


In [None]:
''' 
LANGCHAIN_APP probably has to be reinitialized per thread, otherwise conversations seem to get mixed up.

'''