### Retrieval-Augmented Generation
`Query -> Search a Database -> Relevant Documents -> Send to LLM -> Contextually Relevant Answer` <br/>

Complexity from decisions based on:
- Chunking.
- Databases.
- Preprocessing query.
- Postprocessing results.
- Semantic vs Keywords.
- Hypothetical searches.
- Multi-hop retrieval.
- Agentic retrieval.

#### Multi-Hop Retrieval
`Question -> LM <-> Hybrid Search from DB` <br/>
`Context -> LM <-> DB` <br/>
`Context -> LM -> Answer` <br/>

#### Hybrid HyDE Search
`Question -> HyDE LM -> (Semantic Query -> Embedding Search) + (BM-25 Query -> BM-25 Search) -> Reciprocal Rank Fusion`

### Setup Jokes DB
<a href="https://www.kaggle.com/datasets/abhinavmoudgil95/short-jokes">Dataset link.</a>

In [1]:
import torch
import numpy as np
from transformers import DistilBertModel, DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

def embed_texts(texts):
    encoded_input = tokenizer(texts, padding=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
    embeddings = model_output.last_hidden_state[:,0,:].numpy()
    embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    
    return embeddings

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
from tqdm import tqdm
from pathlib import Path

if not Path('embeddings.npy').exists():
    data = pd.read_csv('shortjokes.csv')
    jokes = data['Joke'].values
    jokes = jokes[:5000]
    
    batch_size = 512
    all_embeddings = []
    for i in tqdm(range(0, len(jokes), batch_size), desc='Generating embeddings'):
        batch_texts = jokes[i:i+batch_size].tolist()
        batch_embeddings = embed_texts(batch_texts)
        all_embeddings.append(batch_embeddings)

    embeddings = np.concatenate(all_embeddings, axis=0)
    print(f'Total embeddings: {len(embeddings)}')
    np.save('embeddings.npy', embeddings)
    with open('jokes.txt', 'w') as f:
        for joke in jokes:
            f.write(joke+'\n')


### Basic Nearest-Neighbors RAG

In [3]:
class BasicEmbeddingsRAG:
    def __init__(self, texts, embeddings):
        self.texts = texts
        self.embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    
    def get_nearest(self, query: str, k: int = 10):
        query_emb = embed_texts([query])
        query_emb = query_emb / np.linalg.norm(query_emb, axis=1, keepdims=True)
        
        # cosine similarity
        # only need dot-product as the embeddings are already normalized
        similarity = np.dot(query_emb, self.embeddings.T).flatten()
        
        topk_idxs = np.argpartition(similarity, -k)[-k:]
        topk_idxs = sorted(topk_idxs, key=lambda x: similarity[x],
                           reverse=True)
        
        return [self.texts[i] for i in topk_idxs]

In [4]:
import time

query = 'Laugh'
with open('jokes.txt', 'r') as f:
    jokes = [l.strip() for l in f.readlines()]
embs = np.load('embeddings.npy')

basic_rag = BasicEmbeddingsRAG(jokes, embs)

start = time.time()
nearest = basic_rag.get_nearest(query, k=10)
end = time.time()

print(f'Time: {end - start}')
print(nearest)

Time: 0.042990922927856445
["The best joke you'll never hear", 'Meet the parents', 'Hire The Pretty Blonde', 'Just one time I wanna see The Bachelor get a cold sore', 'What do you call a bald porcupine? Pointless!', 'pull my upvote', "My life That's the joke.", 'What do you call corn with a sense of humor? Laughing stalk', 'What do you call a bald porcupine? Pointless.', 'Velcro. What a rip off!']


### Approximate Nearest-Neighbors

In [5]:
from annoy import AnnoyIndex

class AnnoyRAG:
    def __init__(self, texts, embeddings, n_trees=10):
        self.texts = texts
        self.emb_dim = embeddings.shape[1]
        self.index = AnnoyIndex(self.emb_dim, 'angular')
        
        normalized_embs = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        for i, vec in enumerate(normalized_embs):
            self.index.add_item(i, vec)
        self.index.build(n_trees)
    
    def get_nearest(self, query: str, k: int = 10):
        query_emb = embed_texts([query])
        query_emb = query_emb / np.linalg.norm(query_emb, axis=1, keepdims=True)
        
        nearest_idxs = self.index.get_nns_by_vector(query_emb[0], k)
        return [self.texts[i] for i in nearest_idxs]

In [10]:
query = 'AI is rogue'

basic_rag = BasicEmbeddingsRAG(jokes, embs)
annoy_rag = AnnoyRAG(jokes, embs)

start = time.time()
nearest_basic = basic_rag.get_nearest(query, k=10)
end = time.time()
print(f'Time for Basic: {end - start}')

start = time.time()
nearest_annoy = annoy_rag.get_nearest(query, k=10)
end = time.time()
print(f'Time for Annoy: {end - start}')

print(nearest_basic)
print(nearest_annoy)

Time for Basic: 0.018012285232543945
Time for Annoy: 0.01598978042602539
['What comes before OP? QWERTYUI', 'Be alert! The world needs more lerts.', '"Blinding Nemo" #BPMovies', 'How do you call a beautiful feminist? An oxymoron', 'Who is the king of the pencil case? The Ruler', 'Political Joke The Economy', '"I see people." - The Fifth Sense', "What comes after America? Bmerica. I'll see myself out", 'Genderfluid? I just call that semen', 'Meet the parents']
['Be alert! The world needs more lerts.', '"Blinding Nemo" #BPMovies', 'How do you call a beautiful feminist? An oxymoron', 'Political Joke The Economy', '"I see people." - The Fifth Sense', 'Genderfluid? I just call that semen', 'Meet the parents', 'Velcro. What a rip off!', 'What is it that is yours , but others use it more than you ? Your name', "What do you call someone incapable of eating people? A can't-ibal"]


### BM-25 Retrieval
- Previous approaches are semantic-based.
  - Uses embeddings.
  - Captures overall semantic correlation.
  - May mess up direct matches.
- BM25 is keyword-based retrieval.
  - Direct term-frequency matching.
  - Can't capture synonyms, only direct matches.
- E.g. Usecase: Searching for a specific model in a refrigerator manual.

In [7]:
from rank_bm25 import BM25Okapi

class BM25Retriever:
    def __init__(self, texts):
        self.texts = texts
        self.bm25 = BM25Okapi([t.split(' ') for t in texts])  # tokenize by splitting on space
    
    def get_nearest(self, query: str, k: int = 10):
        tokenized = query.split(' ')
        topk_docs = self.bm25.get_top_n(tokenized, self.texts, n=k)
        return topk_docs

In [11]:
query = 'Cell phones'

bm25_retriever = BM25Retriever(jokes)

start = time.time()
nearest_bm25 = bm25_retriever.get_nearest(query, k=10)
end = time.time()
print(f'Time: {end-start}')
print(nearest_bm25)

Time: 0.0019989013671875
['What Cell Phone Company does Usain Bolt use? Sprint', 'Ever since the news came out about Samsung.... Their phones have been blowing up.', "I bet kangaroos get tired of holding all of their friend's keys and cell phones while they're at the beach.", 'I become instantly beautiful when I put on my sunglasses. -Every girl, ever.', 'What did the ruler gain a reputation for while campaigning? Straight talk.', 'How do you fit 4 gays on one barstool? Flip it over!', 'I want my tombstone to read "Free WiFi" so people would visit more often', 'You ever notice that the most dangerous thing about marijuana is getting caught with it?', 'What did Arnold Schwarzenegger say at the abortion clinic? Hasta last vista, baby.', 'Sucks that these Crest strips only come in white']


### Combining Both Approaches
#### Reciprocal Rank Fusion
- Classic search engine technique.
- Combines multiple ranked results.

In [9]:
def reciprocal_rank_fusion(ranked_lists, k=60):
    scores = {}
    # calculate RRF scores
    for ranked_list in ranked_lists:
        for rank, doc in enumerate(ranked_list):
            if doc not in scores:
                scores[doc] = 0
            scores[doc] += 1 / (k + rank + 1)
    
    docs = sorted(scores.keys(),
                  key=lambda doc: scores[doc],
                  reverse=True)
    docs = docs[:k]
    return docs

In [14]:
query = 'Cell phones'
topk = 10

vector_rag = BasicEmbeddingsRAG(jokes, embs)
bm25_retriever = BM25Retriever(jokes)

start = time.time()
vector_results = vector_rag.get_nearest(query, k=topk)
end = time.time()
vector_time = end - start

start = time.time()
bm25_results = bm25_retriever.get_nearest(query, k=topk)
end = time.time()
bm25_time = end - start

print(f'Vector Results ({vector_time:.4f}s)')
for i, res in enumerate(vector_results):
    print(f'\t{i+1}. {res}')

print(f'BM25 Results ({bm25_time:.4f}s)')
for i, res in enumerate(bm25_results):
    print(f'\t{i+1}. {res}')

fused_results = reciprocal_rank_fusion([vector_results, bm25_results])
print(f'Fused and Re-ranked Results (Top {topk})')
for i, res in enumerate(fused_results[:topk]):
    print(f'\t{i+1}. {res}')

Vector Results (0.0169s)
	1. Meet the parents
	2. South Africa
	3. pull my upvote
	4. Political Joke The Economy
	5. I have a joke to tell. Can you reddit?
	6. My life That's the joke.
	7. The best joke you'll never hear
	8. I like the sound of you not talking.
	9. Hire The Pretty Blonde
	10. I have a joke about Ebola You probably won't get it
BM25 Results (0.0030s)
	1. What Cell Phone Company does Usain Bolt use? Sprint
	2. Ever since the news came out about Samsung.... Their phones have been blowing up.
	3. I bet kangaroos get tired of holding all of their friend's keys and cell phones while they're at the beach.
	4. I become instantly beautiful when I put on my sunglasses. -Every girl, ever.
	5. What did the ruler gain a reputation for while campaigning? Straight talk.
	6. How do you fit 4 gays on one barstool? Flip it over!
	7. I want my tombstone to read "Free WiFi" so people would visit more often
	8. You ever notice that the most dangerous thing about marijuana is getting caught

### Multi-Hop HyDE
- Separate queries for Semantic and Keyword searches for maximum flexibility.
  - Semantic search is optimized for Cosine Similarity search.
  - BM25 search is optimized for short, keyword-based queries.
- Multi-hop gives the LLM more chances to tune the query for a better hit.
  - Often paired with validation checks for stopping earlier.
  - E.g. Checking if the answer is already retrieved in a Q/A system.

In [16]:
import dspy
from typing import Optional, List

class HypotheticalDoc(dspy.Signature):
    """Given a query, generate hypothetical documents to search a database of one-liner jokes."""
    query: str = dspy.InputField(desc='User wants to fetch jokes related to this topic.')
    retrieved_jokes: Optional[List[str]] = dspy.InputField(desc='Jokes previously retrieved from the DB. Use these to further tune your search.')
    hypothetical_bm25_query: str = dspy.OutputField(desc='Sentence to query to retrieve more jokes about the query from the DB.')
    hypothetical_semantic_query: str = dspy.OutputField(desc='Sentence to search with Cosine Similarity.')

class MultiHopeHyDESearch(dspy.Module):
    def __init__(self, texts, embs, n_hops=3, k=10):
        self.pred = dspy.ChainOfThought(HypotheticalDoc)
        self.pred.set_lm(dspy.LM('gemini/gemini-2.5-flash-lite'))
        
        self.emb_retriever = BasicEmbeddingsRAG(texts, embs)
        self.bm25_retriever = BM25Retriever(texts)
        
        self.n_hops = n_hops
        self.k = k
    
    def forward(self, query):
        retrieved_jokes = []
        all_jokes = []
        for _ in range(self.n_hops):
            new_query = self.pred(query=query, retrieved_jokes=retrieved_jokes)
            print(new_query)
            
            emb_lists = self.emb_retriever.get_nearest(new_query.hypothetical_semantic_query)
            bm25_lists = self.bm25_retriever.get_nearest(new_query.hypothetical_bm25_query)
            retrieved_jokes = reciprocal_rank_fusion([emb_lists, bm25_lists], k=self.k)
            all_jokes.extend(retrieved_jokes)
        return dspy.Prediction(jokes=all_jokes)

In [17]:
query = 'Cell phones'
k = 5
n_hops = 3

hyde = MultiHopeHyDESearch(jokes, embs, n_hops, k)
retrieved_jokes = hyde(query=query).jokes
print(retrieved_jokes)

Prediction(
    reasoning='The user is asking for jokes about cell phones. Since no jokes have been retrieved yet, I should generate a broad BM25 query and a semantic query that captures the essence of cell phone jokes.',
    hypothetical_bm25_query='cell phone jokes',
    hypothetical_semantic_query='Jokes about mobile phones and their use.'
)
Prediction(
    reasoning='The user is asking for jokes about cell phones. The retrieved jokes include one directly about cell phones ("My cell phone is so nervous whenever I go to the countryside... ...it\'s constantly on EDGE.") and another that mentions cell phone providers ("NEVER date someone that works for your cell phone provider. You\'re welcome."). To find more jokes, I should focus on keywords related to cell phones, mobile phones, smartphones, and common cell phone-related scenarios or features.',
    hypothetical_bm25_query='cell phone jokes, mobile phone humor, smartphone jokes, funny cell phone stories',
    hypothetical_semantic_q

### JokeGenerator Example
`Query -> (Idea LM <-> WebSearch) -> Joke Idea -> (Joke LM <-> Joke DB) -> Joke`