In [None]:
# !pip install openai faiss-cpu numpy python-dotenv langchain langchain-openai langchain-community

In [None]:
import os
import re
from dotenv import load_dotenv
from IPython.display import display, Markdown
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_community.vectorstores import FAISS

load_dotenv()

In [None]:
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small")

# Correct loading of FAISS vectorstore with metadata (LangChain-compatible)
vectorstore = FAISS.load_local(
    "faiss_index_directory",
    embeddings_model,
    allow_dangerous_deserialization=True
)

retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

In [None]:
llm = ChatOpenAI(model="gpt-4.1", temperature=0.0)

# Properly set up conversation memory with explicit output_key
memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True,
    output_key="answer"
)

# Initialize conversational retrieval chain correctly
conversation_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=retriever,
    memory=memory,
    return_source_documents=True
)

In [None]:
def fix_latex_delimiters(text):
    text = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', text, flags=re.DOTALL)
    text = re.sub(r'\\\((.*?)\\\)', r'$\1$', text, flags=re.DOTALL)
    return text

def chat(query, char_limit=200, use_retriever=True):
    if use_retriever:
        result = conversation_chain.invoke({"question": query})
        question = fix_latex_delimiters(result['question'])
        answer = fix_latex_delimiters(result['answer'])
        sources = result['source_documents']
    else:
        chat_history = memory.load_memory_variables({})["chat_history"]
        prompt = f"{chat_history}\nUser: {query}\nAssistant:"
        result_text = llm.invoke(prompt)
        question = query
        answer = fix_latex_delimiters(result_text.content)
        sources = []

    print("\nQuestion:")
    display(Markdown(question))

    print("\nAnswer:")
    display(Markdown(answer))

    print("\nSources:")
    for i, doc in enumerate(sources, 1):
        source_info = doc.metadata.get('source', 'Unknown')
        page_content = fix_latex_delimiters(doc.page_content)
        
        # Limit page content length
        if len(page_content) > char_limit:
            page_content = page_content[:char_limit] + "...\n\n*(truncated)*"

        source_md = f"#### Source {i}: {source_info}\n\n```\n{page_content}\n```"
        display(Markdown(source_md))


In [None]:
memory.clear()

In [None]:
chat("""
what is probability?
""")

In [None]:
chat("""
Please elaborate.
""", use_retriever=False)