### LangChain local LLM RAG example with self chunking, reranking, and maximising context based on token length
### For LangSmith users (requires API key)
Utilising LangChain v0.1

This notebook demonstrates the use of LangChain for Retrieval Augmented Generation in Linux with Nvidia's CUDA. LLMs are run using Ollama.

It has self-chunking (where we split up our document into chunks) and then re-ranking the retrieved results before passing into the LLM.

Finally, it uses the token counts of paragraphs for the context to maximise how much we give to the LLM.

Models tested:
- Llama 2
- Mistral 7B
- Mixtral 8x7B
- Neural Chat 7B
- Orca 2
- Phi-2
- Solar 10.7B
- Yi 34B


See the [README.md](README.md) file for help on how to setup your environment to run this.

We start with creating a callback handler so that we can get the number of tokens after the LLM has created inference.

In [1]:
from typing import Any, Optional, Sequence
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
from langchain_core.documents import Document

global llmresult_prompt_token_count # This will be updated with the total token count of the prompt when an LLM finished inference
global llmresult_response_token_count # This will be updated with the total token count of the response when an LLM finished inference

class GenerationStatisticsCallback(BaseCallbackHandler):
    def on_llm_end(self, response: LLMResult, **kwargs) -> None:

        # When the LLM inference has finished, store token counts in global variables for use outside of here
        global llmresult_prompt_token_count
        llmresult_prompt_token_count = response.generations[0][0].generation_info["prompt_eval_count"]

        global llmresult_response_token_count
        llmresult_response_token_count = response.generations[0][0].generation_info["eval_count"]

        print(f"\n\n ----\n\n[ PROMPT TOKEN COUNT {llmresult_prompt_token_count} | RESPONSE TOKEN COUNT {llmresult_response_token_count} ]")

In [2]:
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

callback_manager = CallbackManager(
    [StreamingStdOutCallbackHandler(), GenerationStatisticsCallback()]
)

In [3]:
# Target Context token count
# This is the total amount of tokens from the context we retrieve that we want to put into the prompt for the LLM to use for RAG
# We want to maximise the context put in without putting in more than a certain amount

# We'll set this with the model selection in the next cell
maximum_context_tokens = 0

# Important - about 500 tokens are added to this to generate the full prompt for the LLM before it responds, so we're
# looking at about 1500 tokens + the answer tokens

In [29]:
# Select your model here, put the name of the model in the ollama_model_name variable
# Ensure you have pulled them or run them so Ollama has downloaded them and can load them (which it will do automatically)

# Ollama installation (if you haven't done it yet): $ curl https://ollama.ai/install.sh | sh
# Models need to be running in Ollama for LangChain to use them, to test if it can be run: $ ollama run mistral:7b-instruct-q6_K

# Creating a list of tuples (model_name, max_context_tokens)
ollama_model_configs = [
    ("llama2:7b-chat-q6_K", 2000),                  # 0
    ("mistral:7b-instruct-q6_K", 2000),             # 1
    ("mixtral:8x7b-instruct-v0.1-q4_K_M", 2000),    # 2
    ("neural-chat:7b-v3.3-q6_K", 2000),             # 3
    ("orca2:13b-q5_K_S", 2000),                     # 4
    ("phi", 1000),                                  # 5
    ("solar:10.7b-instruct-v1-q5_K_M", 2000),       # 6
]

# CHANGE THIS VALUE TO THE INDEX OF THE MODEL YOU WANT TO USE:
model_index = 2

# Then we load the values into our variables
ollama_model_name, maximum_context_tokens = ollama_model_configs[model_index]

print(f"Ollama Model selected: {ollama_model_name} with maximum context tokens allowed set to {maximum_context_tokens}")

# Note: Can't run "yi:34b-chat-q3_K_M" or "yi:34b-chat-q4_K_M" - never stopped with inference

Ollama Model selected: mixtral:8x7b-instruct-v0.1-q4_K_M with maximum context tokens allowed set to 2000


In [30]:
# Our LangSmith API key is stored in apikeys.py
# Store your LangSmith key in a variable called LangSmith_API

from apikeys import LangSmith_API
import os

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = LangSmith_API

# Project Name
os.environ["LANGCHAIN_PROJECT"] = "LangChain RAG Linux Chunking"

In [31]:
# Load the LLM with Ollama, setting the temperature low so it's not too creative

