# Build a simple RAG system with semantic search and BM25

To build a simple RAG, we need build a query function, a retrieval function, and a generation function. We will also need a dataset to retrieve from.




## Understand Data Structures

Many tutorials usually use some sort of open-source dataset as examples to show how RAG works. For me, the problem is not really about the RAG workflow, but how to work with data. For beginners like me, it is not a simple question. It is actually the most mind-boggling part of the whole process.



In [6]:
!pip install bm25s
!pip install nltk
!pip install scikit-learn
!pip install pandas



In [None]:
import pandas as pd
import numpy as np
import bm25s
import os
from sentence_transformers import SentenceTransformer

First we need to read the dataset. Pandas is a great and most common choice for data manipulation. We can use `pd.read_csv()` to read the dataset into a DataFrame and assign it to a variable. 

Then we can do some basic data exploration to understand the structure of the dataset. For example, we can use `df.head()` to see the first few rows of the dataset, and `df.info()` to see the data types and null values in each column. Here "df" is the DataFrame variable we assigned earlier, it can be any name you choose but somehow people like to use "df" as a short form for DataFrame.





In [9]:
NEWS_DATA = pd.read_csv("./resource/sample_data/news_data_dedup.csv")
NEWS_DATA.head(3) # Display the first few rows of the dataset

Unnamed: 0,guid,title,description,venue,url,published_at,updated_at
0,e3dc5caa18f9a16d7edcc09f8d5c2bb4,Harvey Weinstein's 2020 rape conviction overtu...,Victims group describes the New York appeal co...,BBC,https://www.bbc.co.uk/news/world-us-canada-688...,2024-04-25 18:24:04+00,2024-04-26 20:03:00.628113+00
1,297b7152cd95e80dd200a8e1997e10d9,Police and activists clash on Atlanta campus a...,"Meanwhile, hundreds of students march in Washi...",BBC,https://www.bbc.co.uk/news/live/world-us-canad...,2024-04-25 13:40:25+00,2024-04-26 20:03:00.654819+00
2,170bd18d1635c44b9339bdbaf1e62123,Haiti PM resigns as transitional council sworn in,The council will try to restore order and form...,BBC,https://www.bbc.co.uk/news/world-latin-america...,2024-04-25 18:11:02+00,2024-04-26 20:03:00.663393+00


