In [12]:
import os
import chromadb
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain

In [13]:
llm = ChatOllama(model="llama2:7b")
emb = OllamaEmbeddings(model="llama2:7b")

In [14]:
llm.invoke("Hello world!")

AIMessage(content='Hello there! * waves * How are you today?', response_metadata={'model': 'llama2:7b', 'created_at': '2024-04-10T03:49:10.156873Z', 'message': {'role': 'assistant', 'content': ''}, 'done': True, 'total_duration': 1118669292, 'load_duration': 2188959, 'prompt_eval_duration': 257229000, 'eval_count': 12, 'eval_duration': 857954000}, id='run-a80a9db9-e5d0-4e7b-b177-2b1f63632dc0-0')

In [15]:
class OllamaEmbeddingFn(chromadb.EmbeddingFunction):
    def __call__(self, input: chromadb.Documents) -> chromadb.Embeddings:
        return emb.embed_documents(input)

In [16]:
# initialise chroma collection
# adapted from https://docs.trychroma.com/usage-guide#using-collections
client = chromadb.PersistentClient(path="documents/chroma_db")
collection = client.get_or_create_collection(
    name="helpsheets", embedding_function=OllamaEmbeddingFn()
)

In [17]:
# paths of documents to index
helpsheet_dir = r"documents/helpsheet collection"
helpsheet_paths = [
    f
    for f in os.listdir(helpsheet_dir)
    if os.path.isfile(os.path.join(helpsheet_dir, f))
]

In [18]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)

# index documents (takes forever)
# adapted from https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf/#using-pypdf
if collection.count() == 0:
    for i, hs_path in tqdm(enumerate(helpsheet_paths)):
        print(i, hs_path)
        # load document
        loader = PyPDFLoader(os.path.join(helpsheet_dir, hs_path))
        docs = loader.load()
        # split document into chunks
        splits = text_splitter.split_documents(docs)
        # unique ids for each chunk
        ids = [f"{i} - {j}" for j in range(len(splits))]
        # add chunks into chroma collection
        collection.add(
            ids=ids,
            metadatas=[d.metadata for d in splits],
            documents=[d.page_content for d in splits],
        )

In [19]:
# initialise langchain vector store from chroma client
# adapted from https://python.langchain.com/docs/integrations/vectorstores/chroma/#passing-a-chroma-client-into-langchain
vectorstore = Chroma(
    client=client,
    collection_name="helpsheets",
    embedding_function=emb,
)
retriever = vectorstore.as_retriever()

In [20]:
# code for chat history (including next 2 code blocks)
# adapted from https://python.langchain.com/docs/use_cases/question_answering/chat_history/
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
history_aware_retriever = create_history_aware_retriever(
    llm, retriever, contextualize_q_prompt
)

In [21]:
qa_system_prompt = """You are an computer science tutoring assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
Do not mention "the context". \
If you don't know the answer, just say that you don't know. \
Use five sentences maximum and keep the answer concise.\

{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)


question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

In [22]:
### Statefully manage chat history ###
store = {}


def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

In [23]:
# streams llm output
# adapted from https://python.langchain.com/docs/use_cases/question_answering/streaming/
def query_and_print(input: str, session_id="abc123"):
    for chunk in conversational_rag_chain.stream(
        {"input": input},
        config={
            "configurable": {"session_id": session_id}
        },  # constructs a key "abc123" in `store`.
    ):
        if "answer" in chunk:
            print(chunk["answer"], end="")
        else:
            print(chunk)

In [28]:
query_and_print("What is an MST?")

{'input': 'What is an MST?'}
{'chat_history': []}
{'context': [Document(page_content='•A≤PB,B≤PC⇒A≤PCPolynomial Time\n•polynomialtime→runtime is polynomial in the length of\nthe encoding of the problem instance\n•"standard"encodings\n•binary encoding of integers\n•list of parameters enclosed in braces (graphs/matrices)\n•pseudo-polynomialalgorithm→runs in time polynomial in\nthenumeric value if the input but is exponential in the\nlengthof the input\n•e.g. DP algo for Knapsack sinceWis in numeric value\n•Knapsack is NOT polynomial time: O(nWlogM)butWis\nnot the number of bits\n•Fractional Knapsack is polynomial time:\nO(nlognlogWlogM)\nDecision Problems\n•decisionproblem→a function that maps an instance\nspaceIto the solution set{YES,NO}\n•decision vs optimisation problem:\n•decision problem : given a directed graph G,is therea\npath from vertex utovof length≤k?\n•optimisation problem : given ..., what is the lengthof the\nshortest path ... ?\n•convert from decision→optimisation : give

In [29]:
query_and_print("How to solve it using Prim's algorithm?")

{'input': "How to solve it using Prim's algorithm?"}
{'chat_history': [HumanMessage(content='What is an MST?'), AIMessage(content="An MST (Minimum Spanning Tree) is a subset of the edges of a connected weighted graph that connect all the vertices together while minimizing the total edge weight. In other words, it is a subgraph of the original graph that contains all the vertices and has the smallest possible total edge weight. The MST is used in many algorithms, such as the Kruskal's algorithm and Prim's algorithm, to find the minimum spanning tree of a graph.")]}
{'context': [Document(page_content='graph faster than O(VE), because for every vertex, we need O(E) time to determine \nits neighbours. \n Proof: \nFalse. If we sort the edge list by the starting vertex followed by the second vertex, \nwe can potentially reduce any searching time to be O(log E) for binary search. The sorting itself will be O(E log E). This combined with the processing of N vertices \nresults in around O(Max(V

In [27]:
store.clear()