# Respond to natural language questions using RAG approach

This notebook contains the steps and code to demonstrate support of Retrieval Augumented Generation using a local model as well as watsonx.ai. It introduces commands for data retrieval, knowledge base building & querying, and model testing.

Some familiarity with Python is helpful. This notebook uses Python 3.11.

#### About Retrieval Augmented Generation
Retrieval Augmented Generation (RAG) is a versatile pattern that can unlock a number of use cases requiring factual recall of information, such as querying a knowledge base in natural language.

In its simplest form, RAG requires 3 steps:

- Phase 1: Index knowledge base passages (once)
- Phase 2: Retrieve relevant passage(s) from knowledge base (for every user query)
- Phase 3: Generate a response by feeding retrieved passage into a large language model (for every user query)


![image](images/RAG.png)

<a id="setup"></a>
##  Set up the environment



**Note:** Please restart the notebook kernel to pick up proper version of packages installed above.

<a id="setup"></a>
## Setup environment and import relevant libraries

As one of the main components will be a document file (we use a PDF) the main imports are pypdf to parse that and chromadb to set up the knowledge base.

In [None]:
import os
import re
import sys
import requests

from enum import auto
from enum import Enum
from pathlib import Path
from typing import Optional

import chromadb

from chromadb import Collection
from chromadb.utils import embedding_functions

from pypdf import PdfReader, PageObject


## Helper class/function

In [None]:
class CollectionStatus(Enum):
    COLLECTION_CREATED = auto()
    COLLECTION_EXISTS = auto()


def ensure_collection(client: chromadb.ClientAPI) -> tuple[CollectionStatus, Optional[Collection]]:
    demo_collection = "harry_potter"
    try:
        client.get_collection(name=demo_collection)
        return CollectionStatus.COLLECTION_EXISTS, None
    except ValueError:
        collection = client.get_or_create_collection(name=demo_collection, metadata={"hnsw:space": "cosine"})
        return CollectionStatus.COLLECTION_CREATED, collection


def clean_text(raw_text: str) -> str:
    cleaned_text = raw_text.replace("\n", " ")
    cleaned_text = re.sub(r"\s+", " ", cleaned_text) #space, tab or line break
    return cleaned_text


def get_chunks(pages: list[PageObject], max_words: int = 150) -> list[tuple[str, int]]:
    text_tokens = [(clean_text(page.extract_text()).split(" "), page.page_number) for page in pages]
    chunks = []

    for idx, (words, page_number) in enumerate(text_tokens):
        for i in range(0, len(words), max_words):
            chunk = words[i:i + max_words]
            if (i + max_words) > len(words) and (len(chunk) < max_words) and (
                    len(text_tokens) != (idx + 1)):
                next_page = text_tokens[idx + 1]
                text_tokens[idx + 1] = (chunk + next_page[0], next_page[1])
                continue
            chunk = " ".join(chunk).strip()
            chunk = f'[Page no. {page_number}]' + ' ' + '"' + chunk + '"'
            chunks.append((chunk, page_number))

    return chunks


def insert_document(document_path: Path, collection: Collection) -> None:
    document_reader = PdfReader(document_path)
    document_name = document_path.stem.replace(" ", "-").replace("_", "-")
    pages = document_reader.pages

    document_chunks = []
    document_ids = []

    chunks = get_chunks(pages)
    for chunk_index, (chunk, page_number) in enumerate(chunks):
        document_ids.append(f"{document_name}_p{page_number}-{chunk_index}")
        document_chunks.append(chunk)

    collection.add(
        documents=document_chunks,
        ids=document_ids,
    )


## Phase 1: Ingesting data & build up knowledge base

![image](images/Ingest_Data.png)

In [None]:
base_directory = Path("./")
db_directory = base_directory / "db"
files_directory = base_directory / "db_files"

if not db_directory.exists():
    db_directory.mkdir()

if not files_directory.exists():
    print("DB files were not copied! Abort.")

chroma_client = chromadb.PersistentClient(path=str(db_directory))

collection_status, collection = ensure_collection(chroma_client)

if collection_status == CollectionStatus.COLLECTION_EXISTS:
    print("Collection already exists. No new files are loaded.")
