In [1]:
import os
import requests
from dotenv import load_dotenv
from operator import itemgetter

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.llms import Ollama

from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel

load_dotenv()
OLLAMA_URL = os.getenv("OLLAMA_URL")

In [2]:
payload = requests.get(f"{OLLAMA_URL}/api/tags").json()
model_names = [model['name'] for model in payload['models']]
model_names = ("\n".join(model_names)).rstrip("\n")
print(model_names)

bakllava:7b
llama2:13b
llama2:7b
mistral:7b


In [3]:
llm = Ollama(
    model="mistral:7b",
    base_url=OLLAMA_URL,
    temperature=0.5,
)

In [4]:
template = """Use the following context strictly to formulate your answers:
Context:\n{context}\n
Question: {question}

Make sure to answer this briefly. Answer strictly in {language} only. No need to mention the source documents.
Answer: """

prompt = PromptTemplate.from_template(template)

output_parser = StrOutputParser()

In [5]:
loader = DirectoryLoader('../documents', glob="**/*.txt", show_progress=True, loader_cls=TextLoader)
documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = text_splitter.split_documents(documents)
embedding = OllamaEmbeddings(
    model="mistral:7b",
    base_url=OLLAMA_URL,
)

DB_PATH = "../.data/faiss_index"
if os.path.isdir(DB_PATH):
    print("Loading existing FAISS index")
    vectorstore = FAISS.load_local(DB_PATH, embeddings=embedding, allow_dangerous_deserialization=True)
else:
    print("Creating new FAISS index")
    vectorstore = FAISS.from_documents(documents=chunks, embedding=embedding, normalize_L2=True)
    vectorstore.save_local("../.data/faiss_index")

retriever = vectorstore.as_retriever()

100%|██████████| 2/2 [00:00<?, ?it/s]

Loading existing FAISS index





In [14]:
def format_docs(docs):
    return "\n\n".join(f"[Document = {doc.metadata['source']}]\n{doc.page_content}" for doc in docs)

def debug(name:str=None):
    if not name:
        name = "DEBUG"
    def _debug(inputs):
        try:
            print(f"{name}\n{'-'*len(name)}\n{inputs.text}\n")
        except:
            print(f"{name}\n{'-'*len(name)}\n{inputs}\n")
        return inputs
    return RunnableLambda(_debug)

def get_sources(docs):
    sources = []
    for k,v in vectorstore.docstore._dict.items():
        sources.append(v.metadata['source'])
    return "\n".join(sources)

In [15]:
llm_chain = (
    RunnablePassthrough.assign(
        context=itemgetter("context") | RunnableLambda(format_docs)
    )
    | prompt
    | llm
    | output_parser
)

chain = (
    RunnableParallel(
        context = itemgetter("question") | retriever,
        question = itemgetter("question"),
        language = itemgetter("language")
    )
    | RunnableParallel(
        llm = llm_chain | debug("LLM Output"),
        sources = itemgetter("context") | RunnableLambda(get_sources) | debug("Source Dcouments"),
    )
)

output = chain.invoke({
    "question": "What is current value of Bitcoin?",
    "language": "English"
})
print(output)

Source Dcouments
----------------
..\documents\info.txt
..\documents\other.txt

LLM Output
----------
 The current value of Bitcoin is 12 USD.

{'llm': ' The current value of Bitcoin is 12 USD.', 'sources': '..\\documents\\info.txt\n..\\documents\\other.txt'}


In [None]:
rag = retriever | format_docs
print(rag.invoke("Bitcoin"))