# RAG Pipeline Demo

This Jupyter Notebook implements a Retrieval-Augmented Generation (RAG) pipeline for a take-home project interview. It answers queries using two PDF datasets: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` and `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf`, handling questions about leadership, products, and specifications.

## Objective
- Combine retrieval and generative AI to provide accurate answers from PDFs.
- Use lightweight models (`lightonai/GTE-ModernColBERT-v1`, `google/flan-t5-base`) for CPU compatibility.
- Provide an interactive query interface with readable output.
- Use Few-Shot Prompting with fictional examples to ensure fairness and generalizability.
- Include detailed code comments for improved readability.

## Architecture
- **Knowledge Base**: PDFs are loaded with `PyPDFLoader`, split into chunks (300 chars, 50-char overlap) using `RecursiveCharacterTextSplitter`, and stored with source tracking for company filtering.
- **Semantic Layer**: Chunks and queries are embedded with `lightonai/GTE-ModernColBERT-v1`.
- **Retrieval**: `retrieve.ColBERT` fetches 15 chunks, reranked to 3 by `rank.rerank`.
- **Augmentation**: Top 3 chunks are combined with the query via a Few-Shot `PromptTemplate`.
- **Generation**: `google/flan-t5-base` generates answers based on the prompt.

## Updates
- Added detailed English comments in code cells for better readability.
- Fixed JSON parsing error in `prompt_template` for Cursor compatibility.
- Implemented Few-Shot Prompting with fictional examples to avoid dataset bias.
- Added robust company filtering with fallback for NeoCompute.
- Enhanced output formatting with clear sections, suppressed progress bars, and concise progress messages.
- Fixed errors for NeoCompute indexing and query handling.

## Setup
- **Dependencies**: `pylate`, `langchain`, `transformers`, `google-colab`, `pypdf`, `hf_xet`.
- **Environment**: Google Colab, CPU-friendly.
- **Datasets**: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` (quantum computing), `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf` (AI hardware/software).

## Instructions
1. Cell 1: Install libraries.
2. Cell 2: Import libraries.
3. Cell 3: Define RAG pipeline with detailed comments.
4. Cell 4: Process PDFs with explanatory comments.
5. Cell 5: Run interactive queries with formatted output and comments.

## Cell 1: Install Dependencies

Installs Python libraries required for the RAG pipeline, ensuring compatibility in Google Colab.

In [None]:
!pip install pylate langchain transformers google-colab
!pip install -U langchain-community pypdf hf_xet

## Cell 2: Import Libraries

Imports libraries for the pipeline and suppresses warnings for clean output.

In [None]:
# Import required libraries for the RAG pipeline
import warnings
from pylate import models, indexes, retrieve, rank  # Pylate modules for embedding, indexing, and retrieval
from langchain.document_loaders import PyPDFLoader  # Load PDFs
from langchain.text_splitter import RecursiveCharacterTextSplitter  # Split text into chunks
from langchain.prompts import PromptTemplate  # Create prompt templates
from google.colab import files  # Handle file uploads in Colab
import os  # File system operations
from transformers import pipeline  # Hugging Face pipeline for text generation
import contextlib  # Redirect stderr for clean output
import io  # StringIO for stderr redirection
from tqdm import tqdm  # Progress bar (disabled)

# Disable tqdm progress bars for cleaner output
tqdm.__init__ = lambda *args, **kwargs: None

# Suppress PDF reader warnings to avoid cluttering output
warnings.filterwarnings('ignore', category=UserWarning, module='pypdf._reader')
warnings.filterwarnings('ignore', category=DeprecationWarning, module='pypdf._reader')

## Cell 3: Define RAG Pipeline

Defines the RAG pipeline functions (`run_rag_pipeline`, `query_rag`) with a Few-Shot prompt using fictional examples. Includes detailed comments to explain each step.

