In [1]:
import os
%pip install -q dotenv
%pip install -U langchain-community
from dotenv import load_dotenv

from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Load environment variables (e.g., OpenAI API key)
load_dotenv()

# Constants
DATA_PATH = "data/sample_docs.txt"
CHROMA_PATH = "db"

In [3]:
# Step 1: Load and chunk documents
def load_and_split_documents():
    loader = TextLoader(DATA_PATH)
    documents = loader.load()
    
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    chunks = splitter.split_documents(documents)
    return chunks

In [4]:
# Step 2: Create or load vector store using LangChain
def get_vector_store(chunks):
    embeddings = OpenAIEmbeddings()
    if os.path.exists(CHROMA_PATH):
        vectordb = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
    else:
        vectordb = Chroma.from_documents(chunks, embedding=embeddings, persist_directory=CHROMA_PATH)
        vectordb.persist()
    return vectordb

In [5]:
# Step 3: Build LangChain RetrievalQA pipeline
def build_rag_chain(vectordb):
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
    retriever = vectordb.as_retriever()
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=retriever,
        return_source_documents=True
    )
    return qa_chain

In [6]:
# Step 4: Main loop
def main():
    print("Loading and processing documents...")
    chunks = load_and_split_documents()
    vectordb = get_vector_store(chunks)
    qa_chain = build_rag_chain(vectordb)

    print("\nRAG Agent ready. Ask a question or type 'exit'.\n")
    while True:
        query = input(">>> ")
        if query.lower() in ["exit", "quit"]:
            break
        result = qa_chain({"query": query})
        print("\nAnswer:\n", result["result"])
        print("\nSources:")
        for doc in result["source_documents"]:
            print("—", doc.metadata.get("source", "Unknown"))

if __name__ == "__main__":
    main()

Loading and processing documents...


RuntimeError: Error loading data/sample_docs.txt