In [8]:
NEWS_DATA.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 870 entries, 0 to 869
Data columns (total 7 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   guid          870 non-null    object
 1   title         870 non-null    object
 2   description   870 non-null    object
 3   venue         870 non-null    object
 4   url           870 non-null    object
 5   published_at  870 non-null    object
 6   updated_at    870 non-null    object
dtypes: object(7)
memory usage: 47.7+ KB


In [10]:
NEWS_DATA['description'][0] #access the first row of the description column

'Victims group describes the New York appeal court\'s decision to retry Hollywood mogul as "profoundly unjust".'

In [17]:
NEWS_DATA.loc[0] #access the first row

guid                             e3dc5caa18f9a16d7edcc09f8d5c2bb4
title           Harvey Weinstein's 2020 rape conviction overtu...
description     Victims group describes the New York appeal co...
venue                                                         BBC
url             https://www.bbc.co.uk/news/world-us-canada-688...
published_at                               2024-04-25 18:24:04+00
updated_at                          2024-04-26 20:03:00.628113+00
Name: 0, dtype: object

It takes some practice to become comfortable with accessing and manipulating data in Pandas. Don't panic.

Pandas library has very good documentation here: https://pandas.pydata.org/

And a cheat sheet here: https://pandas.pydata.org/Pandas_Cheat_Sheet.pdf . A copy of the cheat sheet is also included in the resource folder of this project. 


## Start with querying

If we have a structured dataset, we can start with querying by index. We can create a function that takes an index or a list of indices as input and returns the corresponding rows from the DataFrame. This is a simple way to retrieve data without any complex logic.

In [None]:
def query_news(indices, dataset=NEWS_DATA):
    """
    Retrieves elements from a dataset based on specified indices.

    Parameters:
    indices (list of int): A list containing the indices of the desired elements in the dataset.
    dataset (list or sequence): The dataset from which elements are to be retrieved. It should support indexing.

    Returns:
    list: A list of elements from the dataset corresponding to the indices provided in list_of_indices.
    """

    output = [dataset.iloc[index] for index in indices]

    return output

# Example usage
indices = [0, 1, 2]
result = query_news(indices)
print(result)
print("Type of result:", type(result))
print("=" * 50)
# Display the first element's description from the result
print(f"""Description of {result[0]['title']}:\n 
\"{result[0]['description']} \"\n 
Publication date: {result[0]['published_at']}""")  



[guid                             e3dc5caa18f9a16d7edcc09f8d5c2bb4
title           Harvey Weinstein's 2020 rape conviction overtu...
description     Victims group describes the New York appeal co...
venue                                                         BBC
url             https://www.bbc.co.uk/news/world-us-canada-688...
published_at                               2024-04-25 18:24:04+00
updated_at                          2024-04-26 20:03:00.628113+00
Name: 0, dtype: object, guid                             297b7152cd95e80dd200a8e1997e10d9
title           Police and activists clash on Atlanta campus a...
description     Meanwhile, hundreds of students march in Washi...
venue                                                         BBC
url             https://www.bbc.co.uk/news/live/world-us-canad...
published_at                               2024-04-25 13:40:25+00
updated_at                          2024-04-26 20:03:00.654819+00
Name: 1, dtype: object, guid                       

### BM25 Retrieval

Now that we can retrieve data by index, we can work on the retrieval function, which ideally should take a query as input and return the indices of the relevant rows from the DataFrame. 

Let's start with bm25 using `bm25s` library. BM25 is a popular algorithm for information retrieval that ranks documents based on their relevance to a given query. It is widely used in search engines and can be a good starting point for building a retrieval function.

First, we will construct the "corpus" from the DataFrame. The corpus is a list of strings, where each string is a document in the dataset.

In [45]:
# The corpus used will be the title appended with the description
# the "to_dict('records')" method converts the DataFrame to a list of dictionaries, where each dictionary represents a row in the DataFrame.
records = NEWS_DATA.to_dict('records')
print(records[0])  # Display the first record to check the structure
corpus = [x['title'] + " " + x['description'] for x in records]
print("Corpus created with", len(corpus), "documents. The type of corpus is:", type(corpus))
print("First document in the corpus:", corpus[0])

{'guid': 'e3dc5caa18f9a16d7edcc09f8d5c2bb4', 'title': "Harvey Weinstein's 2020 rape conviction overturned", 'description': 'Victims group describes the New York appeal court\'s decision to retry Hollywood mogul as "profoundly unjust".', 'venue': 'BBC', 'url': 'https://www.bbc.co.uk/news/world-us-canada-68899382', 'published_at': '2024-04-25 18:24:04+00', 'updated_at': '2024-04-26 20:03:00.628113+00'}
Corpus created with 870 documents. The type of corpus is: <class 'list'>
First document in the corpus: Harvey Weinstein's 2020 rape conviction overturned Victims group describes the New York appeal court's decision to retry Hollywood mogul as "profoundly unjust".


We then need to instantiate the BM25 retriever by passing the corpus data. The `bm25s` library provides a `BM25` class that we can use for this purpose.

In the follow code, the `BM25_RETRIEVER` is an instance of the `BM25` class, which is initialized with the corpus. There are some common methods we can use with the `BM25_RETRIEVER` object, such as `index()` to index the tokenized chunks, and `retrieve()` to retrieve relevant documents based on a query.

To instantiate the retriever, we first need call the `bm25s.BM25()` constructor with the corpus data. 

Then we can tokenize the corpus data using the `bm25s.tokenize()` function. This function takes the corpus and an optional list of stopwords as input and returns a list of tokenized documents. Tokenization is the process of breaking down the text into smaller units (tokens) for better processing.

The next step is to index the tokenized chunks within the retriever. This will allow the retriever to efficiently search for relevant documents based on a query. The `index()` method is used for this purpose, and it takes the tokenized data as input.

After tokenizing and indexing the data, we are ready to use the BM25 retriever to retrieve relevant documents based on a query. The `retrieve()` method can be used to perform the retrieval, and it will return a list of relevant documents along with their scores.

In [None]:
# Instantiate the retriever by passing the corpus data
# BM25_RETRIEVER is an instance of the BM25 class from the bm25s library, which is used for information retrieval.
# It is initialized with the corpus, which is a list of strings where each string represents a document in the dataset.
BM25_RETRIEVER = bm25s.BM25(corpus=corpus)

# Define a list of stopwords to be used during tokenization, this is optional but may improve retrieval performance
stopwords = ["a", "the", "and", "is", "to", "of", "in", "that", "it", "for", "on", "with", "as", "this", "by", "at", "from"]
# Tokenize the chunks, which means breaking down the text into smaller units (tokens) for better processing
tokenized_data = bm25s.tokenize(corpus, stopwords=stopwords)

# Index the tokenized chunks within the retriever, this is done automatically when you call the `BM25` constructor, but you can also do it explicitly
BM25_RETRIEVER.index(tokenized_data)

# Check the content of tokenized_data
print("Content of tokenized_data:", tokenized_data[:3])

Split strings:   0%|          | 0/870 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/870 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/870 [00:00<?, ?it/s]



In [None]:
# Tokenize the same query used in the previous exercise
sample_query = "What are the recent news about GDP?"
tokenized_sample_query = bm25s.tokenize(sample_query)

# Get the retrieved results and their respective scores
results, scores = BM25_RETRIEVER.retrieve(tokenized_sample_query, k=5)

"""
Note the actual results and scores are 'lists of lists', where each inner list corresponds to a query. 
Since we only have one query, we can access the first element of each list by calling results[0] and scores[0].
"""

for r,s in zip(results[0], scores[0]):
    print(f"Document: {r} \n Score: {s}")

print(f"Results for query: {sample_query}\n")
for doc in results[0]:
  print(f"Document retrieved {corpus.index(doc)} : {doc}\n")

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:01<?, ?it/s]

Document: GDP and the Dow Are Up. But What About American Well-Being? The standard ways of measuring economic growth don’t capture what life is like for real people. A new metric offers a better alternative, especially for seeing disparities across the country. 
 Score: 5.03474760055542
Document: What the GDP Report Says About Inflation: A Hot First Quarter Thursday’s gross domestic product report suggests that a widely watched inflation reading due Friday could be worse than expected. 
 Score: 4.879597187042236
 Score: 3.8771421909332275
Document: Do the GDP and Dow Reflect American Well-Being? Do the GDP and Dow Reflect American Well-Being? 
 Score: 3.334899663925171
Document: Auto Safety Regulator Investigating Tesla Recall of Autopilot The National Highway Safety Administration said it had concerns about how Tesla handled the recall based on recent crashes and testing of cars that had been updated. 
 Score: 2.951777935028076
Results for query: What are the recent news about GDP?

D

In [68]:
# we need the doc numbers(indices) to perform fusion ranking later
indices_bm25 = [corpus.index(doc) for doc in results[0]]
print("Indices of retrieved documents:", indices_bm25)

Indices of retrieved documents: [752, 673, 289, 626, 43]


!NOTE 

The output from the BM25 retriever will be the same for the same query, regardless of the order of the documents in the corpus and the word order of the query. This is because BM25 is a statistical model that calculates the relevance of each document to the query based on the term frequency and inverse document frequency. In comparison, a semantic search model might produce different results based on the context and meaning of the words in the query and documents. Order of the documents in the corpus and the word order matters for semantic search, but not for BM25.





### Semantic Search Retrieval
Now that we have a basic understanding of how to retrieve data by index and using BM25, we can move on to semantic search. Semantic search is a more advanced technique that uses natural language processing (NLP) to understand the meaning of the query and the documents in the corpus. It can provide more relevant results than BM25, especially for complex queries.

A key component of semantic search is the use of embeddings, which are vector representations of text. These embeddings capture semantic meaning, allowing us to compare text based on context. One common way to measure the similarity between these vectors is through cosine similarity, which calculates how close two vectors are in high-dimensional space. This approach helps in finding content that is contextually similar to the user's query, leading to more accurate and meaningful search results.

In contrast, BM25 uses a sparse representation based on keyword matching. During the indexing stage, BM25 tokenizes documents, builds an inverted index (mapping each term to the documents it appears in along with term frequencies), and stores statistics such as document length. During the retrieval stage, BM25 tokenizes the query in the same way, looks up documents containing those terms in the inverted index, and calculates a relevance score using the BM25 formula, which combines term frequency (TF), inverse document frequency (IDF), and length normalization.

In [69]:
!pip install joblib



In [73]:
import joblib
from sentence_transformers import SentenceTransformer

In [None]:
# for the first time, we need to create embeddings for the corpus

EMBEDDINGS = joblib.load("./resource/sample_data/embeddings.joblib")


In [74]:
model_name = "BAAI/bge-base-en-v1.5" 
model = SentenceTransformer(model_name)

In [75]:
# Example usage
query = "RAG is awesome"
# Using, but truncating the result to not pollute the output, don't truncate it in the exercise.
model.encode(query)[:40]

  return forward_call(*args, **kwargs)


array([ 0.00886302, -0.04775141, -0.00156084,  0.01309997, -0.00206938,
       -0.06157259,  0.0138469 ,  0.00101493, -0.04903951, -0.04762559,
       -0.0362819 ,  0.00478037, -0.03492177,  0.05323153,  0.02193962,
        0.03645133,  0.04029364, -0.0045364 ,  0.01883797, -0.03367385,
        0.02516189, -0.04843633, -0.04047945,  0.02590899,  0.02175233,
        0.03160366,  0.03937932, -0.03640464, -0.03113294, -0.0124723 ,
        0.03661648, -0.00458203, -0.00100169, -0.03188792,  0.02957135,
        0.01986158, -0.00737469,  0.02370172, -0.02151619, -0.07361358],
      dtype=float32)

In [77]:
def cosine_similarity(a, b):
    """
    Calculate the cosine similarity between two vectors.

    Parameters:
    a (np.ndarray): First vector.
    b (np.ndarray): Second vector.

    Returns:
    float: Cosine similarity between the two vectors.
    """
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))   
    

