# RAG Pipeline with Ray Data

[![Ray](https://img.shields.io/badge/Ray-Data-blue)](https://docs.ray.io/en/latest/data/data.html)
[![Python](https://img.shields.io/badge/Python-3.8+-green)](https://python.org)

Build a scalable **Retrieval-Augmented Generation (RAG)** pipeline using Ray Data for distributed processing.

## Overview

This notebook demonstrates how to build a production-ready RAG pipeline that:
- Scales across multiple nodes and GPUs
- Uses vector databases for semantic search
- Combines retrieval with LLM generation

## Learning Objectives

By the end of this notebook, you will be able to:

1. **Understand RAG Architecture**: Learn how Retrieval-Augmented Generation combines vector search with LLMs
2. **Build Scalable Pipelines**: Use Ray Data to create distributed data processing pipelines
3. **Implement Key Components**: Create embedding generators, vector database readers, and LLM inference stages
4. **Chain Operations**: Connect multiple processing stages using Ray Data's `map_batches` API

## What is RAG?

RAG is a technique that enhances Large Language Models (LLMs) by providing them with relevant context retrieved from a knowledge base:

```
┌────────────┐      ┌────────────┐      ┌────────────┐
│   User     │      │  Retrieved │      │    LLM     │
│  Question  │  +   │  Context   │  =   │  Response  │
└────────────┘      └────────────┘      └────────────┘
```

This approach helps:
- **Reduce hallucinations**: Ground responses in actual documents
- **Enable domain knowledge**: Answer questions about specific documents
- **Stay up-to-date**: Use current information not in training data

## Why Ray Data for RAG?

| Feature | Benefit |
|---------|---------|
| **Scalability** | Distribute processing across multiple nodes and GPUs |
| **Streaming** | Process data without loading everything into memory |
| **Actor Pools** | Efficiently manage stateful resources like ML models |
| **Resource Management** | Fine-grained control over CPU, GPU, and memory |

## Step 1: Import Required Libraries



Before we begin, let's import the necessary libraries:

| Library | Purpose |
|---------|---------|
| `sentence_transformers` | Generate text embeddings (convert text to numerical vectors) |
| `chromadb` | Vector database for storing and querying embeddings |
| `ray` | Distributed computing framework for scaling our pipeline |
| `transformers` | Hugging Face library for LLM inference |
| `numpy` | Numerical operations on arrays |

> **Key Concept**: Embeddings are numerical representations of text that capture semantic meaning. Similar texts will have similar embeddings (close in vector space).

In [1]:
from sentence_transformers import SentenceTransformer
import chromadb
import ray
from transformers import pipeline
import numpy as np
import uuid
import shutil

### RAG Pipeline Architecture

Our goal is to build a complete RAG pipeline using Ray Data. Here's the high-level architecture:

```
┌─────────────────────────────────────────────────────────────────────────────────┐
│                         RAG PIPELINE ARCHITECTURE                                │
└─────────────────────────────────────────────────────────────────────────────────┘

                    ┌──────────────────────────────────────────┐
                    │           DOCUMENT INGESTION              │
                    │            (One-time Setup)               │
                    └──────────────────────────────────────────┘
                                       │
    ┌──────────────┐     ┌──────────────┐     ┌──────────────┐
    │  Source Text │ ──▶ │ DocEmbedder  │ ──▶ │   ChromaDB   │
    │   (around.   │     │  (Generate   │     │   (Vector    │
    │     txt)     │     │  Embeddings) │     │   Database)  │
    └──────────────┘     └──────────────┘     └──────────────┘
                                                     │
                    ┌──────────────────────────────────────────┐
                    │           QUERY PIPELINE                  │
                    │         (Runtime Execution)               │
                    └──────────────────────────────────────────┘
                                       │
    ┌──────────────┐     ┌──────────────┐     ┌──────────────┐
    │    User      │ ──▶ │   Embedder   │ ──▶ │ ChromaDB     │
    │   Prompts    │     │  (Vectorize  │     │   Reader     │
    │  (Parquet)   │     │   Queries)   │     │ (Retrieve)   │
    └──────────────┘     └──────────────┘     └──────────────┘
                                                     │
                                                     ▼
                         ┌──────────────┐     ┌──────────────┐
                         │   Prompt     │ ──▶ │    Chat      │
                         │  Enhancer    │     │   (LLM)      │
                         │ (Augment)    │     │  Response    │
                         └──────────────┘     └──────────────┘
                                                     │
                                                     ▼
                                          ┌──────────────────┐
                                          │  Output (Parquet │
                                          │   or Memory)     │
                                          └──────────────────┘
```

### Pipeline Stages Overview

| Stage | Component | Description | Resource |
|-------|-----------|-------------|----------|
| 1 | `ray.data.read_parquet()` | Load user questions from storage | CPU |
| 2 | `Embedder` class | Convert questions to vectors | CPU |
| 3 | `ChromaDBReader` class | Query vector database for similar documents | CPU |
| 4 | `PromptEnhancer` class | Combine retrieved context with original question | CPU |
| 5 | `Chat` class | Generate responses using an LLM | **GPU** |
| 6 | `write_parquet()` | Save results to storage | CPU |


## Step 2: Configure Model Names

We define our model names as constants for easy configuration. This makes it simple to swap models later.

In [2]:

# Embedding model: converts text to dense vectors
EMBEDDER_MODEL = 'all-MiniLM-L6-v2'

# Chat/LLM model: generates responses from prompts
CHAT_MODEL = 'Qwen/Qwen2.5-0.5B-Instruct'

print(f"Embedding Model: {EMBEDDER_MODEL}")
print(f"Chat Model: {CHAT_MODEL}")

Embedding Model: all-MiniLM-L6-v2
Chat Model: Qwen/Qwen2.5-0.5B-Instruct


**Understanding the Models**:

- **`all-MiniLM-L6-v2`**: A lightweight embedding model (23M parameters) that converts text into 384-dimensional vectors. It's fast and efficient, making it ideal for demos and smaller deployments. Runs well on CPU.

- **`Qwen/Qwen2.5-0.5B-Instruct`**: A small but capable instruction-tuned LLM (500M parameters). We use a smaller model here for efficiency, but you can swap it for larger models in production.

**Note on Model Selection**:
- The embedding model must be the SAME for both document ingestion and query embedding
- Using mismatched models results in incompatible vector spaces and poor retrieval
- For production, consider larger models like `hkunlp/instructor-large` (with GPU) for better quality



## Step 3: Load the Input Data

We use `ray.data.read_parquet()` to load our prompts. Ray Data creates a **lazy dataset** that only reads data when needed.

In [3]:
!cp prompts.parquet /mnt/cluster_storage/prompts.parquet

In [4]:
data = ray.data.read_parquet('/mnt/cluster_storage/prompts.parquet')
data.take_batch(4)

2026-01-22 01:45:28,806	INFO worker.py:1821 -- Connecting to existing Ray cluster at address: 10.0.9.248:6379...
2026-01-22 01:45:28,821	INFO worker.py:1998 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-v4klp1kjtnk9yrxwdcz5ah11ub.i.anyscaleuserdata.com [39m[22m
2026-01-22 01:45:28,851	INFO packaging.py:463 -- Pushing file package 'gcs://_ray_pkg_5a796cdbf12b68826eeb8c56dddb942d58f41f86.zip' (10.48MiB) to Ray cluster...
2026-01-22 01:45:28,895	INFO packaging.py:476 -- Successfully pushed file package 'gcs://_ray_pkg_5a796cdbf12b68826eeb8c56dddb942d58f41f86.zip'.
2026-01-22 01:45:29,330	INFO logging.py:397 -- Registered dataset logger for dataset dataset_183_0
2026-01-22 01:45:29,355	INFO streaming_executor.py:178 -- Starting execution of Dataset dataset_183_0. Full logs are in /tmp/ray/session_2026-01-21_22-31-19_329458_2360/logs/ray-data
2026-01-22 01:45:29,356	INFO streaming_executor.py:179 -- Execution plan of Dataset dataset_183_0: InputDataBuffer[Inp

{'prompt': array(['Describe the body of water in Utah?',
        'Tell as much as you can about the robbery?',
        'Did Phileas Fogg really rob the bank?',
        'Who is the main protagonist of Around the World in 80 Days?'],
       dtype=object)}

The output shows us sample prompts from our dataset. Notice these are questions about various topics - this is what users might ask in a RAG system.


## Step 4: Create the Embedding Generator



### What are Embeddings?

Embeddings are dense vector representations of text that capture semantic meaning:
- Similar texts produce similar vectors (close in vector space)
- We can measure similarity using cosine similarity or Euclidean distance
- This allows us to find relevant documents by comparing query embeddings to document embeddings


> **Why this pattern?** Loading a 100MB+ model for every batch would be extremely slow. By loading once in `__init__`, we amortize the cost across many batches.

In [5]:
# ============================================================================
# Embedder Class: Converting Text to Vector Embeddings
# ============================================================================

class Embedder:
    """
    Converts text prompts into dense vector embeddings using a sentence transformer.
    
    The embedding model is loaded once when the actor starts and reused for all batches.
    This is much more efficient than loading the model for each batch.
    """
    def __init__(self, model: str):
        self._device = 'cpu'
        self._model = SentenceTransformer(model, device=self._device)
        print(f"Embedder initialized with model '{model}' on device: {self._device}")
        
    def __call__(self, batch):
        # Generate embeddings for all prompts in the batch
        # The encode() method handles batching internally for efficiency
        batch['prompt_embedding'] = self._model.encode(
            batch['prompt'], 
            batch_size=32
        )
        return batch

### Running the Embedder with Ray Data

Now let's run our Embedder using `map_batches`. Key parameters explained:

| Parameter | Value | Explanation |
|-----------|-------|-------------|
| `fn_constructor_args` | `[EMBEDDER_MODEL]` | Arguments passed to `__init__` |
| `compute` | `ActorPoolStrategy(size=2)` | Use 2 persistent actors (not tasks) |
| `batch_size` | `4` | Process 4 prompts at a time |

### Why ActorPoolStrategy?

```
┌────────────────────────────────────────────────────────────────────────┐
│                 TASKS vs ACTORS COMPARISON                             │
├────────────────────────────────────────────────────────────────────────┤
│                                                                        │
│  TASKS (stateless):                  ACTORS (stateful):                │
│  ┌─────┐ ┌─────┐ ┌─────┐            ┌─────────────────────┐            │
│  │Load │ │Load │ │Load │            │ Load Model ONCE     │            │
│  │Model│ │Model│ │Model│            │     ↓               │            │
│  │  ↓  │ │  ↓  │ │  ↓  │            │ Process Batch 1     │            │
│  │Batch│ │Batch│ │Batch│            │ Process Batch 2     │            │
│  │  1  │ │  2  │ │  3  │            │ Process Batch 3     │            │
│  └─────┘ └─────┘ └─────┘            │     ...             │            │
│     ↑        ↑       ↑              └─────────────────────┘            │
│  SLOW! Model loaded 3x              FAST! Model loaded 1x              │
│                                                                        │
└────────────────────────────────────────────────────────────────────────┘
```

- **Actors** are long-lived processes that maintain state (our loaded model)
- **Tasks** would load the model for each batch - very inefficient!
- `size=2` means we run 2 actors in parallel for throughput

In [6]:
embedder_output = data.map_batches(Embedder, 
                                   fn_constructor_args=[EMBEDDER_MODEL],
                                   compute=ray.data.ActorPoolStrategy(size=2),
                                    batch_size=4      # Process 4 prompts per batch
                                   ).take_batch(4)

2026-01-22 01:45:40,743	INFO logging.py:397 -- Registered dataset logger for dataset dataset_185_0
2026-01-22 01:45:40,751	INFO streaming_executor.py:178 -- Starting execution of Dataset dataset_185_0. Full logs are in /tmp/ray/session_2026-01-21_22-31-19_329458_2360/logs/ray-data
2026-01-22 01:45:40,752	INFO streaming_executor.py:179 -- Execution plan of Dataset dataset_185_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> LimitOperator[limit=4] -> ActorPoolMapOperator[MapBatches(Embedder)]
2026-01-22 01:45:40,943	INFO progress_bar.py:213 -- === Ray Data Progress {ListFiles} ===
2026-01-22 01:45:40,944	INFO progress_bar.py:215 -- ListFiles: Tasks: 1; Actors: 0; Queued blocks: 0 (0.0B); Resources: 1.0 CPU, 384.0MiB object store: Progress Completed 0 / ?
2026-01-22 01:45:40,945	INFO progress_bar.py:213 -- === Ray Data Progress {ReadFiles} ===
2026-01-22 01:45:40,946	INFO progress_bar.py:215 -- ReadFiles: Tasks: 0; Actors: 0; Queued blocks: 

[36m(MapWorker(MapBatches(Embedder)) pid=38215, ip=10.0.50.252)[0m Embedder initialized with model 'all-MiniLM-L6-v2' on device: cpu


The output shows our embeddings - 384-dimensional vectors (one for each prompt). These vectors capture the semantic meaning of our questions.

## Step 5: Set Up the Vector Database (ChromaDB)



### What is a Vector Database?

A vector database is specialized storage for embeddings that enables:
- **Fast similarity search**: Find similar vectors using approximate nearest neighbor (ANN) algorithms
- **Persistence**: Store embeddings across sessions
- **Scalability**: Handle millions or billions of vectors

### ChromaDB Concepts

| Concept | Description |
|---------|-------------|
| **Client** | Connection to the database (persistent or in-memory) |
| **Collection** | A group of embeddings (like a table in SQL) |
| **Documents** | The original text associated with each embedding |
| **Query** | Find similar documents by comparing embeddings |

> **Note**: In a real RAG system, you would pre-populate the vector database with your knowledge base. Here we create a fresh collection for demonstration.

In [7]:
# ============================================================================
# Initialize Vector Database with Fresh Collection
# ============================================================================

import chromadb
import shutil

# Remove existing vector store for a clean start (just for the DEMO)
shutil.rmtree("/mnt/cluster_storage/vector_store", ignore_errors=True)

# Create a persistent ChromaDB client
client = chromadb.PersistentClient(path="/mnt/cluster_storage/vector_store")
collection = client.create_collection("persistent_text_chunks")

print(f"Created collection 'persistent_text_chunks' at /mnt/cluster_storage/vector_store")

Created collection 'persistent_text_chunks' at /mnt/cluster_storage/vector_store


### Step 5.1: Populate the Vector Database with Document Embeddings

**CRITICAL**: Before the RAG pipeline can retrieve relevant documents, we must first populate the vector database with embeddings of our knowledge base documents.

```
┌─────────────────────────────────────────────────────────────────────────┐
│                    DOCUMENT INGESTION FLOW                               │
└─────────────────────────────────────────────────────────────────────────┘

    around.txt                  DocEmbedder                   ChromaDB
    ┌─────────┐                ┌───────────┐                ┌──────────┐
    │ Line 1  │ ──────────────▶│           │──────────────▶│ ID: abc  │
    │ Line 2  │     Filter     │ Generate  │    Upsert     │ Vec:[...]│
    │ Line 3  │   (len > 10)   │ Embeddings│               │ Doc:"..."│
    │   ...   │                │ + UUIDs   │               │          │
    │ Line N  │                │           │               │ ID: xyz  │
    └─────────┘                └───────────┘               │ Vec:[...]│
                                                           │ Doc:"..."│
                                                           └──────────┘
```

This step uses Ray Data to:
1. **Read** the source text file
2. **Generate embeddings** for each paragraph using the same embedding model
3. **Store** the embeddings in ChromaDB with unique IDs

**Why is this important?**
- The RAG pipeline queries the vector database to find similar documents
- Without pre-populated embeddings, there's nothing to retrieve!
- This is typically a one-time ingestion process done before serving queries

> **Production Note**: In production, you might run ingestion on a schedule to add new documents, use incremental updates, or monitor collection size and embedding quality.

In [8]:
# ============================================================================
# Document Embedder for Ingestion
# ============================================================================

class DocEmbedder:
    """
    Generates embeddings for document chunks to be stored in the vector database.
    """
    def __init__(self, model: str):
        self._model = SentenceTransformer(model, device='cpu')
        print(f"DocEmbedder initialized with model: {model} on CPU")
        
    def __call__(self, batch):
        # Generate embeddings for all text in the batch
        embeddings = self._model.encode(batch['text'], batch_size=32)
        
        # Generate unique IDs for each document (required by ChromaDB)
        ids = np.array([uuid.uuid1().hex for _ in batch['text']])
        
        # Return dict with columns needed for ChromaDB: doc, vec, id
        return {
            'doc': batch['text'],      # Original text for retrieval
            'vec': embeddings,          # Vector embeddings
            'id': ids                   # Unique identifiers
        }

# ============================================================================
# ChromaDB Writer Class
# ============================================================================

class ChromaDBWriter:
    """
    Writes document embeddings to ChromaDB in batches.
    
    Maintains a persistent connection to ChromaDB across batch operations.
    """
    def __init__(self, collection_name: str):
        # Connect to the persistent ChromaDB instance
        self._client = chromadb.PersistentClient(path="/mnt/cluster_storage/vector_store")
        self._collection = self._client.get_collection(collection_name)
        print(f"ChromaDBWriter connected to collection: {collection_name}")
        
    def __call__(self, batch):
        # Upsert embeddings (insert or update if ID exists)
        self._collection.upsert(
            embeddings=batch['vec'].tolist(),
            documents=batch['doc'].tolist(),
            ids=batch['id'].tolist()
        )
        # Return count of documents written (for monitoring)
        return {'docs_written': np.array([len(batch['id'])])}

# ============================================================================
# Run the Ingestion Pipeline
# ============================================================================
# This pipeline reads text, generates embeddings, and stores them in ChromaDB.
# Using CPU workers for the lightweight MiniLM model is efficient and stable.
# ============================================================================

# Read the source text file - each line becomes a "document chunk"
# Filter out empty lines to avoid storing useless embeddings
result = ray.data.read_text("/mnt/cluster_storage/around.txt") \
    .filter(lambda row: len(row['text'].strip()) > 10) \
    .map_batches(
        DocEmbedder, 
        fn_constructor_args=[EMBEDDER_MODEL],
        compute=ray.data.ActorPoolStrategy(size=4),
        batch_size=64      # Process 64 documents at a time
    ) \
    .map_batches(
        ChromaDBWriter,
        fn_constructor_args=['persistent_text_chunks'],
        compute=ray.data.ActorPoolStrategy(size=1),  
        batch_size=100
    ) \
    .sum('docs_written')

print(f"\nIngestion complete! Total documents indexed: {result}")

2026-01-22 01:45:50,229	INFO dataset.py:3641 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2026-01-22 01:45:50,234	INFO logging.py:397 -- Registered dataset logger for dataset dataset_191_0
2026-01-22 01:45:50,246	INFO streaming_executor.py:178 -- Starting execution of Dataset dataset_191_0. Full logs are in /tmp/ray/session_2026-01-21_22-31-19_329458_2360/logs/ray-data
2026-01-22 01:45:50,246	INFO streaming_executor.py:179 -- Execution plan of Dataset dataset_191_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Filter(<lambda>)] -> ActorPoolMapOperator[MapBatches(DocEmbedder)] -> ActorPoolMapOperator[MapBatches(ChromaDBWriter)] -> HashAggregateOperator[HashAggregate(key_columns=(), num_partitions=1)] -> LimitOperator[limit=1]
2026-01-22 01:45:50,526	INFO progress_bar.py:213 -- === Ray Data Progress {ListFiles} ===
2026-01-22 01:45:50,527	INFO progress_bar.py

[36m(MapWorker(MapBatches(ChromaDBWriter)) pid=38309, ip=10.0.50.252)[0m ChromaDBWriter connected to collection: persistent_text_chunks


2026-01-22 01:45:55,550	INFO progress_bar.py:215 -- ListFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 1 / 1
2026-01-22 01:45:55,551	INFO progress_bar.py:215 -- ReadFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 1654 / 1654
2026-01-22 01:45:55,552	INFO progress_bar.py:215 -- Filter(<lambda>): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 368.0KiB object store: Progress Completed 1614 / 1614
2026-01-22 01:45:55,553	INFO progress_bar.py:215 -- MapBatches(DocEmbedder): Tasks: 0; Actors: 4 (running=0, restarting=0, pending=4); Queued blocks: 1 (368.0KiB); Resources: 0.0 CPU, 0.0B object store; [all objects local]: Progress Completed 0 / ?
2026-01-22 01:45:55,554	INFO progress_bar.py:215 -- MapBatches(ChromaDBWriter): Tasks: 0; Actors: 1; Queued blocks: 0 (0.0B); Resources: 1.0 CPU, 0.0B object store; [all objects local]: Progress Completed 0 / 

[36m(MapWorker(MapBatches(DocEmbedder)) pid=34573, ip=10.0.45.231)[0m DocEmbedder initialized with model: all-MiniLM-L6-v2 on CPU


2026-01-22 01:46:00,565	INFO progress_bar.py:215 -- ListFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 1 / 1
2026-01-22 01:46:00,566	INFO progress_bar.py:215 -- ReadFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 1654 / 1654
2026-01-22 01:46:00,567	INFO progress_bar.py:215 -- Filter(<lambda>): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 368.0KiB object store: Progress Completed 1614 / 1614
2026-01-22 01:46:00,568	INFO progress_bar.py:215 -- MapBatches(DocEmbedder): Tasks: 1; Actors: 1; Queued blocks: 0 (0.0B); Resources: 1.0 CPU, 384.0MiB object store; [0/1 objects local]: Progress Completed 0 / ?
2026-01-22 01:46:00,569	INFO progress_bar.py:215 -- MapBatches(ChromaDBWriter): Tasks: 0; Actors: 1; Queued blocks: 0 (0.0B); Resources: 1.0 CPU, 0.0B object store; [all objects local]: Progress Completed 0 / ?
2026-01-22 01:46:00,570	INFO progre


Ingestion complete! Total documents indexed: 1614


In [9]:
# ============================================================================
# Verify the Ingestion
# ============================================================================
# Let's confirm the vector database was populated correctly by:
# 1. Checking the document count
# 2. Running a test query to see if retrieval works
# ============================================================================

# Connect to ChromaDB and check collection stats
verify_client = chromadb.PersistentClient(path="/mnt/cluster_storage/vector_store")
verify_collection = verify_client.get_collection("persistent_text_chunks")

doc_count = verify_collection.count()
print(f"Total documents in collection: {doc_count}")

# Run a test query to verify retrieval works
# First, create an embedding for a test question using the SAME model
test_model = SentenceTransformer(EMBEDDER_MODEL)
test_query = test_model.encode("What is the Great Salt Lake?").tolist()

# Query ChromaDB for similar documents
test_results = verify_collection.query(
    query_embeddings=[test_query],
    n_results=2
)

print(f"\nTest Query: 'What is the Great Salt Lake?'")
print(f"Retrieved {len(test_results['documents'][0])} documents:")
for i, doc in enumerate(test_results['documents'][0]):
    print(f"\n  Document {i+1}:")
    print(f"    {doc[:200]}..." if len(doc) > 200 else f"    {doc}")

Total documents in collection: 1614

Test Query: 'What is the Great Salt Lake?'
Retrieved 2 documents:

  Document 1:
    The Salt Lake, seventy miles long and thirty-five wide, is situated three miles eight hundred feet above the sea. Quite different from Lake Asphaltite, whose depression is twelve hundred feet below th...

  Document 2:
    The track up to this time had reached its highest elevation at the Great Salt Lake. From this point it described a long curve, descending towards Bitter Creek Valley, to rise again to the dividing rid...


### The ChromaDBReader Class

This class queries ChromaDB to find documents similar to our query embeddings. It follows the same pattern as `Embedder`:

- **`__init__`**: Connect to ChromaDB and get the collection (done once per actor)
- **`__call__`**: Query for similar documents for each batch of embeddings

The `top_n` parameter controls how many similar documents to retrieve for each query.

**Key Method**: `collection.query(query_embeddings=vecs, n_results=self._top_n)` 
- Takes embedding vectors and returns the `n` most similar documents
- Uses approximate nearest neighbor search for speed

In [10]:
# ============================================================================
# ChromaDBReader: Vector Database Retrieval
# ============================================================================

class ChromaDBReader:
    """
    Retrieves similar documents from ChromaDB based on query embeddings.
    
    The database connection is established once per actor and reused.
    """
    def __init__(self, collection: str, top_n: int):
        # Connect to the persistent ChromaDB instance
        chroma_client = chromadb.PersistentClient(path="/mnt/cluster_storage/vector_store")
        self._coll = chroma_client.get_collection(collection)
        self._top_n = top_n
    
    def __call__(self, batch):
        """
        Query ChromaDB for similar documents for each embedding in the batch.
        
        Args:
            batch: Dict with 'prompt_embedding' containing query vectors
            
        Returns:
            Dict with 'responsive_documents' containing retrieved documents
        """
        # Convert numpy array to list for ChromaDB API
        vecs = list(batch['prompt_embedding'])
        
        # Query ChromaDB for similar documents
        # Returns: {'ids': [...], 'documents': [...], 'distances': [...], ...}
        results = self._coll.query(
            query_embeddings=vecs, 
            n_results=self._top_n
        )
        
        # Add retrieved documents to the batch
        # results['documents'] is a list of lists (one list per query)
        batch['responsive_documents'] = results['documents']
        return batch

### Chaining Pipeline Stages

Now we chain `Embedder` and `ChromaDBReader` together using multiple `map_batches` calls:

```
┌─────────────────────────────────────────────────────────────────────────────────────────────────────┐
│                                    DATA FLOW THROUGH PIPELINE                                       │
└─────────────────────────────────────────────────────────────────────────────────────────────────────┘

    Input Batch                    After Embedder                      After ChromaDBReader
    ┌──────────────┐              ┌─────────────────────────┐         ┌──────────────────────────────┐
    │ {            │              │ {                       │         │ {                            │
    │   'prompt':  │  ────────▶   │   'prompt': [...],      │ ──────▶ │   'prompt': [...],           │
    │     [...]    │   Stage 1    │   'prompt_embedding':   │ Stage 2 │   'prompt_embedding':        │
    │ }            │              │     [[0.1, ...], ...]   │         │     [[0.1, ...], ...],       │
    └──────────────┘              │ }                       │         │   'retrieved_documents':     │
                                  └─────────────────────────┘         │     [["doc1", ...], ...]     │
                                                                      │ }                            │
                                                                      └──────────────────────────────┘
```
**How it works**:
1. Each batch flows through `Embedder`, which adds `prompt_embedding` to the batch
2. The enriched batch then flows to `ChromaDBReader`, which adds `responsive_documents`
3. Ray Data handles all the data transfer, parallelization, and backpressure automatically

In [11]:
# ============================================================================
# Demo: Embedder + ChromaDBReader (Two-Stage Pipeline)
# ============================================================================
# Now let's chain two stages together to see document retrieval in action.
#
# Data Flow:
# 1. Prompts -> Embedder: Converts text to vectors (CPU)
# 2. Vectors -> ChromaDBReader: Finds similar documents in the database (CPU)
#
# Note: This won't return useful documents if the database is empty!
# Make sure you ran the ingestion step (Step 5.1) first.
# ============================================================================

two_stage_output = data.map_batches(
    Embedder, 
    fn_constructor_args=[EMBEDDER_MODEL], 
    compute=ray.data.ActorPoolStrategy(size=2),
    batch_size=4
).map_batches(
    ChromaDBReader, 
    fn_constructor_args=['persistent_text_chunks', 3],  # Retrieve top 3 documents
    compute=ray.data.ActorPoolStrategy(size=2)
).take_batch(4)

# Display results
print("Two-stage pipeline output:")
print("-" * 60)
for i, (prompt, docs) in enumerate(zip(two_stage_output['prompt'], two_stage_output['responsive_documents'])):
    print(f"\nQuestion: {prompt}")
    print(f"Retrieved {len(docs)} documents:")
    for j, doc in enumerate(docs[:2]):  # Show first 2 docs
        print(f"  {j+1}. {doc[:100]}..." if len(doc) > 100 else f"  {j+1}. {doc}")

2026-01-22 01:47:24,032	INFO logging.py:397 -- Registered dataset logger for dataset dataset_194_0
2026-01-22 01:47:24,037	INFO streaming_executor.py:178 -- Starting execution of Dataset dataset_194_0. Full logs are in /tmp/ray/session_2026-01-21_22-31-19_329458_2360/logs/ray-data
2026-01-22 01:47:24,038	INFO streaming_executor.py:179 -- Execution plan of Dataset dataset_194_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> LimitOperator[limit=4] -> ActorPoolMapOperator[MapBatches(Embedder)] -> ActorPoolMapOperator[MapBatches(ChromaDBReader)]
2026-01-22 01:47:24,254	INFO progress_bar.py:213 -- === Ray Data Progress {ListFiles} ===
2026-01-22 01:47:24,256	INFO progress_bar.py:215 -- ListFiles: Tasks: 1; Actors: 0; Queued blocks: 0 (0.0B); Resources: 1.0 CPU, 384.0MiB object store: Progress Completed 0 / ?
2026-01-22 01:47:24,257	INFO progress_bar.py:213 -- === Ray Data Progress {ReadFiles} ===
2026-01-22 01:47:24,257	INFO progress_bar.py:21

[36m(MapWorker(MapBatches(Embedder)) pid=39005, ip=10.0.50.252)[0m Embedder initialized with model 'all-MiniLM-L6-v2' on device: cpu


2026-01-22 01:47:30,338	INFO streaming_executor.py:305 -- ✔️  Dataset dataset_194_0 execution finished in 6.30 seconds
INFO:openlineage.client.transport.composite:Stopping OpenLineage CompositeTransport emission after the first successful delivery because `continue_on_success=False`. Transport that emitted the event: <HttpTransport(name=first, kind=http, priority=1)>


Two-stage pipeline output:
------------------------------------------------------------

Question: Describe the body of water in Utah?
Retrieved 3 documents:
  1. The Salt Lake, seventy miles long and thirty-five wide, is situated three miles eight hundred feet a...
  2. During the lecture the train had been making good progress, and towards half-past twelve it reached ...

Question: Tell as much as you can about the robbery?
Retrieved 3 documents:
  1. “Listen. On the 28th of last September a robbery of fifty-five thousand pounds was committed at the ...
  2. “Well, Ralph,” said Thomas Flanagan, “what about that robbery?”

Question: Did Phileas Fogg really rob the bank?
Retrieved 3 documents:
  1. Phileas Fogg had won his wager of twenty thousand pounds!
  2. Phileas Fogg did not betray the least disappointment; but the situation was a grave one. It was not ...

Question: Who is the main protagonist of Around the World in 80 Days?
Retrieved 3 documents:
  1. Around the World in Eighty

**Note**: The `responsive_documents` field may be empty if no matching documents were found. In a production system, you would have pre-populated the vector database with relevant documents.

## Step 6: Enhance Prompts with Retrieved Context



### The "Augmentation" in RAG

This is where RAG gets its name - we **augment** the user's question with retrieved context:

```
┌─────────────────────────────────────────────────────────────────────────┐
│                    PROMPT AUGMENTATION                                   │
└─────────────────────────────────────────────────────────────────────────┘

    Original Question              Retrieved Documents          Enhanced Prompt
    ┌──────────────────┐          ┌──────────────────┐        ┌─────────────────────┐
    │ "What is the     │          │ "The Salt Lake,  │        │ System: You are a   │
    │  Great Salt      │    +     │  seventy miles   │   =    │  helpful assistant  │
    │  Lake?"          │          │  long..."        │        │                     │
    └──────────────────┘          │ "It reached the  │        │ User: Context:      │
                                  │  Great Salt Lake │        │ - "The Salt Lake..."│
                                  │  towards..."     │        │ - "It reached..."   │
                                  └──────────────────┘        │                     │
                                                              │ Question: What is   │
                                                              │ the Great Salt Lake?│
                                                              └─────────────────────┘
```

### Prompt Engineering Best Practices

The enhanced prompt includes:
- **System message**: Tells the LLM its role ("You are a helpful assistant")
- **Context**: The retrieved documents that may help answer the question
- **Instructions**: Tells the LLM to admit when it doesn't know something
- **User question**: The original question

> **Tip**: Experiment with different system prompts to see how they affect output quality.

In [12]:
# ============================================================================
# PromptEnhancer: The "Augmentation" in RAG
# ============================================================================
# This class combines:
# 1. The user's original question
# 2. Retrieved relevant documents from the vector database
# 3. System instructions for the LLM
#
# This creates a "context-aware" prompt that helps the LLM answer questions
# about specific documents it wasn't trained on.
# ============================================================================

class PromptEnhancer:
    """
    Enhances user prompts with retrieved document context for RAG.
    
    The enhanced prompt follows a chat format:
    - System message: Sets the LLM's behavior
    - User message: Contains retrieved context + original question
    """
    def __init__(self):
        # Base template for the user message
        # The {context} placeholder will be replaced with retrieved documents
        self._user_template = """You are a helpful assistant who can answer questions about a text based on your existing knowledge and documents supplied here.

                                When answering questions, use the following relevant excerpts from the text:
                                {context}

                                If you don't have information to answer a question, please say you don't know. Don't make up an answer.

                                Question: {question}"""
    
    def __call__(self, batch):
        original_prompts = batch['prompt']
        enhanced_prompts = []
        
        for ix, original_prompt in enumerate(original_prompts):
            # Get retrieved documents for this prompt
            docs = batch['responsive_documents'][ix]
            
            # Format the context from retrieved documents
            # Handle both list and numpy array cases safely
            if docs is not None and len(docs) > 0:
                context = "\n".join([f"- {doc}" for doc in docs])
            else:
                context = "No relevant documents found."
            
            # Build the enhanced prompt using safe string formatting 
            user_content = self._user_template.format(
                context=context,
                question=original_prompt
            )
            
            # Create chat-formatted message for the LLM
            enhanced_prompts.append([
                {"role": "system", "content": "You are a helpful assistant that answers questions based on provided context."},
                {"role": "user", "content": user_content}
            ])

        batch['enhanced_prompt'] = enhanced_prompts
        return batch

### Running the Three-Stage Pipeline

Now we chain three stages together:
1. **Embedder** → Generates embeddings
2. **ChromaDBReader** → Retrieves relevant documents
3. **PromptEnhancer** → Creates augmented prompts

The output will include an `enhanced_prompt` field with the complete prompt ready for the LLM.

In [13]:
# ============================================================================
# Demo: Three-Stage Pipeline (Embedder + ChromaDBReader + PromptEnhancer)
# ============================================================================
# Let's see what the enhanced prompts look like before sending to the LLM.
# This shows the "augmentation" step - combining user questions with context.
# ============================================================================

three_stage_output = data.map_batches(
                            Embedder, 
                            fn_constructor_args=[EMBEDDER_MODEL], 
                            compute=ray.data.ActorPoolStrategy(size=2),
                            batch_size=4
                        ).map_batches(
                            ChromaDBReader, 
                            fn_constructor_args=['persistent_text_chunks', 3], 
                            compute=ray.data.ActorPoolStrategy(size=2)
                        ).map_batches(
                            PromptEnhancer, 
                            compute=ray.data.ActorPoolStrategy(size=2)
                        ).take_batch(3)

# Display the enhanced prompt structure
print("Three-stage pipeline output (showing enhanced prompt structure):")
print("=" * 60)
for i, (prompt, enhanced) in enumerate(zip(three_stage_output['prompt'][:2], 
                                            three_stage_output['enhanced_prompt'][:2])):
    print(f"\n--- Example {i+1} ---")
    print(f"Original Question: {prompt}")
    print(f"\nEnhanced Prompt Structure:")
    for msg in enhanced:
        print(f"  Role: {msg['role']}")
        content_preview = msg['content'][:200] + "..." if len(msg['content']) > 200 else msg['content']
        print(f"  Content: {content_preview}")
    print()

2026-01-22 01:47:30,476	INFO logging.py:397 -- Registered dataset logger for dataset dataset_198_0
2026-01-22 01:47:30,478	INFO limit_pushdown.py:140 -- Skipping push down of limit 3 through map MapBatches[MapBatches(Embedder)] because it requires 4 rows to produce stable outputs
2026-01-22 01:47:30,479	INFO limit_pushdown.py:140 -- Skipping push down of limit 3 through map MapBatches[MapBatches(Embedder)] because it requires 4 rows to produce stable outputs
2026-01-22 01:47:30,484	INFO streaming_executor.py:178 -- Starting execution of Dataset dataset_198_0. Full logs are in /tmp/ray/session_2026-01-21_22-31-19_329458_2360/logs/ray-data
2026-01-22 01:47:30,484	INFO streaming_executor.py:179 -- Execution plan of Dataset dataset_198_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> ActorPoolMapOperator[MapBatches(Embedder)] -> LimitOperator[limit=3] -> ActorPoolMapOperator[MapBatches(ChromaDBReader)] -> ActorPoolMapOperator[MapBatches(Promp

[36m(MapWorker(MapBatches(Embedder)) pid=35171, ip=10.0.45.231)[0m Embedder initialized with model 'all-MiniLM-L6-v2' on device: cpu


2026-01-22 01:47:37,099	INFO streaming_executor.py:305 -- ✔️  Dataset dataset_198_0 execution finished in 6.61 seconds
INFO:openlineage.client.transport.composite:Stopping OpenLineage CompositeTransport emission after the first successful delivery because `continue_on_success=False`. Transport that emitted the event: <HttpTransport(name=first, kind=http, priority=1)>


Three-stage pipeline output (showing enhanced prompt structure):

--- Example 1 ---
Original Question: Describe the body of water in Utah?

Enhanced Prompt Structure:
  Role: system
  Content: You are a helpful assistant that answers questions based on provided context.
  Role: user
  Content: You are a helpful assistant who can answer questions about a text based on your existing knowledge and documents supplied here.

                                When answering questions, use the follo...


--- Example 2 ---
Original Question: Tell as much as you can about the robbery?

Enhanced Prompt Structure:
  Role: system
  Content: You are a helpful assistant that answers questions based on provided context.
  Role: user
  Content: You are a helpful assistant who can answer questions about a text based on your existing knowledge and documents supplied here.

                                When answering questions, use the follo...



## Step 7: Add LLM Inference to the Pipeline



### The Chat Class

Now we add the final piece - the LLM that generates responses. The Chat class:

- Uses Hugging Face's `pipeline` API for text generation
- Loads the model to GPU for fast inference
- Processes batches of enhanced prompts

### Key Parameters

| Parameter | Value | Explanation |
|-----------|-------|-------------|
| `max_new_tokens` | `200` | Maximum length of generated response |
| `truncation` | `True` | Truncate long inputs to fit model context |
| `cache_dir` | `/mnt/local_storage` | Cache model weights locally |

> **Performance Note**: `max_new_tokens` affects both response quality and latency. More tokens = longer responses but slower generation.

In [14]:
# ============================================================================
# Chat Class: LLM Response Generation
# ============================================================================

class Chat:
    """
    Generates LLM responses for enhanced prompts using Hugging Face pipelines.
    
    The model is loaded once on initialization and reused for all batches.
    This avoids the significant overhead of loading a model for each batch.
    """
    def __init__(self, model: str):
        """
        Initialize the chat model.
        
        Args:
            model: Hugging Face model ID (e.g., 'Qwen/Qwen2.5-0.5B-Instruct')
        """
        # Create a text generation pipeline with the specified model
        # device='cuda:0' ensures we use the GPU for faster inference
        self.pipe = pipeline(
            "text-generation", 
            model=model, 
            device='cuda:0',
            # Cache model weights locally to avoid re-downloading
            model_kwargs={"cache_dir": "/mnt/local_storage"}
        )
        print(f"Chat model '{model}' loaded on GPU")
    
    def __call__(self, batch):
        """
        Generate responses for a batch of enhanced prompts.
        
        Args:
            batch: Dict with 'enhanced_prompt' containing chat-formatted messages
            
        Returns:
            Dict with 'responses' containing the LLM outputs
        """
        # Convert enhanced prompts from arrays to lists for the pipeline
        # Each prompt is a list of messages [{"role": "system", ...}, {"role": "user", ...}]
        enhanced_prompts = [[msg for msg in prompt] for prompt in batch['enhanced_prompt']]
        batch['responses'] = self.pipe(
            enhanced_prompts, 
            max_new_tokens=200,    # Max tokens to generate
            truncation=True        # Truncate long inputs
        )
        return batch

[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39275, ip=10.0.50.252)[0m Failed to convert column 'responsive_documents' into pyarrow array due to: Error converting data to Arrow: column: 'responsive_documents', shape: (3,), dtype: object, data: [array(['The Salt Lake, seventy miles long and thirty-five wide, is situated three miles eight hundred feet above the sea. Quite diffe...; falling back to serialize as pickled python objects
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39275, ip=10.0.50.252)[0m Traceback (most recent call last):
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39275, ip=10.0.50.252)[0m   File "/home/ray/anaconda3/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py", line 774, in from_numpy
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39275, ip=10.0.50.252)[0m     return cls._from_numpy(arr)
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39275, ip=10.0.50.252)[0m            ^^^^^^^^^^^^^^^^^^^^
[36m(MapWorker(MapBatches(PromptEnhanc

### Running the Complete RAG Pipeline

Now we run the full 4-stage pipeline:

```
Embedder → ChromaDBReader → PromptEnhancer → Chat
```

**Note on `concurrency` vs `ActorPoolStrategy`**: 
- `concurrency=4` is shorthand for `ActorPoolStrategy(size=4)` with autoscaling
- Use explicit `ActorPoolStrategy` when you need precise control over actor count

We store the output in a Python variable for inspection. In production, you would typically stream results or write to storage.

In [15]:
# ============================================================================
# Complete RAG Pipeline Execution
# ============================================================================
# This runs the full 4-stage pipeline and stores results in memory for inspection.
#
# Pipeline Flow:
# [Prompts] -> [Embedder] -> [ChromaDBReader] -> [PromptEnhancer] -> [Chat] -> [Output]
#
# Resource Allocation:
# - Embedder: CPU only (MiniLM is efficient on CPU)
# - ChromaDBReader: CPU only (vector DB queries)
# - PromptEnhancer: CPU only (string operations)
# - Chat: GPU required (LLM inference)
#
# IMPORTANT: We use compute=ray.data.ActorPoolStrategy() instead of the deprecated
# 'concurrency' parameter. This is the recommended approach in Ray 2.53.0+.
# ============================================================================

output = data \
    .map_batches(
        Embedder, 
        fn_constructor_args=[EMBEDDER_MODEL], 
        # CPU-based embedding with the lightweight MiniLM model
        compute=ray.data.ActorPoolStrategy(size=4),
        batch_size=4
    ) \
    .map_batches(
        ChromaDBReader, 
        fn_constructor_args=['persistent_text_chunks', 3],  # Collection name, top_n results
        compute=ray.data.ActorPoolStrategy(size=2)
    ) \
    .map_batches(
        PromptEnhancer, 
        compute=ray.data.ActorPoolStrategy(size=2)
    ) \
    .map_batches(
        Chat, 
        compute=ray.data.ActorPoolStrategy(size=1),  # Single GPU actor
        fn_constructor_args=[CHAT_MODEL], 
        num_gpus=1,         # LLM inference needs GPU
        batch_size=4        # Keep batch small to avoid OOM
    ) \
    .take_batch(23)         # Retrieve all 23 prompts for inspection

print(f"Pipeline complete! Retrieved {len(output['prompt'])} responses.")

2026-01-22 01:47:37,232	INFO logging.py:397 -- Registered dataset logger for dataset dataset_203_0
2026-01-22 01:47:37,238	INFO streaming_executor.py:178 -- Starting execution of Dataset dataset_203_0. Full logs are in /tmp/ray/session_2026-01-21_22-31-19_329458_2360/logs/ray-data
2026-01-22 01:47:37,238	INFO streaming_executor.py:179 -- Execution plan of Dataset dataset_203_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> LimitOperator[limit=23] -> ActorPoolMapOperator[MapBatches(Embedder)] -> ActorPoolMapOperator[MapBatches(ChromaDBReader)] -> ActorPoolMapOperator[MapBatches(PromptEnhancer)] -> ActorPoolMapOperator[MapBatches(Chat)]
2026-01-22 01:47:37,563	INFO progress_bar.py:213 -- === Ray Data Progress {ListFiles} ===
2026-01-22 01:47:37,564	INFO progress_bar.py:215 -- ListFiles: Tasks: 1; Actors: 0; Queued blocks: 0 (0.0B); Resources: 1.0 CPU, 384.0MiB object store: Progress Completed 0 / ?
2026-01-22 01:47:37,565	INFO progress_bar.

[36m(MapWorker(MapBatches(Embedder)) pid=35475, ip=10.0.45.231)[0m Embedder initialized with model 'all-MiniLM-L6-v2' on device: cpu


[36m(MapWorker(MapBatches(Chat)) pid=39506, ip=10.0.50.252)[0m Device set to use cuda:0
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39505, ip=10.0.50.252)[0m Failed to convert column 'responsive_documents' into pyarrow array due to: Error converting data to Arrow: column: 'responsive_documents', shape: (23,), dtype: object, data: [array(['The Salt Lake, seventy miles long and thirty-five wide, is situated three miles eight hundred feet above the sea. Quite diff...; falling back to serialize as pickled python objects
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39505, ip=10.0.50.252)[0m Traceback (most recent call last):
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39505, ip=10.0.50.252)[0m   File "/home/ray/anaconda3/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py", line 774, in from_numpy
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39505, ip=10.0.50.252)[0m     return cls._from_numpy(arr)
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=39505, i

[36m(MapWorker(MapBatches(Chat)) pid=39506, ip=10.0.50.252)[0m Chat model 'Qwen/Qwen2.5-0.5B-Instruct' loaded on GPU


2026-01-22 01:47:47,704	INFO progress_bar.py:215 -- ListFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 1 / 1
2026-01-22 01:47:47,705	INFO progress_bar.py:215 -- ReadFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 23 / 23
2026-01-22 01:47:47,705	INFO progress_bar.py:215 -- limit=23: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 23 / 23
2026-01-22 01:47:47,706	INFO progress_bar.py:215 -- MapBatches(Embedder): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store; [0/1 objects local]: Progress Completed 23 / 23
2026-01-22 01:47:47,707	INFO progress_bar.py:215 -- MapBatches(ChromaDBReader): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store; [0/1 objects local]: Progress Completed 23 / 23
2026-01-22 01:47:47,708	INFO progress_bar.py:215 -- MapBat

Pipeline complete! Retrieved 23 responses.


### Inspecting the Results

Let's create a helper function to visualize the LLM's responses. This extracts and prints:
- The user's original question
- The LLM's generated response

> **Quality Check**: Analyze the responses to see if they're accurate. Does the model make things up when it doesn't have information?

In [16]:
# ============================================================================
# Helper Function: Visualize RAG Results
# ============================================================================
# This function formats the LLM responses for easy reading.
# It extracts and displays:
# 1. The original user question
# 2. The LLM's generated answer
# ============================================================================

def print_visual_eval(batch, max_responses=5):
    """
    Pretty-print the RAG pipeline results for evaluation.
    
    Args:
        batch: Output batch containing 'prompt' and 'responses'
        max_responses: Maximum number of responses to display (default 5)
    """
    print("=" * 80)
    print("RAG PIPELINE RESULTS")
    print("=" * 80)
    
    for i, (prompt, response) in enumerate(zip(batch['prompt'][:max_responses], 
                                                batch['responses'][:max_responses])):
        print(f"\n{'─' * 80}")
        print(f"QUESTION {i+1}: {prompt}")
        print(f"{'─' * 80}")
        
        # Extract the generated answer from the response structure
        # The response format is: [{'generated_text': [{'role': 'system', ...}, {'role': 'user', ...}, {'role': 'assistant', 'content': ...}]}]
        try:
            generated_text = response[0]['generated_text']
            # Find the assistant's response (last message in the conversation)
            assistant_response = None
            for msg in generated_text:
                if isinstance(msg, dict) and msg.get('role') == 'assistant':
                    assistant_response = msg.get('content', 'No content')
            
            if assistant_response:
                print(f"ANSWER: {assistant_response}")
            else:
                # Fallback: show the last message content
                print(f"ANSWER: {generated_text[-1].get('content', str(generated_text[-1]))}")
        except (KeyError, IndexError, TypeError) as e:
            print(f"ANSWER: [Error parsing response: {e}]")
            print(f"Raw response: {response}")
    
    print(f"\n{'=' * 80}")
    print(f"Displayed {min(max_responses, len(batch['prompt']))} of {len(batch['prompt'])} results")
    print("=" * 80)

In [17]:
print_visual_eval(output)

RAG PIPELINE RESULTS

────────────────────────────────────────────────────────────────────────────────
QUESTION 1: Describe the body of water in Utah?
────────────────────────────────────────────────────────────────────────────────
ANSWER: The body of water described in the excerpt is called the Dead Sea.

────────────────────────────────────────────────────────────────────────────────
QUESTION 2: Tell as much as you can about the robbery?
────────────────────────────────────────────────────────────────────────────────
ANSWER: Based on the excerpt provided, I can tell you that:

- It occurred on the 28th of last September.
- The amount stolen was £55,000.
- The perpetrator's name was likely Phileas Fogg, but there was no indication of his identity or any connection to the story.
- The robbery took place at the Bank of England.
- It involved the theft of money.
- No evidence of professional involvement was found during the investigation.
- There were suspicions that the thief was unprof

## Step 8: Production Deployment - Writing to Storage


In production, you typically want to:
1. Process larger datasets
2. Write results to persistent storage (Parquet, databases, data lakes)
3. Run the pipeline as a batch job

### Writing Results to Parquet

Instead of `.take_batch()` (which loads results into memory), we use `.write_parquet()` to stream results directly to storage. This:
- Handles datasets larger than memory
- Provides durability (results won't be lost if something fails)
- Enables downstream processing

> **Other Options**: `write_json()`, `write_csv()`, or writing to databases are also available.

In [18]:
# ============================================================================
# Production Deployment: Write Results to Parquet
# ============================================================================
# In production, you typically want to persist results to storage rather than
# loading them into memory. This approach:
#
# 1. Handles datasets larger than memory (streaming)
# 2. Provides durability (results survive failures)
# 3. Enables downstream processing (other jobs can read the output)
#
# write_parquet() streams data directly to storage without materializing
# the entire dataset in memory.
# ============================================================================

import shutil

# Clean up previous output (for demo purposes)
shutil.rmtree('/mnt/cluster_storage/batch_output_1.parquet', ignore_errors=True)

print("Running complete RAG pipeline and writing results to Parquet...")

# Run the complete pipeline and write to Parquet
ray.data.read_parquet('/mnt/cluster_storage/prompts.parquet') \
    .map_batches(
        Embedder, 
        fn_constructor_args=[EMBEDDER_MODEL], 
        compute=ray.data.ActorPoolStrategy(size=4),
        batch_size=4
    ) \
    .map_batches(
        ChromaDBReader, 
        fn_constructor_args=['persistent_text_chunks', 3], 
        compute=ray.data.ActorPoolStrategy(size=2)
    ) \
    .map_batches(
        PromptEnhancer, 
        compute=ray.data.ActorPoolStrategy(size=2)
    ) \
    .map_batches(
        Chat, 
        compute=ray.data.ActorPoolStrategy(size=1), 
        fn_constructor_args=[CHAT_MODEL], 
        num_gpus=1, 
        batch_size=4
    ) \
    .write_parquet('/mnt/cluster_storage/batch_output_1.parquet')

print("\nPipeline complete! Results written to /mnt/cluster_storage/batch_output_1.parquet")

2026-01-22 01:48:31,197	INFO logging.py:397 -- Registered dataset logger for dataset dataset_210_0
2026-01-22 01:48:31,202	INFO streaming_executor.py:178 -- Starting execution of Dataset dataset_210_0. Full logs are in /tmp/ray/session_2026-01-21_22-31-19_329458_2360/logs/ray-data
2026-01-22 01:48:31,202	INFO streaming_executor.py:179 -- Execution plan of Dataset dataset_210_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> ActorPoolMapOperator[MapBatches(Embedder)] -> ActorPoolMapOperator[MapBatches(ChromaDBReader)] -> ActorPoolMapOperator[MapBatches(PromptEnhancer)] -> ActorPoolMapOperator[MapBatches(Chat)] -> TaskPoolMapOperator[Write]


Running complete RAG pipeline and writing results to Parquet...


2026-01-22 01:48:31,520	INFO progress_bar.py:213 -- === Ray Data Progress {ListFiles} ===
2026-01-22 01:48:31,521	INFO progress_bar.py:215 -- ListFiles: Tasks: 1; Actors: 0; Queued blocks: 0 (0.0B); Resources: 1.0 CPU, 384.0MiB object store: Progress Completed 0 / ?
2026-01-22 01:48:31,522	INFO progress_bar.py:213 -- === Ray Data Progress {ReadFiles} ===
2026-01-22 01:48:31,523	INFO progress_bar.py:215 -- ReadFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 0 / ?
2026-01-22 01:48:31,524	INFO progress_bar.py:213 -- === Ray Data Progress {MapBatches(Embedder)} ===
2026-01-22 01:48:31,525	INFO progress_bar.py:215 -- MapBatches(Embedder): Tasks: 0; Actors: 4 (running=0, restarting=0, pending=4); Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store; [all objects local]: Progress Completed 0 / ?
2026-01-22 01:48:31,526	INFO progress_bar.py:213 -- === Ray Data Progress {MapBatches(ChromaDBReader)} ===
2026-01-22 01:48:31,

[36m(MapWorker(MapBatches(Embedder)) pid=35886, ip=10.0.45.231)[0m Embedder initialized with model 'all-MiniLM-L6-v2' on device: cpu


[36m(MapWorker(MapBatches(Chat)) pid=40081, ip=10.0.50.252)[0m Device set to use cuda:0
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=40080, ip=10.0.50.252)[0m Failed to convert column 'responsive_documents' into pyarrow array due to: Error converting data to Arrow: column: 'responsive_documents', shape: (23,), dtype: object, data: [array(['The Salt Lake, seventy miles long and thirty-five wide, is situated three miles eight hundred feet above the sea. Quite diff...; falling back to serialize as pickled python objects
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=40080, ip=10.0.50.252)[0m Traceback (most recent call last):
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=40080, ip=10.0.50.252)[0m   File "/home/ray/anaconda3/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py", line 774, in from_numpy
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=40080, ip=10.0.50.252)[0m     return cls._from_numpy(arr)
[36m(MapWorker(MapBatches(PromptEnhancer)) pid=40080, i

[36m(MapWorker(MapBatches(Chat)) pid=40081, ip=10.0.50.252)[0m Chat model 'Qwen/Qwen2.5-0.5B-Instruct' loaded on GPU


2026-01-22 01:48:41,640	INFO progress_bar.py:215 -- ListFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 1 / 1
2026-01-22 01:48:41,640	INFO progress_bar.py:215 -- ReadFiles: Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store: Progress Completed 23 / 23
2026-01-22 01:48:41,641	INFO progress_bar.py:215 -- MapBatches(Embedder): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store; [0/1 objects local]: Progress Completed 23 / 23
2026-01-22 01:48:41,642	INFO progress_bar.py:215 -- MapBatches(ChromaDBReader): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 0.0B object store; [0/1 objects local]: Progress Completed 23 / 23
2026-01-22 01:48:41,642	INFO progress_bar.py:215 -- MapBatches(PromptEnhancer): Tasks: 0; Actors: 0; Queued blocks: 0 (0.0B); Resources: 0.0 CPU, 84.8KiB object store; [0/1 objects local]: Progress Completed 23 / 23
2026-01-22 01:48


Pipeline complete! Results written to /mnt/cluster_storage/batch_output_1.parquet


## Summary and Key Takeaways


### What We Built

We implemented a complete RAG (Retrieval-Augmented Generation) pipeline using Ray Data:

```
┌─────────────────────────────────────────────────────────────────────────────────┐
│                         COMPLETE RAG PIPELINE                                   │
└─────────────────────────────────────────────────────────────────────────────────┘

  ┌──────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐
  │ Prompts  │───▶│ Embedder │───▶│ ChromaDB │───▶│  Prompt  │───▶│   Chat   │
  │ (Parquet)│    │  (CPU)   │    │  Reader  │    │ Enhancer │    │  (GPU)   │
  └──────────┘    └──────────┘    │  (CPU)   │    │  (CPU)   │    └──────────┘
                                  └──────────┘    └──────────┘         │
                                                                       ▼
                                                                 ┌──────────┐
                                                                 │  Output  │
                                                                 │ (Parquet)│
                                                                 └──────────┘
```

| Stage | Class | Purpose | Resource |
|-------|-------|---------|----------|
| 0 | `DocEmbedder` + `ChromaDBWriter` | Ingest documents into vector DB | CPU |
| 1 | `Embedder` | Convert user queries to vector embeddings | CPU |
| 2 | `ChromaDBReader` | Retrieve similar documents from vector DB | CPU |
| 3 | `PromptEnhancer` | Combine query with retrieved context | CPU |
| 4 | `Chat` | Generate responses using an LLM | **GPU** |

### Key Ray Data Concepts

```
┌─────────────────────────────────────────────────────────────────────────┐
│                    KEY CONCEPTS SUMMARY                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  1. map_batches()          → Apply transformations to data batches      │
│  2. Callable Classes       → __init__ for setup, __call__ for batches   │
│  3. ActorPoolStrategy      → Keep workers alive across batches          │
│  4. Resource Allocation    → num_gpus for GPU, defaults to CPU          │
│  5. Chaining               → .map_batches(...).map_batches(...)         │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
```

### API Note (Ray 2.53.0+)

- **DEPRECATED**: `concurrency=N` parameter
- **USE INSTEAD**: `compute=ray.data.ActorPoolStrategy(size=N)`

### Best Practices

| Practice | Why It Matters |
|----------|----------------|
| Load models in `__init__` | Avoids loading model for every batch |
| Use same embedding model for ingestion and queries | Ensures vector space compatibility |
| Use `ActorPoolStrategy` for stateful operations | Actors persist between batches |
| Use fractional GPUs when appropriate | Allows multiple models to share GPU |
| Write to storage with `write_parquet()` | Handles datasets larger than memory |
| Safe string formatting (not `eval()`) | Security and reliability |

### Common Issues and Solutions

| Issue | Solution |
|-------|----------|
| CUDA out of memory | Reduce batch_size, use fractional GPUs, or smaller models |
| Empty retrieval results | Check ingestion used same embedding model |
| Slow embedding | Use GPU for larger models, or more CPU actors |
| `concurrency` deprecation warning | Use `compute=ray.data.ActorPoolStrategy()` |

### Next Steps

- **Experiment with models**: Try `instructor-large` (with GPU) for better quality
- **Improve chunking**: Use smarter document chunking (semantic boundaries)
- **Add metadata**: Store chapter numbers, page references in ChromaDB
- **Implement caching**: Cache embeddings for frequently asked questions
- **Add evaluation**: Implement automated response quality evaluation

### References

- [Ray Data Documentation](https://docs.ray.io/en/latest/data/batch_inference.html)
- [Ray Data map_batches API](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html)
- [ChromaDB Documentation](https://docs.trychroma.com/)
- [Sentence Transformers](https://www.sbert.net/)