In [None]:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnableWithMessageHistory
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.schema import Document

# -----------------------------
# 1. Knowledge Base
# -----------------------------
documents = [
    Document(page_content="LangChain is a framework for building applications using large language models."),
    Document(page_content="Retrieval-Augmented Generation (RAG) improves LLM responses by retrieving relevant context."),
    Document(page_content="LangChain Expression Language allows declarative construction of LLM pipelines."),
    Document(page_content="Chunking strategy directly affects retrieval accuracy in RAG systems."),
]

# -----------------------------
# 2. Chunk Documents
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=200,
    chunk_overlap=30
)
splits = text_splitter.split_documents(documents)

# -----------------------------
# 3. Vector Store & Retriever
# -----------------------------
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_documents(splits, embeddings)
retriever = vectorstore.as_retriever()

# -----------------------------
# 4. Chat History Store
# -----------------------------
chat_store = {}

def get_session_history(session_id: str):
    if session_id not in chat_store:
        chat_store[session_id] = InMemoryChatMessageHistory()
    return chat_store[session_id]

# -----------------------------
# 5. Prompt Template
# -----------------------------
prompt = ChatPromptTemplate.from_messages([
    ("system", "Answer the question using ONLY the provided context."),
    MessagesPlaceholder(variable_name="chat_history"),
    ("human", "Context:\n{context}\n\nQuestion:\n{question}")
])

# -----------------------------
# 6. LLM
# -----------------------------
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0
)

# -----------------------------
# 7. Helper: Format Retrieved Docs
# -----------------------------
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# -----------------------------
# 8. Build RAG Chain (LCEL)
# -----------------------------
rag_chain_base = (
    RunnableParallel(
        context=lambda x: format_docs(retriever.invoke(x["question"])),
        question=lambda x: x["question"],
        chat_history=lambda x: x.get("chat_history", [])
    )
    | prompt
    | llm
    | StrOutputParser()
)

# -----------------------------
# 9. Add Conversational Memory
# -----------------------------
rag_chain = RunnableWithMessageHistory(
    rag_chain_base,
    get_session_history,
    input_messages_key="question",
    history_messages_key="chat_history"
)

# -----------------------------
# 10. Run the System
# -----------------------------
if __name__ == "__main__":
    session_id = "student-session"

    print("Q1: What is RAG?")
    answer1 = rag_chain.invoke(
        {"question": "What is RAG?"},
        config={"configurable": {"session_id": session_id}}
    )
    print("A1:", answer1, "\n")

    print("Q2: Why is chunking important?")
    answer2 = rag_chain.invoke(
        {"question": "Why is chunking important in RAG systems?"},
        config={"configurable": {"session_id": session_id}}
    )
    print("A2:", answer2)
