# RAG

This notebook is to:
1. Load documents
2. Split documents into chunks
3. Save documents and chunks to `data` directory

References
* [Contextual Retrieval](https://www.anthropic.com/news/contextual-retrieval)
* [Hallucination Elimination Using Acurai](https://arxiv.org/pdf/2412.05223)
* [s1: Simple test-time scaling](https://arxiv.org/pdf/2501.19393)
* [Implementing Contextual Retrieval in RAG pipline](https://medium.com/the-ai-forum/implementing-contextual-retrieval-in-rag-pipeline-8f1bc7cbd5e0)

In [None]:
# Rag Setup
%%capture
!pip install -q sentence_transformers
!pip -q install langchain
!pip -q install langchain-qdrant
!pip install langchain_community
!pip install --upgrade --quiet chromadb bs4 qdrant-client
!pip install langchainhub
!pip install -U langchain-huggingface
!pip install --upgrade --quiet  pymupdf

### Contextual Retrieval Installations
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf==3.20.3 datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install transformers

### Contextual BM25
!pip install elasticsearch
!pip install rank_bm25
!pip install faiss-cpu
!pip install flashrank

In [None]:
import torch
import os
import json
import pickle
import numpy as np
import threading
import time, locale
from google.colab import drive
from pprint import pprint
from datetime import datetime
from typing import List, Dict, Any
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.utils.math import cosine_similarity
from langchain.schema import Document

from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams

from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders import PyMuPDFLoader

### Contextual Retrieval Imports
from tqdm import tqdm
from unsloth import FastLanguageModel
from transformers import TextStreamer
from unsloth.chat_templates import get_chat_template

### BM25
import hashlib
import os
import getpass
from typing import List, Tuple
from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate
from rank_bm25 import BM25Okapi
from langchain.retrievers import ContextualCompressionRetriever,BM25Retriever,EnsembleRetriever
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.embeddings import HuggingFaceEmbeddings

import flashrank

### TODO: Parallelize inferencing
# from concurrent.futures import ThreadPoolExecutor, as_completed

# from langchain_openai import ChatOpenAI
# from langchain_groq import ChatGroq
# from transformers import AutoTokenizer , AutoModelForCausalLM
# from transformers import pipeline, BitsAndBytesConfig
# from langchain_huggingface import HuggingFacePipeline
# from langchain.llms import HuggingFacePipeline
# from langchain import PromptTemplate, LLMChain
# from langchain_core.prompts import ChatPromptTemplate
# from langchain_text_splitters import CharacterTextSplitter
# from langchain_core.output_parsers import StrOutputParser
# from langchain import hub
# from langchain_community.document_loaders import WebBaseLoader
# from langchain_community.vectorstores import Qdrant
# from langchain_core.output_parsers import StrOutputParser
# from langchain_core.runnables import RunnablePassthrough

In [None]:
locale.getpreferredencoding = lambda: "UTF-8"
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

In [None]:
##### Hyper Parameters
# gte-Qwen2-1.5B-instruct
embedding_base = "multi-qa-mpnet-base-dot-v1"
embedding_bge = "BAAI/bge-large-en-v1.5"
# model_name = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit"
model_name = "unsloth/Meta-Llama-3.1-8B-Instruct"
chunk_size = 256
chunk_overlap = 40

In [None]:
# Embedding Model
%%capture
base_embeddings = HuggingFaceEmbeddings(model_name=embedding_bge)

In [None]:
# Folder
drive.mount('/content/drive')
folder_location = '/content/drive/MyDrive/capstone/RAG_items'
list_of_documents = os.listdir(folder_location)
save_file_location = os.path.join(folder_location, 'data')
if not os.path.exists(save_file_location):
  os.makedirs(save_file_location)
pdf_file_path = 'pdfs_to_ingest.txt'
extra_pdf = 'OSW Self-Care Workbook.pdf' # TODO: still need to ingest
list_of_documents.remove(pdf_file_path)
list_of_documents.remove(extra_pdf)

### TXT location
txt_files_to_ingest_path = os.path.join(folder_location, 'txt_files_to_ingest')
txt_files_to_ingest = os.listdir(txt_files_to_ingest_path)

### PDFs directory location for the urls
pdf_files_to_ingest_path = os.path.join(folder_location, 'pdfs_to_ingest.txt')
pdf_urls = []
with open(pdf_files_to_ingest_path, 'r') as f:
    pdf_urls.append(f.read().splitlines())
pdf_urls = list(set(pdf_urls[0]))

In [None]:
pdf_urls

In [None]:
class GetDocumentLoader:
  def __init__(self):
    self.text_splitter = None
    self.global_doc_number = 0
    self.documents = []
    self.chunks = []

  def get_text_splitter(self):
    return self.text_splitter

  def get_global_doc_number(self):
    return self.global_doc_number

  def get_documents(self):
    return self.documents

  def get_chunks(self):
    return self.chunks

  def set_text_splitter(self, chunk_size, chunk_overlap):
    self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

  def load_from_file(self, path_to_file):
    with open(path_to_file, 'rb') as f:
      data = pickle.load(f)
      self.documents = data['documents']
      self.chunks = data['chunks']
      self.global_doc_number = data['global_doc_number']
      self.text_splitter = data['text_splitter']
      print(f'Loaded {len(self.documents)} documents')
      print(f'Loaded {len(self.chunks)} chunks')

  def save_to_file(self, path_to_file):
    data = {
        'documents': self.documents,
        'chunks': self.chunks,
        'global_doc_number': self.global_doc_number,
        'text_splitter': self.text_splitter
    }
    with open(path_to_file, 'wb') as f:
      pickle.dump(data, f)
    print(f"Saved {len(self.documents)} documents")
    print(f"Saved {len(self.chunks)} chunks")

  def data_txt_separator(self, file_path):
    with open(file_path, 'r', encoding='utf-8-sig') as f:
      lines = f.readlines()
      metadata = {}
      for line in lines[:6]:
        if ":" in line:
          key, value = line.strip().split(":", 1)
          metadata[key.strip()] = value.strip()
          metadata['doc_num'] = self.get_global_doc_number()
      self.global_doc_number += 1
      content = "".join(lines[len(metadata):]).strip()
      return Document(page_content=content, metadata=metadata)

  def load_text_documents(self, directory_path):
    with tqdm(total=len(os.listdir(directory_path)), desc="Processing files") as pbar:
      for filename in os.listdir(directory_path):
        if filename.endswith('txt'):
          file_path = os.path.join(directory_path, filename)
          doc = self.data_txt_separator(file_path)
          self.documents.append(doc)
          pbar.update(1)
        print(f'Loaded: {self.get_global_doc_number()}')
      print(f'Complete Loading of : {self.get_global_doc_number()}')

  def load_pdf_documents(self, url_paths: list) -> list:
    with tqdm(total=len(url_paths), desc="Processing URLs") as pbar:
      for url_path in url_paths:
        loader = PyMuPDFLoader(url_path)
        pages = loader.load()
        for page_num in range(len(pages)):
          page = pages[page_num]
          page.metadata['doc_num'] = self.get_global_doc_number()
          self.documents.append(page)
          pbar.update(1)
          self.global_doc_number += 1
        print(f'Loaded: {self.get_global_doc_number()}')
    print(f'Global doc number after file: {self.global_doc_number}')

  def index_splitter_doc_chunks(self, chunk_size, chunk_overlap):
    self.set_text_splitter(chunk_size, chunk_overlap)
    splits = self.text_splitter.split_documents(self.documents)
    for idx, text in enumerate(splits):
      text.metadata['chunk_num'] = idx
      text.metadata['chunk_id'] = f"doc_{text.metadata['doc_num']}_chunk_{text.metadata['chunk_num']}"
    self.chunks = splits
    print(f'number of splits/chunks: {len(self.chunks)}')

  # def_index_splitter_list_doc_chunks(self, chunk_size, chunk_overlap):
  #   self.set_text_splitter(chunk_size, chunk_overlap)
  #   splits = self.text_splitter.split_documents(self.documents)

In [None]:
dl = GetDocumentLoader()
date_now = datetime.now().strftime("%Y-%m-%d")

In [None]:
# ### Load documents and chunks from file
# if os.path.exists(os.path.join(save_file_location, 'data--contextual-retrieval-2025-02-11-28.pkl')):
#   dl.load_from_file(os.path.join(save_file_location, 'data--contextual-retrieval-2025-02-11-28.pkl'))
# else:
#   print('file does not exists')

In [None]:
# ### Load documents and chunks
# print('Start the loading process')
# dl.load_text_documents(txt_files_to_ingest_path)
# dl.get_global_doc_number()
# print(f"Amount of Text Documents: {len(dl.get_documents())}")
# dl.load_pdf_documents(pdf_urls)
# print(f"Amount of PDF Documents: {len(dl.get_documents())}")
# print(f"Total Documents: {dl.get_global_doc_number()}")

# # split documents up into chunks and index each chunk
# print('Start chunking process')
# dl.index_splitter_doc_chunks(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

In [None]:
# if not os.path.exists(os.path.join(save_file_location, f'data-doc-chunks-{date_now}.pkl')):
#   dl.save_to_file(os.path.join(save_file_location, f'data-doc-chunks-{date_now}.pkl'))

In [None]:
### Load new txt in
text_2025_02_14 = "/content/drive/MyDrive/capstone/RAG_items/2025_02_14_text"

In [None]:
dl.load_text_documents(text_2025_02_14)

In [None]:
### Load new pdfs in
dl.load_pdf_documents(['https://careerdevelopment.princeton.edu/sites/g/files/toruqf1041/files/documents/networking_guide-oct._2020.pdf'])

In [None]:
pprint(dl.get_documents()[29].metadata)

In [None]:
### Load new documents into chunks
print('Start chunking process')
# dl.index_splitter_doc_chunks(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
len(dl.get_chunks())


In [None]:
### Save new data into RAG
if not os.path.exists(os.path.join(save_file_location, f'data-doc-chunks-{date_now}.pkl')):
  dl.save_to_file(os.path.join(save_file_location, f'data-doc-chunks-{date_now}.pkl'))

In [None]:
# dl.get_chunks()[56].metadata

In [None]:
### Load model
max_seq_length = 2048
dtype = None
load_in_4bit = True
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

## Loading Model
model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    # model_name = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    model_name = model_name,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
### Enable Inference Optimization
FastLanguageModel.for_inference(model)

In [None]:
### Contextual Retrieval
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>

Here is the chunk we want to situate within the whole document:
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.

<answer>
"""

def situate_context(doc: str, chunk: str) -> str:
  prompt = DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc, chunk_content=chunk)
  inputs = tokenizer([prompt], return_tensors="pt").to(device)
  outputs = model.generate(**inputs, max_new_tokens=1024, temperature=0.1, do_sample=False)
  return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

In [None]:
### function to process document and one chunk for contextual retrieval and returns the it
def get_contextual_content(doc: Document, chunk: Document) -> str:
  result = situate_context(doc.page_content, chunk.page_content)
  contextual_result = result.split('<answer>')[1].split('</answer>')[0].strip()
  chunk.metadata['contextualized_content'] = contextual_result.strip()
  return chunk

In [None]:
# chunk_count = 0
# for t_chunk in dl.get_chunks():
#   if t_chunk.metadata['doc_num'] == dl.get_documents()[7].metadata['doc_num']:
#     chunk_count += 1
# print(chunk_count)

# chunk_count = len([t_chunk for t_chunk in dl.get_chunks() if t_chunk.metadata['doc_num'] == dl.get_documents()[7].metadata['doc_num']])
# print(chunk_count)

In [None]:
def process_doc_chunk_for_contextual_retrieval(docs: List[Document], chunks: List[Document]):
  # process without using parallel threads
  for doc in docs:
    with tqdm(total=len([t_chunk for t_chunk in chunks
              if t_chunk.metadata['doc_num'] == doc.metadata['doc_num']]),
              desc="Processing chunks") as pbar:
      print(f"Processing doc: {doc.metadata['doc_num']}")
      for chunk in chunks:
        if chunk.metadata['doc_num'] == doc.metadata['doc_num']:
          get_contextual_content(doc, chunk)
          pbar.update(1)

### TODO: Implement parallel threading
# parallel_threads = 4
# def process_doc_chunk_for_contextual_retrieval(docs: List[Document], chunks: List[Document], parallel_threads):
#   with ThreadPoolExecutor(max_workers=parallel_threads) as executor:
#     futures = []
#     results = []
#     for doc in docs[:1]:
#       print(f'Document metadata: {doc.metadata}')
#       for chunk in chunks[:2]:
#         print(f'Chunk metadata: {chunk.metadata}')
#         if chunk.metadata['doc_num'] == doc.metadata['doc_num']:
#           futures.append(executor.submit(get_contextual_content, doc, chunk))
#     for future in tqdm(as_completed(futures), total=len(futures), desc="Processing chunks"):
#       results.append(future.result())
#     return results

In [None]:
document_number = 29
max_doc_number = len(dl.get_documents())

In [None]:
max_doc_number

In [None]:
dl.get_documents()[-1]

In [None]:
type(dl.get_documents())

In [None]:
dl.get_chunks()[606]

In [None]:
### Run contextual retrieval
## Add checkpoints to ensure no YOLO runs
## completes after each document and all its chunks
for i in range(document_number, max_doc_number):
  process_doc_chunk_for_contextual_retrieval(dl.get_documents()[i:i+1], dl.get_chunks())
  if not os.path.exists(os.path.join(save_file_location, f'data--contextual-retrieval-{date_now}-{document_number}.pkl')):
    dl.save_to_file(os.path.join(save_file_location, f'data--contextual-retrieval-{date_now}-{document_number}.pkl'))
  document_number += 1
  print(f'Completed document {i}...saving')

  # process_doc_chunk_for_contextual_retrieval(dl.get_documents(), dl.get_chunks())

In [None]:
if not os.path.exists(os.path.join(save_file_location, f'data--contextual-retrieval-{date_now}.pkl')):
  dl.save_to_file(os.path.join(save_file_location, f'data--contextual-retrieval-{date_now}.pkl'))

In [None]:
# SentenceTransformer("hkunlp/instructor-large")

In [None]:
### Initiate vectorstore
client = QdrantClient(":memory:")
client.create_collection(
    collection_name = "mental_health_db",
    vectors_config = VectorParams(size = 1024, distance = Distance.COSINE)
)
vector_store = QdrantVectorStore(
    client = client,
    collection_name = "mental_health_db",
    embedding = base_embeddings,
)

In [None]:
### Add
# vector_store.add_documents(dl.get_chunks())

In [None]:
query = "What is a financial budget?"

results = vector_store.similarity_search_with_score(query, k=4)

In [None]:
class BM25:
  def __init__(self, text_splitter, base_embeddings, model, tokenizer):
    self.text_splitter = text_splitter
    self.base_embeddings = base_embeddings
    self.model = model
    self.tokenizer = tokenizer

  def get_text_splitter(self):
    return self.text_splitter

  def get_base_embeddings(self):
    return self.base_embeddings

  def get_model(self):
    return self.model

  def create_vectorstores(self, chunks: List[Document]) -> FAISS:
    """
    Create a BM25 index for the given chunks
    """
    return FAISS.from_documents(chunks, self.base_embeddings)

  def create_bm25_index(self, chunks: List[Document]) -> BM25Okapi:
    """
    Create a BM25 index for the given chunks
    """
    tokenized_chunks = [chunk.page_content.split() for chunk in chunks]
    return BM25Okapi(tokenized_chunks)

  def create_flashrank_index(self, vectorstore):
    """
    Create a FlashRank index for the given chunks
    # """
    # ranker = flashrank.Ranker(self.base_embeddings)
    retriever = vectorstore.as_retriever(search_kwargs={"k":10})
    compression_retriever = ContextualCompressionRetriever(base_compressor=FlashrankRerank().model_rebuild(), base_retriever=retriever)
    return compression_retriever

  def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever:
    """
    Create a BM25 index for the given chunks
    """
    bm25_retriever = BM25Retriever.from_documents(chunks)
    return bm25_retriever

  def create_ensemble_retriever_reranker(self, vectorstore, bm25_retriever) ->EnsembleRetriever:
    """
    Create an ensemble retriever for the given chunks
    """
    retriever_vs = vectorstore.as_retriever(search_kwargs={"k":10})
    bm25_retriever.k = 5
    ensemble_retriever = EnsembleRetriever(
        retrievers=[retriever_vs, bm25_retriever],
        weights=[0.5, 0.5]
    )
    redundant_filter = EmbeddingsRedundantFilter(self.base_embeddings)
    reranker = FlashrankRerank()
    pipeline_compressor = DocumentCompressorPipeline(
        transformers=[redundant_filter, reranker])
    compression_pipeline = ContextualCompressionRetriever(
        base_compressor=pipeline_compressor, base_retriever=ensemble_retriever)
    return compression_pipeline

  @staticmethod
  def generate_cache_key(document: str) -> str:
    """
    Generate a cache key for a document
    """
    return hashlib.md5(document.encode()).hexdigest()

  def generate_answer(self, query: str, relevant_chunks: List[str]) -> str:
    """
    Generate an answer for the given query and relevant chunks
    """
    prompt = """
    Based on the following information, please provide a concise and accurate answer to the question.
    If the information is not sufficient to answer the question, say so.

    Question: {query}

    Relevant information:
    {chunks}

    Answer:
    """
    full_prompt = prompt.format(query=query, chunks="\n\n".join(relevant_chunks))
    inputs = self.tokenizer([full_prompt], return_tensors="pt").to(device)
    outputs = self.model.generate(**inputs, max_new_tokens=1024, temperature=0.1, do_sample=False)
    return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

In [None]:
bm = BM25(dl.get_text_splitter(), base_embeddings, model, tokenizer)

In [None]:
contextualized_vectorstore = bm.create_vectorstores(dl.get_chunks())

In [None]:
contextual_bm25_index = bm.create_bm25_index(dl.get_chunks())

In [None]:
contextualized_reranker = bm.create_flashrank_index(contextualized_vectorstore)

In [None]:
### Vector Store setup
# client = QdrantClient(":memory:")
# client.create_collection(
#     collection_name="mental_health_db",
#     vectors_config=VectorParams(size=768, distance=Distance.COSINE),
# )
# vector_store = QdrantVectorStore(
#     client = client,
#     collection_name="mental_health_db",
#     embedding=base_embeddings,
# )

In [None]:
# vector_store.add_documents(txt_splits)
# vector_store.add_documents(pdf_splits)

In [None]:
# query = "what is a financial budget?"

# results = vector_store.similarity_search_with_score(query, k=4)

In [None]:
# for res in results:
#   print(res)
#   print('\n')

In [None]:
# for res in results:
#   print(res[1])

In [None]:
### Simple Test-Time Scaling Technique

## Apply Test-Time Scaling (Budget Forcing)
def generate_with_budget_forcing(model, tokenizer, prompt, max_thinking_tokens=30, extra_tokens=10):
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
  generate_ids = model.generate(**inputs, max_new_tokens=max_thinking_tokens + extra_tokens)
  output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=False)

  tokens = output_text.split()
  if len(tokens) > max_thinking_token:
    truncated = tokens[:max_thinking_tokens] + ["<end-of-thinking>"]
    return " ".join(truncated)
  return output_text


In [None]:
## Run Inference with Adapted Generation Function
prompt = "what is a financial budget?"
output = generate_with_budget_forcing(model, tokenizer, prompt)
print("Generated Output:")
print(output)