In [None]:
def run_rag_pipeline(pdf_paths):
    """Initialize the RAG pipeline by processing PDFs, creating embeddings, and setting up models."""
    try:
        # --- Initialize Data Structures ---
        # Lists to store document texts and IDs
        all_document_texts = []
        all_document_ids = []
        # Dictionary to map document IDs to their text content
        document_map = {}
        # Counter for generating unique document IDs
        current_doc_id = 0
        # Dictionary to track the source (QuantumCore/NeoCompute) of each document
        document_sources = {}

        # --- Process PDFs ---
        # Initialize text splitter with 300-char chunks and 50-char overlap
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
        for pdf_path in pdf_paths:
            print(f'Processing PDF: {pdf_path}')
            # Check if PDF file exists
            if not os.path.exists(pdf_path):
                print(f'PDF not found: {pdf_path}')
                continue
            # Load PDF content
            loader = PyPDFLoader(pdf_path)
            documents = loader.load()
            # Verify that content was loaded
            if not documents:
                print(f'No content loaded from {pdf_path}')
                continue
            # Split documents into chunks
            chunks = text_splitter.split_documents(documents)
            document_texts = [chunk.page_content for chunk in chunks]
            # Generate unique IDs for chunks
            document_ids = [str(i + current_doc_id) for i in range(len(document_texts))]
            # Map IDs to texts
            document_map.update(dict(zip(document_ids, document_texts)))
            # Extract source name (QuantumCore or NeoCompute) from filename
            source_name = os.path.basename(pdf_path).split('_')[0]
            # Associate IDs with source
            document_sources.update({doc_id: source_name for doc_id in document_ids})
            # Add texts and IDs to main lists
            all_document_texts.extend(document_texts)
            all_document_ids.extend(document_ids)
            # Update ID counter
            current_doc_id += len(document_texts)
            print(f'Created {len(document_texts)} chunks from {pdf_path}')

        # --- Validate Document Processing ---
        if not all_document_texts:
            print('No documents processed.')
            return None

        print(f'Total chunks created: {len(all_document_texts)}')

        # --- Set Up Embedding Model ---
        # Initialize ColBERT model for embeddings
        model_name = 'lightonai/GTE-ModernColBERT-v1'
        model = models.ColBERT(model_name_or_path=model_name)

        # --- Create Index for Retrieval ---
        # Set up Voyager index to store embeddings
        index_folder = 'pylate-index'
        index_name = 'pdf_index'
        index = indexes.Voyager(index_folder=index_folder, index_name=index_name, override=True)

        # --- Generate and Store Embeddings ---
        # Encode document texts into embeddings, suppressing stderr for clean output
        with contextlib.redirect_stderr(io.StringIO()):
            documents_embeddings = model.encode(
                all_document_texts,
                batch_size=32,
                is_query=False,
                show_progress_bar=False
            )
        # Add embeddings to the index
        index.add_documents(all_document_ids, documents_embeddings=documents_embeddings)

        # --- Initialize Retriever ---
        # Set up ColBERT retriever for fetching relevant chunks
        retriever = retrieve.ColBERT(index=index)

        # --- Initialize Generator ---
        # Set up FLAN-T5 model for text generation
        generator = pipeline('text2text-generation', model='google/flan-t5-base', max_length=300)

        # --- Define Few-Shot Prompt ---
        # Create a prompt template with fictional examples to guide answer extraction
        prompt_template = r"""
You are an expert assistant answering questions based solely on the provided text. Follow the instructions and examples below to generate concise and accurate responses.

**Instructions**:
1. For role-based questions (e.g., CEO, CTO), return the full name of the individual in that role.
2. For list-based questions (e.g., products, compliance standards), return a comma-separated list of names, sorted alphabetically.
3. For detail-based questions (e.g., qubit count), return the exact detail as stated.
4. For other questions, provide a brief, relevant answer.
5. If the answer is not in the text, return: "The answer could not be found in the text."
6. Parse bullet points, sentences, or headings, ignoring irrelevant details unless requested.
7. Keep answers concise, avoiding extra explanations.

**Examples**:
- Text: "- Jane Smith, CEO: 20 years in tech innovation. - SkyNet: Advanced AI processor."
  Question: Who is the CEO of Horizon Innovations?
  Answer: Jane Smith

- Text: "- CloudPeak: Scalable cloud platform. - SecureVault: Data protection suite."
  Question: What are the products offered by TechTrend Solutions?
  Answer: CloudPeak, SecureVault

- Text: "- PCI DSS certified. - FedRAMP compliant."
  Question: What compliance standards does DataSafe Inc. follow?
  Answer: FedRAMP, PCI DSS

- Text: "- AlphaCore: 100-qubit photonic architecture."
  Question: What is the qubit count of AlphaCore?
  Answer: 100-qubit photonic architecture

- Text: "- Robert Lee, CTO: Expert in cloud systems."
  Question: Who is the CIO of Horizon Innovations?
  Answer: The answer could not be found in the text.

**Text**: {context}

**Question**: {question}

**Answer**:
"""
        # Create PromptTemplate object with input variables
        PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

        # --- Return Pipeline Components ---
        # Return all components needed for querying
        return model, index, retriever, generator, PROMPT, document_map, document_sources

    except Exception as e:
        # Handle any errors during pipeline setup
        print(f'Error processing PDFs: {e}')
        return None