In [79]:
query1 = "What are the primary colors"
query2 = "Yellow, red and blue"
query3 = "Cats are friendly animals"

query1_embed = model.encode(query1)
query2_embed = model.encode(query2)
query3_embed = model.encode(query3)

print(f"Similarity between '{query1}' and '{query2}' = {cosine_similarity(query1_embed, query2_embed)}")
print(f"Similarity between '{query1}' and '{query3}' = {cosine_similarity(query1_embed, query3_embed)}")

Similarity between 'What are the primary colors' and 'Yellow, red and blue' = 0.7377139329910278
Similarity between 'What are the primary colors' and 'Cats are friendly animals' = 0.4508620500564575


In [82]:
def semantic_search(query, embeddings, model, top_k=5):
    """
    Perform semantic search to find the most relevant documents to a given query.

    Parameters:
    query (str): The search query.
    embeddings (np.ndarray): A 2D array where each row is the embedding of a document in the corpus.
    model (SentenceTransformer): A pre-trained SentenceTransformer model used to encode the query.
    top_k (int): The number of top relevant documents to return.

    Returns:
    list: Indices of the top_k most relevant documents in the corpus.
    """
    
    # Encode the query to get its embedding
    query_embedding = model.encode(query)
    
    # Compute cosine similarities between the query embedding and all document embeddings
    # the input embeddings is a 2D array, where each row is a document embedding
    # the query_embedding is a 1D array
    # we use broadcasting to compute the cosine similarity between the query and all documents
    cosine_similarities = np.dot(embeddings, query_embedding) / (np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding))
    
    # Get the indices of the top_k most similar documents
    top_k_indices = np.argsort(cosine_similarities)[-top_k:][::-1]
    
    return top_k_indices.tolist()   