from langchain_community.llms import Ollama
llm = Ollama(model=ollama_model_name,
    callback_manager=callback_manager,
    temperature=0.1)

In [32]:
# Quick test of the LLM with a general question before we start doing RAG
# llm.invoke("why is the sky blue?")

# Note: This line would not complete for Yi-34B - need to work out why inferencing never finishes (works fine when running with the same prompt in ollama.)

In [33]:
# Embeddings will be based on the Ollama loaded model

from langchain_community.embeddings import OllamaEmbeddings

embeddings = OllamaEmbeddings(model=ollama_model_name)

In [34]:
from langchain_community.document_loaders import DirectoryLoader

loader = DirectoryLoader('Data', glob="**/*.docx")

In [35]:
# Load documents

docs = loader.load()

In [36]:
docs

 Document(page_content="Thundertooth\n\nEmbraced by the futuristic city and its inhabitants, Thundertooth found a sense of purpose beyond merely satisfying his hunger. Inspired by the advanced technology surrounding him, he decided to channel his creativity into something extraordinary. With the help of the city's brilliant engineers, Thundertooth founded a one-of-a-kind toy factory that produced amazing widgets – magical, interactive toys that captivated the hearts of both children and adults alike.\n\nThundertooth's toy factory became a sensation, and its creations were highly sought after. The widgets incorporated cutting-edge holographic displays, levitation technology, and even the ability to change shapes and colors with a mere thought. Children across the city rejoiced as they played with these incredible toys that seemed to bring their wildest fantasies to life.\n\nAs the years passed, Thundertooth's life took a heartwarming turn. He met a kind and intelligent dinosaur named Se

In [37]:
# Ensure we have the right number of Word documents loaded

len(docs)

4

We create a function to split text into paragraphs but keep numbered sections, bullet points, and lists together. This is suitable for the document because it has numbered and bulleted points - this would need to be changed to suit the document.

In [38]:
import re

# Define the regular expression pattern for splitting paragraphs
para_split_pattern = re.compile(r'\n\n')

# Splits a document's text into paragraphs but if it has numbered or bulleted points, they will be included with the paragraph before it.
def split_text_into_paragraphs(text):


    # Use the pattern to split the text into paragraphs
    paragraphs = para_split_pattern.split(text)

    # Combine paragraphs that should not be split
    combined_paragraphs = [paragraphs[0]]

    for p in paragraphs[1:]:
        # Check if the paragraph starts with a number or a dash and, if so, concatenate it to the previous paragraph so we keep them all in one chunk

        # Strip out any leading new lines
        p = p.lstrip('\n')

        if p and (p[0].isdigit() or p[0] == '-' or p.split()[0].endswith(':')):
            combined_paragraphs[-1] += '\n\n\n' + p
        else:
            combined_paragraphs.append(p)

    # Remove empty strings from the result
    combined_paragraphs = [p.strip() for p in combined_paragraphs if p.strip()]

    return combined_paragraphs

Create nodes from the paragraphs that we've carefully split up, counting the paragraphs so we know what kind of token length we're working with.

We can use the LLM object to count the tokens with get_num_tokens.

In [39]:
from langchain.docstore.document import Document

paragraph_separator = "\n\n\n"

# Stores the maximum length of a paragraph, in tokens
max_paragraph_tokens = 0

# Total tokens, used to determine average
total_paragraph_tokens = 0

# Nodes
paragraph_nodes = []

# Loop through the documents, splitting each into paragraphs and checking the number of tokens per paragraph
for document in docs:

    paragraph_token_lens = []
    paragraphs = split_text_into_paragraphs(document.page_content)
    print(f"Document {document.metadata['source']} has {len(paragraphs)} paragraphs")
    for paragraph in paragraphs:

        # Count the tokens in this paragraph
        token_count = llm.get_num_tokens(paragraph)
        paragraph_token_lens.append(token_count)
        print(f"Paragraph tokens: {token_count}")

        if token_count > max_paragraph_tokens:
            max_paragraph_tokens = token_count

        total_paragraph_tokens = total_paragraph_tokens + token_count

        # Create and add the node from the paragraph
        # include metadata we can use for citations
        node = Document(page_content=paragraph) # Copy the metadata from the Word document into here
        node.metadata["source"] = document.metadata["source"]
        node.metadata["token_count"] = token_count
        paragraph_nodes.append(node)

    # print(paragraph_token_lens)

print(f"\n** The maximum paragraph tokens is {max_paragraph_tokens} **")

average_paragraph_tokens = int(total_paragraph_tokens / len(paragraph_nodes))
print(f"\n** The average paragraph's token count is {average_paragraph_tokens} **")

print(f"\n** Created {len(paragraph_nodes)} nodes **")


Document Data/Thundertooth Part 3.docx has 10 paragraphs
Paragraph tokens: 3
Paragraph tokens: 75
Paragraph tokens: 51
Paragraph tokens: 54
Paragraph tokens: 193
Paragraph tokens: 60
Paragraph tokens: 86
Paragraph tokens: 65
Paragraph tokens: 57
Paragraph tokens: 83
Document Data/Thundertooth Part 2.docx has 6 paragraphs
Paragraph tokens: 3
Paragraph tokens: 88
Paragraph tokens: 70
Paragraph tokens: 327
Paragraph tokens: 64
Paragraph tokens: 69
Document Data/Thundertooth Part 1.docx has 13 paragraphs
Paragraph tokens: 3
Paragraph tokens: 89
Paragraph tokens: 68
Paragraph tokens: 83
Paragraph tokens: 56
Paragraph tokens: 73
Paragraph tokens: 60
Paragraph tokens: 24
Paragraph tokens: 37
Paragraph tokens: 23
Paragraph tokens: 49
Paragraph tokens: 54
Paragraph tokens: 105
Document Data/Thundertooth Part 4.docx has 14 paragraphs
Paragraph tokens: 3
Paragraph tokens: 89
Paragraph tokens: 61
Paragraph tokens: 71
Paragraph tokens: 66
Paragraph tokens: 67
Paragraph tokens: 72
Paragraph tokens: 

Let's see the split data - now neatly in paragraphs and the bullet points and lists are with their respective paragraph.

In [40]:
paragraph_nodes

[Document(page_content='Thundertooth', metadata={'source': 'Data/Thundertooth Part 3.docx', 'token_count': 3}),
 Document(page_content="One fateful day, as the citizens of the futuristic city went about their daily lives, a collective gasp echoed through the streets as a massive meteor hurtled towards Earth. Panic spread like wildfire as people looked to the sky in horror, realizing the impending catastrophe. The city's advanced technology detected the threat, and an emergency broadcast echoed through the streets, urging everyone to seek shelter.", metadata={'source': 'Data/Thundertooth Part 3.docx', 'token_count': 75}),
 Document(page_content="Thundertooth, ever the protector of his newfound home, wasted no time. With a determined gleam in his eyes, he gathered his family and hurried to the city's command center, where Mayor Grace and the leading scientists were coordinating the evacuation efforts.", metadata={'source': 'Data/Thundertooth Part 3.docx', 'token_count': 51}),
 Document(p

We no longer need to use the LangChain text splitter as we've already done the splitting

In [41]:
# Split them up into chunks using a Text Splitter

# from langchain.text_splitter import RecursiveCharacterTextSplitter

# text_splitter = RecursiveCharacterTextSplitter()
# documents = text_splitter.split_documents(docs)

In [42]:
# Create the embeddings from our split up chunks

from langchain_community.vectorstores import FAISS

vector = FAISS.from_documents(paragraph_nodes, embeddings)

In preparing the prompt, we add direction to include citations so that the LLM is instructed to include the sources in its response (hopefully!).

In [43]:
# Prepare the prompt and then the chain

from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

if ollama_model_name == "phi" or ollama_model_name == "phi:chat":
    # Phi-2 prompt is less flexible
    prompt_template = """Instruct: With this context\n\n{context}\n\nQuestion: (Include citations) {question}\nOutput:"""

else:
    # prompt_template = """You are a story teller , answering questions in an excited, insightful, and empathetic way. Answer the question based only on the provided context:
    prompt_template = """You are a story teller writing in the style of Agatha Christie. Answer the question only with the provided context. YOU MUST INCLUDE THE SOURCES.

    <context>
    {context}
    </context>

    Question: {question}"""

prompt = PromptTemplate(
    template=prompt_template, 
    input_variables=[
        'context', 
        'question',
    ]
)

Now that we have broken down the documents into paragraph-sized chunks we need to retrieve more paragraphs so the LLM has a decent amount of context to use. Without adding the "search_kwargs" parameter the answer to the questions was worse. For example, when asked if they had any children no relevant context was provided.

Note: To be able to get the context for the children's names to be included (and then reranked to the top) I needed to set the number of retrieved chunks to 20. The section with the children's names was the 11th result from the retriever! This indicates that retrieving more than you think you need is likely.

In [44]:
# Create the retriever and set it to return a good amount of chunks

from langchain.chains import create_retrieval_chain

# We use a variable to store the number of results as we'll use the same amount for the reranking (as we'll manually remove some)
retrieval_chunks = 20

retriever = vector.as_retriever(search_kwargs={"k": retrieval_chunks})

Let's implement the Cohere reranking, utilising our retriever (which is getting more results to work with now) and our LLM

Note: You'll need a Cohere API key. A trial key is free for non-commercial purposes. I've stored it in apikey.py as Cohere_API = "your key in here"

https://cohere.com/

In [45]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank

from apikeys import Cohere_API

# Create the retriever
# Here we retrieve the same number of chunks and we'll have a relevance score so we can cut out the lowest ranking ones
# that fit into our target maximum context tokens  
compressor = CohereRerank(cohere_api_key=Cohere_API, top_n=retrieval_chunks)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=retriever,
)

In [46]:
# Let's test that it includes the paragraph starting with "As the years passed..." when asked about their children.

test_retrieval = compression_retriever.get_relevant_documents("Did they have any children? If so, what were their names?")

test_retrieval

[Document(page_content="Lumina, Echo, and Sapphire grew concerned as they noticed Ignis's increasingly erratic behavior. They attempted to reason with him, pleading for him to abandon his destructive ambitions and embrace the family's legacy of unity. However, Ignis, consumed by his thirst for power, rejected their pleas and retreated further into the shadows.", metadata={'source': 'Data/Thundertooth Part 4.docx', 'token_count': 66, 'relevance_score': 0.66551924}),
 Document(page_content="Thundertooth and Seraphina reveled in the joy of parenthood, watching their children grow and flourish in the futuristic landscape they now called home. The family became an integral part of the city's fabric, not only through the widgets produced in their factory but also through the positive impact each member had on the community.", metadata={'source': 'Data/Thundertooth Part 2.docx', 'token_count': 64, 'relevance_score': 0.5951397}),
 Document(page_content='Thundertooth', metadata={'source': 'Data

The above shows that, indeed, we are able to get that paragraph and it is the highest ranked.

Importantly, if we had not brought enough chunks back with the retriever (referring to the vector store retriever) then we would not have had the right chunks to run through Cohere for reranking.

So if this line:
```
retriever = vector.as_retriever(search_kwargs={"k": 20})
```

was:
```
retriever = vector.as_retriever(search_kwargs={"k": 10})
```

We would not have been able to get that "As the years passed..." chunk for reranking.

Additionally, we're able to compress the number of chunks from the 11+ we needed to get the right chunk down to 5 because we have the best 5 of that bunch. This reduces the tokens needed for the LLM to process.

Now, we create a LangChain chain with the Cohere reranker retriever.

Except, we need to do a hack to be able to get the callback triggered after our documents are retrieved from cohere. We do it through a function that returns the chain (rather than creating a chain like we did previously in the following cell).

See GitHub issue for more information:
https://github.com/langchain-ai/langchain/issues/7290

In [47]:
# from langchain.chains import RetrievalQA

# rerank_chain = RetrievalQA.from_chain_type(
    # llm=llm,
    # retriever=compression_retriever,
    # return_source_documents=True,
    # chain_type_kwargs={"prompt": prompt}    # Pass in our prompt
# )

In [48]:
from langchain.schema.runnable import RunnablePassthrough
from uuid import uuid4

def get_rerankchain(retriever, llm, callbacks):

    # Format the documents for our context
    # We're include the source and relevance, hoping the LLMs will use it
    # The prompt could be used to include examples of the citations
    def format_docs(docs):
        formatted_text = []

        for doc in docs:
            source = doc.metadata.get("source", "Unknown Source")
            relevance_score = int(doc.metadata.get("relevance_score", 0.0) * 100)
            page_content = doc.page_content

            formatted_doc = (
                f"[Source '{source}', Relevance {relevance_score}]\n"
                f"{page_content}\n"
            )
            formatted_text.append(formatted_doc)

        result = "\n".join(formatted_text)
        return result

    def hack_inject_callback(docs):
        # https://github.com/langchain-ai/langchain/issues/7290
        for callback in callbacks:
            callback.on_retriever_end(docs, run_id=uuid4())

        return docs

    return (
        {"context": retriever | hack_inject_callback | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
    )

We create a callback handler to be called after we've retrieved all the documents with their relevance score.

With these, we count the tokens of each and remove any that exceed our maximum context token length.

THis maximises the number of documents we send to the LLM while keeping within our context length bounds.

In [49]:
class RetrievalHandler(BaseCallbackHandler):

    def on_retriever_end(self, documents, **kwargs):
        total_tokens = 0
        context_tokens_used = 0

        # Documents we'll keep as they fit within our target token count
        documents_to_keep = []

        for idx, doc in enumerate(documents):
            total_tokens += doc.metadata["token_count"]

            if total_tokens <= maximum_context_tokens:
                # Good to keep
                context_tokens_used += doc.metadata["token_count"]
                documents_to_keep.append(doc)

        print(f"[ON RETRIEVER END - FINISH - total tokens {total_tokens} across {len(documents)} documents. Kept {context_tokens_used} across {len(documents_to_keep)}]")

        # Modify the contents of the original 'documents' list which we'll then format for insertion into the context within the prompt
        documents.clear()
        documents.extend(documents_to_keep)


In [50]:
# Here are our test questions

TestQuestions = [
    "Summarise the story for me.",
    "Who was the main protagonist?",
    "Did they have any children? If so, what were their names?",
    "Did anything eventful happen?",
    "Who are the main characters?",
    "What do you think happens next in the story?"
]

Ask our questions with our reranking chain.

Here we reiterate that we want to include citations.

In [53]:
# Our Question and Answers for display
qa_pairs = []

# our Retrieval Call Back Handler - so we can intercept the results of the retriever and adjust how many documents are passed through to the LLM
retrieval_callback_handler = RetrievalHandler()

for index, question in enumerate(TestQuestions, start=1):
    question = question.strip() # Clean up

    print(f"\n{index}/{len(TestQuestions)}: {question}")

    rerank_chain = get_rerankchain(retriever=compression_retriever, llm=llm, callbacks=[retrieval_callback_handler])

    # Ask the Retriever and then the LLM the question
    response = rerank_chain.invoke(question)

    # Keep track of question, answer, prompt tokens, and repsonse tokens
    qa_pairs.append((question.strip(), response.strip(), llmresult_prompt_token_count, llmresult_response_token_count)) # Add to our output array

    # Uncomment the following line if you want to test just the first question
    # break 


1/6: Summarise the story for me.
[ON RETRIEVER END - FINISH - total tokens 965 across 20 documents. Kept 965 across 20]
 In a futuristic city of advanced technology and flying cars, Thundertooth, a talking dinosaur from the past, arrives through a time portal. The mayor, Grace, invites him to stay in the city, where he becomes a symbol of unity between the past and the future. Thundertooth and his family, including Lumina, Echo, Sapphire, and Ignis, live in harmony with humans, becoming beloved figures and an integral part of the community. They also establish a toy factory that brings dinosaurs and humans together in shared creativity and innovation.

When a meteor threatens the city, Thundertooth's family uses their unique abilities to divert it from its path, saving the city and its inhabitants. However, Ignis becomes consumed by darkness and goes on a rampage, causing chaos in the city. The other siblings confront Ignis in an epic battle, determined to save their brother and the c

In [54]:
# Print out the questions and answers

for index, (question, answer, prompttokens, responsetokens) in enumerate(qa_pairs, start=1):
    print(f"{index}/{len(qa_pairs)} {question}\n\n[Prompt Tokens: {prompttokens}, Response Tokens: {responsetokens}]\n\n{answer}\n\n--------\n")

1/6 Summarise the story for me.

[Prompt Tokens: 1711, Response Tokens: 272]

In a futuristic city of advanced technology and flying cars, Thundertooth, a talking dinosaur from the past, arrives through a time portal. The mayor, Grace, invites him to stay in the city, where he becomes a symbol of unity between the past and the future. Thundertooth and his family, including Lumina, Echo, Sapphire, and Ignis, live in harmony with humans, becoming beloved figures and an integral part of the community. They also establish a toy factory that brings dinosaurs and humans together in shared creativity and innovation.

When a meteor threatens the city, Thundertooth's family uses their unique abilities to divert it from its path, saving the city and its inhabitants. However, Ignis becomes consumed by darkness and goes on a rampage, causing chaos in the city. The other siblings confront Ignis in an epic battle, determined to save their brother and the city. In the end, they manage to reach Ignis,