def query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query):
    """Process a user query with retrieval and generation, relying on Few-Shot prompt."""
    try:
        # --- Prepare Query ---
        # Convert query to a list for batch processing
        queries = [query]

        # --- Encode Query ---
        print('Encoding query...')
        # Encode the query into an embedding, suppressing stderr
        with contextlib.redirect_stderr(io.StringIO()):
            query_embedding = model.encode(
                queries,
                batch_size=32,
                is_query=True,
                show_progress_bar=False
            )

        # --- Retrieve Documents ---
        print('Retrieving documents...')
        # Retrieve top 15 relevant document chunks
        with contextlib.redirect_stderr(io.StringIO()):
            top_k_initial = 15
            initial_results = retriever.retrieve(queries_embeddings=query_embedding, k=top_k_initial)

        # Check if any documents were retrieved
        if not initial_results or not initial_results[0]:
            print('No documents retrieved.')
            return None, 'No relevant documents found.'

        # Extract document IDs from results
        retrieved_doc_ids = [result['id'] for result in initial_results[0] if 'id' in result]
        if not retrieved_doc_ids:
            print('No document IDs after retrieval.')
            return None, 'No relevant documents found.'

        # --- Filter by Company ---
        # Determine company based on query keywords
        company = None
        filtered_doc_ids = retrieved_doc_ids
        if 'quantumcore' in query.lower():
            company = 'QuantumCore'
        elif 'neocompute' in query.lower():
            company = 'NeoCompute'
        if company:
            # Filter documents by company source
            filtered_doc_ids = [doc_id for doc_id in retrieved_doc_ids if document_sources.get(doc_id) == company]
            if not filtered_doc_ids:
                # Fallback to all documents if no company-specific chunks found
                print(f'No documents found for company: {company}. Falling back to all documents.')
                filtered_doc_ids = retrieved_doc_ids

        # --- Retrieve Document Texts ---
        # Get text content for filtered document IDs
        retrieved_documents = [document_map.get(doc_id, '') for doc_id in filtered_doc_ids]
        retrieved_documents = [doc for doc in retrieved_documents if doc]

        # Check if any valid documents remain
        if not retrieved_documents:
            print('No valid documents after filtering.')
            return None, 'No relevant documents found.'

        # --- Rerank Documents ---
        print('Reranking documents...')
        # Rerank documents to select top 3 most relevant
        with contextlib.redirect_stderr(io.StringIO()):
            reranked_results = rank.rerank(
                documents_ids=[filtered_doc_ids],
                queries_embeddings=query_embedding,
                documents_embeddings=[model.encode(retrieved_documents, is_query=False, show_progress_bar=False)]
            )

        # Extract reranked document IDs
        reranked_doc_ids = []
        if reranked_results and isinstance(reranked_results[0], list):
            for result in reranked_results[0]:
                if isinstance(result, dict) and 'id' in result:
                    reranked_doc_ids.append(result['id'])
                elif isinstance(result, str):
                    reranked_doc_ids.append(result)
        else:
            # Fallback to top 3 filtered IDs if reranking fails
            reranked_doc_ids = filtered_doc_ids[:3]

        # Check if reranked IDs are valid
        if not reranked_doc_ids:
            print('No document IDs after reranking.')
            return None, 'No relevant documents found.'

        # Get texts for reranked documents
        reranked_documents = [document_map.get(doc_id, '') for doc_id in reranked_doc_ids]
        reranked_documents = [doc for doc in reranked_documents if doc]

        # --- Build Context ---
        # Combine top 3 documents into context, limiting to 600 characters
        max_context_length = 600
        context = '\n'.join(reranked_documents[:3])[:max_context_length]
        if not context:
            print('No context generated.')
            return None, 'No relevant context found.'

        # --- Generate Prompt ---
        # Format the prompt with context and question
        prompt_text = PROMPT.format(context=context, question=query)

        # --- Generate Answer ---
        print('Generating answer...')
        # Use FLAN-T5 to generate the answer
        response = generator(prompt_text)[0]['generated_text']
        answer = response.strip()

        # Handle empty or invalid responses
        if not answer or answer.lower() == 'none':
            answer = 'The answer could not be found in the text.'

        # Return context and answer
        return context, answer

    except Exception as e:
        # Handle any errors during query processing
        print(f'Error processing query: {e}')
        return None, 'Error processing query.'