The broadcasting mechanism in NumPy allows us to perform operations on arrays of different shapes. In this case, we can compute the cosine similarity between the query embedding and all document embeddings without needing to reshape the arrays explicitly. A step-by-step breakdown of the cosine similarity calculation is as follows:

1. `dot_product = np.dot(EMBEDDINGS, query_embed)`: This computes the dot product between the query embedding and each document embedding. For this particular model, say if we have 1000 documents and each embedding is a 768-dimensional vector, the `EMBEDDINGS` array would have a shape of (1000, 768), and the `query_embed` would have a shape of (768,). The resulting `dot_product` would be a 1D array of shape (1000,), where each element represents the dot product between the query and a corresponding document.

2. `norms = np.linalg.norm(EMBEDDINGS, axis=1) * np.linalg.norm(query_embed)`: This computes the norms of the document embeddings and the query embedding. For 1000 documents, the `norms` array would have a shape of (1000,), where each element represents the norm of a corresponding document embedding multiplied by the norm of the query embedding. The query embedding's norm is a single scalar value and it is the same for all documents, so it can be broadcasted across the document norms.

3. `cosine_similarities = dot_product / norms`: This computes the cosine similarities by dividing the dot products by the norms. The result is a 1D array where each element represents the cosine similarity between the query and a corresponding document.

