In [None]:
import os
from dotenv import load_dotenv
from openai import OpenAI
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_classic.chains import RetrievalQA
from langchain_classic.prompts import PromptTemplate
import mercury as mr


%load_ext autoreload
%autoreload 2

In [None]:
local_directory_data = 'C:/Users/maran/OneDrive/Documents/Git Profile/Data-Projects/Inference Group/rag_data/'
state_of_union_file = '2024_State_of_the_Union.txt'
rag_data_dir = '/rag_data'
vector_store_faiss = '/faiss_index'
path_to_faiss_dir = local_directory_data + rag_data_dir + vector_store_faiss
path_to_state_of_union = local_directory_data + state_of_union_file

# Custom prompt template
prompt_template = """Use the following pieces of context to answer the question at the end. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.

Context:
{context}

Question: {question}
Answer: """

In [None]:
app = mr.App(
    title="Simple RAG Query System", 
    description="Ask questions about your documents",
    show_code=False,
    static_notebook=False 
)

question = mr.Text(value=None, label="What do you want to know?")
search_result_choices = mr.Slider(value=3, min=1, max=5, label="Number of results")
show_sources = mr.Checkbox(value=False, label="Show Sources")
file_upload = mr.File(label="File upload", max_file_size="10MB")

In [None]:
load_dotenv()
api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=api_key)

try:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": "Say 'API is working!'"}],
        max_tokens=10
    )
    print("‚úÖ API Response:", response.choices[0].message.content)

except Exception as e:
    print("‚ùå API Error:", e)

In [None]:
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")

if file_upload.filepath is not None:
    with open(file_upload.filepath,  'r', encoding='utf-8') as file_upload_obj:
        file_upload_read_obj = file_upload_obj.read()

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=100
    )

    example_texts = text_splitter.split_text(file_upload_read_obj)
    documents = [Document(page_content=text) for text in example_texts]

    vector_store = FAISS.from_documents(documents, embeddings)
    vector_store.save_local(path_to_faiss_dir)

    retriever_object = vector_store.as_retriever(
    search_type="similarity",
    search_kwargs={"k": search_result_choices.value} 
)

In [None]:
llm = ChatOpenAI(
    model_name="gpt-3.5-turbo",
    temperature=0,
    openai_api_key=os.getenv('OPENAI_API_KEY')
)

PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

if file_upload.filepath is None:
    normal_prompt = PromptTemplate(
        input_variables=["question"],
        template="Answer the following question: {question}"
    )
    qa_chain = normal_prompt | llm

else:
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever_object,
        chain_type_kwargs={"prompt": PROMPT},
        return_source_documents=show_sources.value
    )


In [None]:
if file_upload.filepath is None:
    normal_no_context_result: dict = qa_chain.invoke({"question": question.value})
    llm_result = {'result': normal_no_context_result.content}
else:
    llm_result: dict = qa_chain.invoke({"query": question.value})

if question.value != '':
    print(llm_result['result'])

if show_sources.value:
    print(f"\nüìö SOURCES ({len(llm_result.get('source_documents', []))} documents):")
    for i, doc in enumerate(llm_result.get('source_documents', [])):
        print(f"\nSource {i+1}:")
        print(f"Content: {doc.page_content}...")
        if hasattr(doc, 'metadata') and doc.metadata:
            print(f"Metadata: {doc.metadata}")