## Cell 4: Process PDFs

Initializes the RAG pipeline by processing PDFs, with comments explaining each step.

In [None]:
# Define paths to PDF datasets
pdf_paths = [
    '/data/QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf',
    '/data/NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf'
]

# --- Handle Missing PDFs ---
# Check if PDFs exist; prompt for upload if not found
if not all(os.path.exists(pdf_path) for pdf_path in pdf_paths):
    print('Please upload your PDF files (QuantumCore_v1.pdf and/or NeoCompute_v2.pdf):')
    # Allow user to upload PDFs in Colab
    uploaded = files.upload()
    # Update paths to uploaded files
    pdf_paths = [f'/content/{name}' for name in uploaded.keys()]

# --- Initialize Pipeline ---
# Run the RAG pipeline with the PDF paths
result = run_rag_pipeline(pdf_paths)
if result:
    # Unpack pipeline components if successful
    model, index, retriever, generator, PROMPT, document_map, document_sources = result
    print('RAG pipeline initialized successfully.')
else:
    # Report failure if pipeline initialization fails
    print('Failed to initialize RAG pipeline.')

## Cell 5: Interactive Querying

Provides an interactive query interface with formatted output and detailed comments.

**Example Queries**:
- Who is the CEO of QuantumCore Solutions? → 'Dr. Elena Ruiz'
- What are the products offered by NeoCompute Technologies? → 'NeoCloud, NeoSecure'
- Who is the CIO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- Who is the CEO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- What is the qubit count of QubitCore? → '50-qubit superconducting architecture'
- What compliance standards does NeoCompute follow? → 'ISO/IEC 27001, SOC 2 Type II'

In [None]:
def interactive_query():
    """Run an interactive query loop with formatted output for user queries."""
    # --- Display Interface Header ---
    # Print a clean header for the query interface
    print('=====================================')
    print('Interactive RAG Query Interface')
    print('=====================================')
    print('Enter your query (or type "exit" to quit):
')

    # --- Query Loop ---
    while True:
        # Prompt user for a query
        query = input('Query: ').strip()
        # Check if user wants to exit
        if query.lower() == 'exit':
            print('\nExiting query interface.')
            break
        # Verify that pipeline is initialized
        if not result:
            print('\nError: RAG pipeline not initialized. Please run Cell 4 first.')
            break

        # --- Process Query ---
        # Run the query through the RAG pipeline
        context, answer = query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query)

        # --- Display Results ---
        # Print formatted query results
        print('\n====================================')
        print(f'Query: {query}')
        print('====================================')
        print('\n**Context Retrieved**:\n')
        if context is None:
            # Handle case where no context was retrieved
            print('    Error: No context retrieved.')
        else:
            # Format context with indentation for readability
            indented_context = context.replace('\n', '\n    ')
            print(f'    {indented_context}')
        print('\n---')
        print('\n**Answer**:\n')
        # Display the generated answer
        print(f'    {answer}')
        print('\n====================================\n')
        print('Enter your next query (or type "exit" to quit):
')

# --- Run Interactive Query Interface ---
interactive_query()