demostration of dot matrix multiplication and cosine similarity calculation with  3*2 matrix and 2*1 vector:


```math
\begin{align*}
\text{Dot Product} & : \begin{bmatrix}
1 & 2 \\
3 & 4 \\
5 & 6
\end{bmatrix} \cdot \begin{bmatrix}
7 \\
8
\end{bmatrix} = \begin{bmatrix}
1 \cdot 7 + 2 \cdot 8 \\
3 \cdot 7 + 4 \cdot 8 \\
5 \cdot 7 + 6 \cdot 8
\end{bmatrix} = \begin{bmatrix}
23 \\
83 \\
143
\end{bmatrix} \\

\end{align*}

In [120]:
query = "trend of economy"
top_k_indices_semantic = semantic_search(query, EMBEDDINGS, model, top_k=7)
print(f"Top {len(top_k_indices_semantic)} indices for query '{query}': {top_k_indices_semantic}")

Top 7 indices for query 'trend of economy': [772, 754, 743, 673, 752, 303, 289]


In [121]:
#run bm25 again
results, scores = BM25_RETRIEVER.retrieve(bm25s.tokenize(query), k=7)
indices_bm25 = [corpus.index(doc) for doc in results[0]]
print("Indices of retrieved documents:", indices_bm25)

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

Indices of retrieved documents: [303, 754, 289, 643, 141, 756, 251]


## RRF

RRF (Reciprocal Rank Fusion) is a method that combines the results from multiple retrieval models to improve the overall performance of the retrieval system. It works by assigning a score to each document based on its rank in the results of different models, and then combining these scores to produce a final ranking.

The RRF algorithm works as follows:
1. For each retrieval model, retrieve a ranked list of documents for a given query.

2. For each document in the ranked list, assign a score based on its rank. The score is typically calculated as `1 / (k + rank)`, where `k` is a constant (often set to 60) and `rank` is the position of the document in the ranked list.

3. Sum the scores for each document across all retrieval models.

4. Rank the documents based on their combined scores to produce the final result.

The formula for the RRF score of a document `d` is given by:

```math
\text{RRF}(d) = \sum_{i=1}^{N} \frac{1}{k + \text{rank}_i(d)}
```



In [165]:
def reciprocal_rank_fusion(list1, list2, top_k=5, K=50):
    """
    Fuse rank from multiple IR systems using Reciprocal Rank Fusion.

    Args:
        list1 (list[int]): A list of indices of the top-k documents that match the query.
        list2 (list[int]): Another list of indices of the top-k documents that match the query.
        top_k (int): The number of top documents to consider from each list for fusion. Defaults to 5.
        K (int): A constant used in the RRF formula. Defaults to 60.

    Returns:
        list[int]: A list of indices of the top-k documents sorted by their RRF scores.
    """

    ### START CODE HERE ###

    # Create a dictionary to store the RRF scores for each document index
    rrf_scores = {}

    # Iterate over each document list
    for lst in [list1, list2]:
        # Calculate the RRF score for each document index
        for rank, item in enumerate(lst, start=1): # Start = 1 set the first element as 1 and not 0. 
                                                   # This is a convention on how ranks work (the first element in ranking is denoted by 1 and not 0 as in lists)
            # If the item is not in the dictionary, initialize its score to 0
            if item not in rrf_scores:
                rrf_scores[item] = 0
                
            # Update the RRF score for each document index using the formula 1 / (rank + K)
            current_score = 1/(rank+K)
            
            #print(f"Document {item} from {lst} has RRF score: {current_score:.4f}")
            rrf_scores[item] += current_score
            #print(f"Updated RRF score for document {item}: {rrf_scores[item]:.4f}")

    # Sort the document indices based on their RRF scores in descending order
    sorted_items = sorted(rrf_scores, key=rrf_scores.get, reverse = True)

    # Slice the list to get the top-k document indices
    top_k_indices = [int(x) for x in sorted_items[:top_k]]

    ### END CODE HERE ###

    return top_k_indices


In [129]:

# Example usage of reciprocal_rank_fusion
top_k_indices_bm25 = indices_bm25[:5]  # Assuming we want the top 5 from BM25
top_k_indices_semantic = top_k_indices_semantic[:5]  # Assuming we want the top 5 from semantic search
top_k_indices_fused = reciprocal_rank_fusion(top_k_indices_bm25, top_k_indices_semantic, top_k=5)

print(f"Top {len(top_k_indices_bm25)} indices from BM25: {top_k_indices_bm25}")
print(f"Top {len(top_k_indices_semantic)} indices from semantic search: {top_k_indices_semantic}")
print(f"Top {len(top_k_indices_fused)} indices after RRF fusion: {top_k_indices_fused}")    

Document 303 from [303, 754, 289, 643, 141] has RRF score: 0.0196
Updated RRF score for document 303: 0.0196
Document 754 from [303, 754, 289, 643, 141] has RRF score: 0.0192
Updated RRF score for document 754: 0.0192
Document 289 from [303, 754, 289, 643, 141] has RRF score: 0.0189
Updated RRF score for document 289: 0.0189
Document 643 from [303, 754, 289, 643, 141] has RRF score: 0.0185
Updated RRF score for document 643: 0.0185
Document 141 from [303, 754, 289, 643, 141] has RRF score: 0.0182
Updated RRF score for document 141: 0.0182
Document 772 from [772, 754, 743, 673, 752] has RRF score: 0.0196
Updated RRF score for document 772: 0.0196
Document 754 from [772, 754, 743, 673, 752] has RRF score: 0.0192
Updated RRF score for document 754: 0.0385
Document 743 from [772, 754, 743, 673, 752] has RRF score: 0.0189
Updated RRF score for document 743: 0.0189
Document 673 from [772, 754, 743, 673, 752] has RRF score: 0.0185
Updated RRF score for document 673: 0.0185
Document 752 from [

In [135]:
# retrieve the documents using the fused indices
fused_documents = query_news(top_k_indices_fused, NEWS_DATA)
print("Fused documents retrieved:")
for doc in fused_documents:

    print(f"Title: {doc['title']},\n Description: {doc['description']}, Published at: {doc['published_at']}")

Fused documents retrieved:
Title: America's Economy Is No. 1. That Means Trouble.,
 Description: Solid growth, big deficits and a strong dollar stir memories of past crises., Published at: 2024-04-25 16:40:00+00
Title: The U.S. economy grew at a sharply slower pace in the first quarter and inflation topped Wall Street's expectations, dimming investor hopes for a quick Fed rate cut and sending the stock and bond markets down.,
 Description: The U.S. economy grew at a sharply slower pace in the first quarter and inflation topped Wall Street’s expectations, dimming investor hopes for a quick Fed rate cut and sending the stock and bond markets down., Published at: 2024-04-26 02:49:00+00
Title: America's Economy Is No. 1. That Means Trouble,
 Description: If you want a single number to capture America’s economic stature, here it is: This year, the U.S. will account for 26.3% of the global gross domestic product, the highest in almost two decades. That’s based on the latest projections from 

### Assemble the Pipeline
Now that we have the retrieval functions, we can assemble the pipeline. The pipeline will take a query as input, retrieve the relevant documents using the BM25 and semantic search retrieval functions, and then generate a response using the LLM.

In [None]:
def context_pipeline(query, top_k=5):
    bm25_results, bm25_scores = BM25_RETRIEVER.retrieve(bm25s.tokenize(query), k=top_k)
    bm25_indices = [corpus.index(doc) for doc in bm25_results[0]]

    semantic_indices = semantic_search(query, EMBEDDINGS, model, top_k=top_k)

    fused_indices = reciprocal_rank_fusion(bm25_indices, semantic_indices, top_k=top_k)
    
    fused_documents = query_news(fused_indices, NEWS_DATA)
    
    return fused_documents

# fused_documents is a list of dictionaries, each dictionary contains the title, description, and published_at of the document


## Assemble the RAG system

With the "context" for the query ready, we can now assemble the RAG system by combining the retrieval and generation components. 

We will use OpenAI's `GPT-4o-mini` model for the generation component, which is a powerful language model that can generate human-like text based on the context provided.

In [136]:
# Initializing
import openai
import os
from openai import OpenAI
from IPython.display import Markdown, display

api_key = os.getenv("OPENAI_API_KEY") # Use correct env var name from your .env or manually set it
print("OpenAI API Key:", api_key[:10] + "..." if api_key else "Not found")  # Only show first 10 chars for security
base_url = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")  # Use correct env var name from your .env
print("OpenAI Base URL:", base_url)

# Initialize OpenAI client with API key and base URL
client = OpenAI(
    api_key=api_key,
    base_url=base_url
)

OpenAI API Key: sk-nzH4YCa...
OpenAI Base URL: https://xiaoai.plus/v1


In [None]:
def gen_prompt(prompt, context=None):
    # context in this case is a list of dictionaries, each dictionary contains the title, description, and published_at of the document
    # convert context to a string
    if context:
        context = "\n\n".join([f"Title: {doc['title']}\n Description: {doc['description']}\n Published at: {doc['published_at']} \n URL: {doc['url']}" for doc in context])
    return f"""
    You are a helpful assistant with access to the following context:
    ====
    Context: {context}\n
    ====\n

    User query: {prompt};
    Use the context to answer the user's question as accurately as possible.
    If the context does not provide enough information, politely inform the user that you cannot answer based on the provided context.
    """

In [None]:
query = "What are the recent news about wheat in China?"
context = context_pipeline(query)
prompt = gen_prompt(query, context = context)
print(prompt)

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

You are a helpful assistant with access to the following context:
    ====
    Context: Title: China's Gold Consumption Rises on Safe-Haven Demand
 Description: Chinese buyers, spooked by a protracted property slump and a recent stock-market rout, are rushing toward gold as economic uncertainty looms, propelling a global bullion rally.
 Published at: 2024-04-26 09:40:00+00 
 URL: https://www.wsj.com/articles/chinas-gold-consumption-rises-on-safe-haven-demand-60654fff

Title: Blinken’s Visit to China: What to Know
 Description: Secretary of State Antony J. Blinken is in China this week as tensions have risen over trade, security, Russia’s war on Ukraine and the Middle East crisis.
 Published at: 2024-04-25 04:46:13+00 
 URL: https://www.nytimes.com/2024/04/25/world/asia/blinken-china-united-states-thaw.html

Title: Blinken to warn China against helping Russia in Ukraine conflict
 Description: US Secretary of State Antony Blinken will visit Beijing next week and pressure Chinese leaders 

In [None]:
def generate_llm_with_rag(prompt, api_key=api_key, base_url=base_url, use_RAG=True, top_k=3):

    client = OpenAI(api_key=api_key, base_url=base_url)
    if use_RAG:
        context = context_pipeline(query, top_k=top_k)  # Get the context using the index_pipeline function
        prompt = gen_prompt(prompt, context=context)  # Use the sys_prompt function to format the prompt with RAG data
        response = client.chat.completions.create(
        model="gpt-4o-mini",  # or "gpt-4o-mini-preview" for the preview version
        temperature=0.7,
        max_tokens=500,
        top_p=1.0,
        messages=[{"role": "user", "content": prompt}]
    )
        return response.choices[0].message.content
    else:
        # If not using RAG, just return the prompt
        response = client.chat.completions.create(
        model="gpt-4o-mini",  # or "gpt-4o-mini-preview" for the preview version
        temperature=0.7,
        max_tokens=500,
        top_p=1.0,
        messages=[{"role": "user", "content": prompt}]
    )
        return response.choices[0].message.content

In [178]:
# Generate response using the LLM with RAG
 
user_input = input("You: ")
print("You:", user_input)
# Initialize messages with the user's input
response_generate_llm_with_rag = generate_llm_with_rag(user_input, api_key=api_key, base_url=base_url, use_RAG=True)
print("RAG Assistant:")
display(Markdown(response_generate_llm_with_rag))


You: tax


Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

  return forward_call(*args, **kwargs)


RAG Assistant:


I'm sorry, but the provided context does not contain any information related to taxes. If you have a specific question about taxes, feel free to ask!

In [None]:
# Generate response using the LLM without RAG
user_input = input("You: ")
print("You:", user_input)

response_generate_llm_with_rag_no_rag = generate_llm_with_rag(user_input, api_key=api_key, base_url=base_url, use_RAG=False)
print("Assistant (without RAG):")
display(Markdown(response_generate_llm_with_rag_no_rag))