# RAG Tutorial: How to load MediaTech's datasets from Hugging Face and use them in a RAG pipeline ?

This notebook demonstrates how to build a **Retrieval-Augmented Generation (RAG)** pipeline using:
- **Hugging Face Datasets**: Load pre-embedded legal documents from the LEGI dataset
- **Qdrant**: Vector database for efficient similarity search (dense + sparse vectors)
- **OpenAI compatible API**: We will here use OpenGateLLM, the open source French government's LLM API for embeddings and inference

## Prerequisites
- A running Qdrant instance (default: `localhost:6333`)
- An API key for the OpenAI compatible API (here https://albert.api.etalab.gouv.fr)
- Required packages: `fastembed`, `qdrant-client`, `openai`, `datasets`, `pandas`

---

## 1. Configuration

In [None]:
# Install required packages
#%pip install pandas fastembed qdrant-client openai datasets

In [None]:
from fastembed import SparseTextEmbedding
from qdrant_client import QdrantClient

# === API Configuration ===
# OpenGateLLM is a French government OpenAI-compatible API for LLMs and embeddings
API_KEY = "changeme"  # Replace with your actual OpenGateLLM key
API_URL = "https://albert.api.etalab.gouv.fr/v1"

# === Vector Database Configuration ===
# Connect to local Qdrant instance for storing and searching vectors
client = QdrantClient(url="http://localhost", port=6333)

# === Sparse Embedding Model ===
# BM25 model for keyword-based sparse embeddings (used in hybrid search)
bm25_embedding_model = SparseTextEmbedding("Qdrant/bm25")

---

## 2. Core Functions

This section defines the main functions for:
1. **Embedding generation**: Convert text to dense vectors using BGE-M3
2. **Retrieval**: Search the vector database using hybrid search (dense + sparse)
3. **Inference**: Generate responses using an LLM with streaming output
4. **Prompt construction**: Build RAG prompts with retrieved context

### 2.1 Embedding Generation

Generate dense embeddings using the BGE-M3 model via the OpenGateLLM API. This model is multilingual and works well for French legal text.

In [None]:
from openai import OpenAI


def generate_embeddings(
    data: str | list[str], model: str = "BAAI/bge-m3"
) -> list[float]:
    """
    Generates embeddings for a given text using a specified model.

    Args:
        data (str or list[str]): The input to generate embeddings for.
        model (str, optional): The model identifier to use for generating embeddings. Defaults to "BAAI/bge-m3".

    Returns:
        list[float]: The embedding vector for the input text.

    Raises:
        Any exceptions raised by the OpenAI client during the embedding generation process.

    Note:
        Requires properly configured API_URL and API_KEY for the OpenAI client.
    """
    client_openai = OpenAI(base_url=API_URL, api_key=API_KEY)
    vector = client_openai.embeddings.create(
        input=data, model=model, encoding_format="float"
    )
    embeddings = [item.embedding for item in vector.data]

    return embeddings

### 2.2 Hybrid Retrieval

Retrieve relevant documents using **Reciprocal Rank Fusion (RRF)** to combine:
- **Dense search**: Semantic similarity using BGE-M3 embeddings
- **Sparse search**: Keyword matching using BM25

This hybrid approach improves retrieval quality by leveraging both semantic understanding and exact keyword matching.

In [None]:
def inference(
    chat_messages: list[dict],
    model: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",  # Change to your preferred model
    return_output: bool = False,
    print_inference: bool = True,
    print_prompt: bool = False,
    max_tokens: int = 2000,
):
    """
    Performs inference using a chat-based model with streaming output.
    Args:
        chat_messages (list[dict]): The chat messages to send to the model.
        model (str, optional): The model name to use for inference. Defaults to "mistralai/Mistral-Small-3.2-24B-Instruct-2506".
        return_output (bool, optional): Whether to return the full output as a string. Defaults to False.
        print_inference (bool, optional): Whether to print the inference output in real-time. Defaults to True.
        print_prompt (bool, optional): Whether to print the prompt messages. Defaults to False.
        max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 2000.
    """
    client = OpenAI(
        api_key=API_KEY,
        base_url=API_URL,
    )

    if print_prompt:
        print(chat_messages)

    # stream chat.completions
    chat_response = client.chat.completions.create(
        model=model,  # this must be the model name the was deployed to the API server
        stream=True,
        # top_p=0.9,
        temperature=0.1,
        max_tokens=max_tokens,
        messages=chat_messages,
    )
    output = ""
    for chunk in chat_response:
        try:
            if chunk.choices[0].delta.content:
                output += chunk.choices[0].delta.content
                if print_inference:
                    print(chunk.choices[0].delta.content, flush=True, end="")
        except Exception as e:
            continue

    if return_output:
        return output


# Example usage
print(inference(chat_messages=[{"role": "user", "content": "Salut ca va ?"}]))

### 2.3 LLM Inference

Stream responses from the Mistral model via OpenGateLLM. Streaming provides a better user experience by displaying tokens as they are generated.

In [None]:
from qdrant_client import models


def retrieval(
    query: str,
    collection_name="legi_code_travail",
    hybrid_search: bool = True,
    limit: int = 10,
):
    """
    Retrieves relevant documents from a Qdrant collection based on a query.
    Args:
        query (str): The search query.
        collection_name (str, optional): The name of the Qdrant collection to search. Defaults to "legi_code_travail".
        hybrid_search (bool, optional): Whether to use hybrid search (embedding + sparse). Defaults to True.
        limit (int, optional): The maximum number of results to return. Defaults to 10.
    """
    embedding = generate_embeddings(query)[0]
    sparse_query_vector = next(bm25_embedding_model.query_embed(query))

    if hybrid_search:
        # Perform the search
        search_results = client.query_points(
            collection_name=collection_name,
            prefetch=[
                models.Prefetch(
                    query=embedding,
                    using="BAAI/bge-m3",
                    limit=2*limit,
                ),
                models.Prefetch(
                    query=models.SparseVector(**sparse_query_vector.as_object()),
                    using="bm25",
                    limit=2*limit,
                ),
            ],
            with_payload=True,
            query=models.FusionQuery(fusion=models.Fusion.RRF),
            limit=limit,
        )
    else:
        # Perform the search
        search_results = client.query_points(
            collection_name=collection_name,
            query=embedding,
            using="BAAI/bge-m3",
            limit=limit,
            with_payload=True,
        )

    # Print the closest result
    results = []
    if search_results:
        for result in search_results.points:
            results.append({"payload": result.payload, "score": result.score})
            # print("Closest point payload:", result)
        return results
    else:
        print("No results found")


### 2.4 RAG Prompt Construction

Build a prompt that includes retrieved documents as context. The LLM will use these documents to generate an informed response.

In [None]:
def make_prompt(
    query: str,
    system_prompt: str = "Tu es un assistant IA utile qui répond aux questions des utilisateurs en utilisant des documents pertinents fournis.",
    hybrid_search: bool = True,
    collection_name: str = "legi_code_travail",
    limit: int = 5,
):
    chunks = []
    results = retrieval(
        query=query, collection_name=collection_name, hybrid_search=hybrid_search
    )
    chunks.extend(results[k].get("payload") for k in range(len(results)))

    top_chunks = chunks[:limit]
    chat_messages = [{"role": "system", "content": system_prompt}]

    prompt = f"""
    Voici ci dessous les documents pertinents pour répondre à la question suivante : {query}\n
    """
    for chunk in top_chunks:
        prompt += f"""
        <<< {chunk.get("chunk_text", "")} >>>
        """

    chat_messages.append({"role": "user", "content": prompt})

    return chat_messages, top_chunks

### 2.5 Utility Functions

Helper functions, for example to generate deterministic UUIDs from chunk IDs for Qdrant point identification.

In [None]:
import hashlib
import uuid


def string_to_uuid(s: str) -> str:
    hash_bytes = hashlib.sha256(str(s).encode()).digest()[:16]
    return str(uuid.UUID(bytes=hash_bytes))

---

## 3. Loading Dataset and Creating Vector Database

We load the **French Labor Code** ("Code du Travail") from the [AgentPublic/legi](https://huggingface.co/datasets/AgentPublic/legi) dataset on Hugging Face. This dataset contains:
- Pre-chunked legal articles
- Pre-computed BGE-M3 embeddings
- Metadata (status, article ID, etc.)

We filter to keep only articles that are currently in force (`VIGUEUR`) or will be repealed in the future (`ABROGE_DIFF`).

In [None]:
import json

import pandas as pd
from datasets import load_dataset

# Load the French Labor Code subset from the LEGI dataset
# The dataset is available at: https://huggingface.co/datasets/AgentPublic/legi
dataset = load_dataset(
    "AgentPublic/legi", data_files="data/legi-latest/legi_code_du_travail/*.parquet"
)

df = pd.DataFrame(dataset["train"])
print(f"Total articles loaded: {len(df)}")

# Filter to keep only valid articles:
# - VIGUEUR: Currently in force
# - ABROGE_DIFF: Will be repealed at a future date (still valid now)
df = df[df["status"].isin(["VIGUEUR", "ABROGE_DIFF"])]
print(f"Articles after filtering: {len(df)}")

# Parse the pre-computed embeddings from JSON strings to lists
df["embeddings_bge-m3"] = df["embeddings_bge-m3"].apply(json.loads)

# Preview the dataset structure
df.head()

### Create Qdrant Collection with Hybrid Vectors

We create a Qdrant collection with two vector types:
1. **Dense vectors** (`BAAI/bge-m3`): Pre-computed semantic embeddings from the dataset
2. **Sparse vectors** (`bm25`): Computed on-the-fly using the BM25 model with IDF weighting

In [None]:
from qdrant_client import models
from qdrant_client.models import PointStruct
from tqdm import tqdm

collection_name = "legi_code_travail"
embedding_dim = len(df["embeddings_bge-m3"].iloc[0])

# Create collection if it doesn't exist
if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config={
            # Dense vector configuration for semantic search
            "BAAI/bge-m3": models.VectorParams(
                size=embedding_dim, distance=models.Distance.COSINE
            )
        },
        sparse_vectors_config={
            # Sparse vector configuration for BM25 keyword search
            "bm25": models.SparseVectorParams(modifier=models.Modifier.IDF)
        },
    )
    print(f"Created new collection: {collection_name}")
else:
    print(f"Collection '{collection_name}' already exists")

# Prepare points with both dense and sparse vectors
print("Preparing points with embeddings...")
points = []
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Computing BM25 embeddings"):
    # Compute BM25 sparse embeddings for hybrid search
    bm25_embeddings = list(bm25_embedding_model.passage_embed(row["chunk_text"]))
    
    points.append(
        PointStruct(
            id=string_to_uuid(row["chunk_id"]),
            vector={
                "BAAI/bge-m3": row["embeddings_bge-m3"],  # Dense vector
                "bm25": bm25_embeddings[0].as_object(),   # Sparse vector
            },
            payload={
                "chunk_text": row["chunk_text"],
                # Include all metadata columns except embeddings
                **{
                    col: row[col]
                    for col in df.columns
                    if col not in ["embeddings_bge-m3", "chunk_text"]
                },
            },
        )
    )

# Upsert points in batches for efficiency
batch_size = 100
print("Upserting points to Qdrant...")
for i in tqdm(range(0, len(points), batch_size), desc="Uploading batches"):
    client.upsert(collection_name=collection_name, points=points[i : i + batch_size])

print(f"\nCollection '{collection_name}' ready with {len(points)} vectors (dimension: {embedding_dim})")

---

## 4. Testing the RAG Pipeline

Now let's test our RAG system with a question about French labor law. The pipeline will:
1. **Retrieve** relevant legal articles using hybrid search
2. **Augment** the prompt with the retrieved context
3. **Generate** an informed response using the LLM

In [None]:
# Define the system prompt for the legal assistant
system_prompt = """Tu es un assistant IA utile et expert dans le domaine juridique qui répond aux questions des utilisateurs en utilisant des documents pertinents fournis.
Si tu ne sais pas, réponds que tu ne sais pas. 
"""

# Ask a question about French labor law
question = "Quelle est la durée journalière légale du travail en France ?"

# Build the RAG prompt with retrieved context
chat_messages, top_chunks = make_prompt(
    query=question,
    system_prompt=system_prompt,
    collection_name="legi_code_travail",
    hybrid_search=True,  # Use both dense and sparse search
    limit=7,             # Retrieve top 7 documents
)

# Display retrieved documents (optional - uncomment to see sources)
# print(f"Retrieved {len(top_chunks)} relevant documents\n")
# for k, chunk in enumerate(top_chunks):
#     print(f"---- Document {k+1} ----")
#     print(chunk.get("chunk_text")[:300] + "..." if len(chunk.get("chunk_text", "")) > 300 else chunk.get("chunk_text"))
#     print()

# Generate response using the LLM
print("=" * 50)
print("Response:")
print("=" * 50)
inference(
    chat_messages=chat_messages,
    model="mistralai/Mistral-Small-3.2-24B-Instruct-2506",
    return_output=False,
    print_inference=True,
    print_prompt=False,
    max_tokens=2000,
)