else:
    print("Creating collection...")
    for document_path in files_directory.glob("*.pdf"):
        insert_document(document_path, collection)


### _Optional: Check collection_

In [None]:
client = chromadb.PersistentClient(path="./db")
collection = client.get_collection(name="harry_potter")

In [None]:
collection.peek(5)['documents']

### _Excursus 1: Tokenization_

![image](images/Tokenization.png)

Credits to Andreas, Hardy, Alex & Nils :)

In [None]:
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-MiniLM-L6-v2')

sentence = ["What did you eat for breaktfast?"]

tokens = model.tokenize(sentence)

print(f"Number of tokens: {len(tokens['input_ids'][0])}")
tokens['input_ids']

### _Excursus 2: Embeddings_

![image](images/Embeddings.png)

Credits to Andreas, Hardy, Alex & Nils :)

In [None]:
embeddings = model.encode(sentence)

print(f"Number of embeddings: {len(embeddings[0])}")
print(embeddings[0][:100])

## Phase 2: Retrieve relevant passage(s) from Knowledge Base

![image](images/Retrieve_Data.png)

In [None]:
question = "What was the job of Mr. Dursley?" #Do Mr. and Mrs. Dursely have a son? #How old is Mr. Dursley?
results = collection.query(
    query_texts=[question],
    n_results=2,
)

results['distances'], results['documents']

## Phase 3: Build prompt, pass to LLM & generate Response

![image](images/Generate_Response.png)

In [None]:
def build_prompt(question, topn_chunks: list[str]):
    prompt = "Search results:\n"

    for chunk in topn_chunks:
        prompt += chunk + "\n\n"

    prompt += "Instructions: Compose a comprehensive reply to the query using the search results given. " \
              "If the search results mention multiple subjects " \
              "with the same name, create separate answers for each. Only include information found in the results and " \
              "don't add any additional information. Make sure the answer is correct and don't output false content. " \
              "If the text does not relate to the query, simply state 'Found Nothing'. Ignore outlier " \
              "search results which has nothing to do with the question. Only answer what is asked. The " \
              "answer should be short and concise."

    prompt += f"\n\n\nQuery: {question}\n\nAnswer: "

    return prompt

In [None]:
prompt = build_prompt(question, results["documents"][0])
print(prompt)

## 3a. Interacting with the LLM on GPU (using watsonx)

In [None]:
from dotenv import load_dotenv
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams

load_dotenv()
api_key = os.getenv("API_KEY", None) 
api_url = os.getenv("IBM_CLOUD_URL", None)
project_id = os.getenv("PROJECT_ID", None)

creds = {
    "url": api_url,
    "apikey": api_key,

}

#creds, project_id

In [None]:
params = { 
    GenParams.DECODING_METHOD: "sample", 
    GenParams.MAX_NEW_TOKENS: 200, 
    GenParams.TEMPERATURE: 0.1
    }

model = Model("meta-llama/llama-2-70b-chat", params=params, credentials=creds, project_id=project_id)

In [None]:
for count, response in enumerate(model.generate_text_stream(prompt)):
    print("▌") if count == 0 else print(response, end="")

## 3b Interacting with a small & quantized LLM on CPU (llama.cpp)

The model deployed here is llama-2-7b-q8_0 (int8 quantized)

In [None]:
import httpx
import json

SERVER="127.0.0.1" #server where you deployed llama.cpp, 127.0.0.1 => localhost
PORT="8080"

json_data = {
    'prompt': prompt,
    'n_predict': 256,
    'stream': True
}

client = httpx.AsyncClient()
lastChunks = ""
async with client.stream('POST', f'http://{SERVER}:{PORT}/completion', json=json_data) as response:
    async for chunk in response.aiter_bytes():
        try:
            data = json.loads(chunk.decode('utf-8')[6:])
        except:
            pass
        if data['stop'] is False:
            print(data['content'], end="")
        else:
            print('\n\n')
            print(data['timings'])

### Further topics to consider

- Performance Benchmarks?
- Quantization (PTDQ, PTSQ, QAT)?
- LangChain?
