# RAG Pipeline Demo

This Jupyter Notebook implements a Retrieval-Augmented Generation (RAG) pipeline for a take-home project interview. It answers queries using two PDF datasets: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` and `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf`, handling questions about leadership, products, and specifications.

## Objective
- Combine retrieval and generative AI for accurate answers from PDFs.
- Use lightweight models (`lightonai/GTE-ModernColBERT-v1`, `google/flan-t5-base`) for CPU compatibility.
- Provide an interactive query interface with readable output.
- Use Few-Shot Prompting with fictional examples to ensure fairness and generalizability.

## Architecture
- **Knowledge Base**: PDFs are loaded with `PyPDFLoader`, split into chunks (300 chars, 50-char overlap) using `RecursiveCharacterTextSplitter`, and stored with source tracking for company filtering.
- **Semantic Layer**: Chunks and queries are embedded with `lightonai/GTE-ModernColBERT-v1`.
- **Retrieval**: `retrieve.ColBERT` fetches 15 chunks, reranked to 3 by `rank.rerank`.
- **Augmentation**: Top 3 chunks are combined with the query via a Few-Shot `PromptTemplate`.
- **Generation**: `google/flan-t5-base` generates answers based on the prompt.

## Updates
- Fixed JSON parsing error in `prompt_template` for Cursor compatibility.
- Implemented Few-Shot Prompting with fictional examples to avoid dataset bias.
- Added robust company filtering with fallback for NeoCompute.
- Enhanced output formatting with clear sections, suppressed progress bars, and concise progress messages.
- Fixed errors for NeoCompute indexing and query handling.

## Setup
- **Dependencies**: `pylate`, `langchain`, `transformers`, `google-colab`, `pypdf`, `hf_xet`.
- **Environment**: Google Colab, CPU-friendly.
- **Datasets**: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` (quantum computing), `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf` (AI hardware/software).

## Instructions
1. Cell 1: Install libraries.
2. Cell 2: Import libraries.
3. Cell 3: Define RAG pipeline.
4. Cell 4: Process PDFs.
5. Cell 5: Run interactive queries with formatted output.

## Cell 1: Install Dependencies

Installs Python libraries for the RAG pipeline, ensuring compatibility in Google Colab.

In [None]:
!pip install pylate langchain transformers google-colab
!pip install -U langchain-community pypdf hf_xet

## Cell 2: Import Libraries

Imports libraries for the pipeline and suppresses warnings for clean output.

In [None]:
import warnings
from pylate import models, indexes, retrieve, rank
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from google.colab import files
import os
from transformers import pipeline
import contextlib
import io
from tqdm import tqdm

tqdm.__init__ = lambda *args, **kwargs: None  # Disable tqdm progress bars
warnings.filterwarnings('ignore', category=UserWarning, module='pypdf._reader')
warnings.filterwarnings('ignore', category=DeprecationWarning, module='pypdf._reader')

## Cell 3: Define RAG Pipeline

Defines the RAG pipeline with a Few-Shot prompt using fictional examples, properly escaped for JSON.

In [None]:
def run_rag_pipeline(pdf_paths):
    """Initialize the RAG pipeline by processing PDFs, creating embeddings, and setting up models."""
    try:
        all_document_texts = []
        all_document_ids = []
        document_map = {}
        current_doc_id = 0
        document_sources = {}

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
        for pdf_path in pdf_paths:
            print(f'Processing PDF: {pdf_path}')
            if not os.path.exists(pdf_path):
                print(f'PDF not found: {pdf_path}')
                continue
            loader = PyPDFLoader(pdf_path)
            documents = loader.load()
            if not documents:
                print(f'No content loaded from {pdf_path}')
                continue
            chunks = text_splitter.split_documents(documents)
            document_texts = [chunk.page_content for chunk in chunks]
            document_ids = [str(i + current_doc_id) for i in range(len(document_texts))]
            document_map.update(dict(zip(document_ids, document_texts)))
            source_name = os.path.basename(pdf_path).split('_')[0]
            document_sources.update({doc_id: source_name for doc_id in document_ids})
            all_document_texts.extend(document_texts)
            all_document_ids.extend(document_ids)
            current_doc_id += len(document_texts)
            print(f'Created {len(document_texts)} chunks from {pdf_path}')

        if not all_document_texts:
            print('No documents processed.')
            return None

        print(f'Total chunks created: {len(all_document_texts)}')

        model_name = 'lightonai/GTE-ModernColBERT-v1'
        model = models.ColBERT(model_name_or_path=model_name)

        index_folder = 'pylate-index'
        index_name = 'pdf_index'
        index = indexes.Voyager(index_folder=index_folder, index_name=index_name, override=True)

        with contextlib.redirect_stderr(io.StringIO()):
            documents_embeddings = model.encode(
                all_document_texts,
                batch_size=32,
                is_query=False,
                show_progress_bar=False
            )
        index.add_documents(all_document_ids, documents_embeddings=documents_embeddings)

        retriever = retrieve.ColBERT(index=index)

        generator = pipeline('text2text-generation', model='google/flan-t5-base', max_length=300)

        prompt_template = r"""
You are an expert assistant answering questions based solely on the provided text. Follow the instructions and examples below to generate concise and accurate responses.

**Instructions**:
1. For role-based questions (e.g., CEO, CTO), return the full name of the individual in that role.
2. For list-based questions (e.g., products, compliance standards), return a comma-separated list of names, sorted alphabetically.
3. For detail-based questions (e.g., qubit count), return the exact detail as stated.
4. For other questions, provide a brief, relevant answer.
5. If the answer is not in the text, return: "The answer could not be found in the text."
6. Parse bullet points, sentences, or headings, ignoring irrelevant details unless requested.
7. Keep answers concise, avoiding extra explanations.

**Examples**:
- Text: "- Jane Smith, CEO: 20 years in tech innovation. - SkyNet: Advanced AI processor."
  Question: Who is the CEO of Horizon Innovations?
  Answer: Jane Smith

- Text: "- CloudPeak: Scalable cloud platform. - SecureVault: Data protection suite."
  Question: What are the products offered by TechTrend Solutions?
  Answer: CloudPeak, SecureVault

- Text: "- PCI DSS certified. - FedRAMP compliant."
  Question: What compliance standards does DataSafe Inc. follow?
  Answer: FedRAMP, PCI DSS

- Text: "- AlphaCore: 100-qubit photonic architecture."
  Question: What is the qubit count of AlphaCore?
  Answer: 100-qubit photonic architecture

- Text: "- Robert Lee, CTO: Expert in cloud systems."
  Question: Who is the CIO of Horizon Innovations?
  Answer: The answer could not be found in the text.

**Text**: {context}

**Question**: {question}

**Answer**:
"""
        PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

        return model, index, retriever, generator, PROMPT, document_map, document_sources

    except Exception as e:
        print(f'Error processing PDFs: {e}')
        return None

def query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query):
    """Process a user query with retrieval and generation, relying on Few-Shot prompt."""
    try:
        queries = [query]

        print('Encoding query...')
        with contextlib.redirect_stderr(io.StringIO()):
            query_embedding = model.encode(
                queries,
                batch_size=32,
                is_query=True,
                show_progress_bar=False
            )

        print('Retrieving documents...')
        with contextlib.redirect_stderr(io.StringIO()):
            top_k_initial = 15
            initial_results = retriever.retrieve(queries_embeddings=query_embedding, k=top_k_initial)

        if not initial_results or not initial_results[0]:
            print('No documents retrieved.')
            return None, 'No relevant documents found.'

        retrieved_doc_ids = [result['id'] for result in initial_results[0] if 'id' in result]
        if not retrieved_doc_ids:
            print('No document IDs after retrieval.')
            return None, 'No relevant documents found.'

        company = None
        filtered_doc_ids = retrieved_doc_ids
        if 'quantumcore' in query.lower():
            company = 'QuantumCore'
        elif 'neocompute' in query.lower():
            company = 'NeoCompute'
        if company:
            filtered_doc_ids = [doc_id for doc_id in retrieved_doc_ids if document_sources.get(doc_id) == company]
            if not filtered_doc_ids:
                print(f'No documents found for company: {company}. Falling back to all documents.')
                filtered_doc_ids = retrieved_doc_ids  # Fallback to unfiltered

        retrieved_documents = [document_map.get(doc_id, '') for doc_id in filtered_doc_ids]
        retrieved_documents = [doc for doc in retrieved_documents if doc]

        if not retrieved_documents:
            print('No valid documents after filtering.')
            return None, 'No relevant documents found.'

        print('Reranking documents...')
        with contextlib.redirect_stderr(io.StringIO()):
            reranked_results = rank.rerank(
                documents_ids=[filtered_doc_ids],
                queries_embeddings=query_embedding,
                documents_embeddings=[model.encode(retrieved_documents, is_query=False, show_progress_bar=False)]
            )

        reranked_doc_ids = []
        if reranked_results and isinstance(reranked_results[0], list):
            for result in reranked_results[0]:
                if isinstance(result, dict) and 'id' in result:
                    reranked_doc_ids.append(result['id'])
                elif isinstance(result, str):
                    reranked_doc_ids.append(result)
        else:
            reranked_doc_ids = filtered_doc_ids[:3]

        if not reranked_doc_ids:
            print('No document IDs after reranking.')
            return None, 'No relevant documents found.'

        reranked_documents = [document_map.get(doc_id, '') for doc_id in reranked_doc_ids]
        reranked_documents = [doc for doc in reranked_documents if doc]

        max_context_length = 600
        context = '\n'.join(reranked_documents[:3])[:max_context_length]
        if not context:
            print('No context generated.')
            return None, 'No relevant context found.'

        prompt_text = PROMPT.format(context=context, question=query)

        print('Generating answer...')
        response = generator(prompt_text)[0]['generated_text']
        answer = response.strip()

        if not answer or answer.lower() == 'none':
            answer = 'The answer could not be found in the text.'

        return context, answer

    except Exception as e:
        print(f'Error processing query: {e}')
        return None, 'Error processing query.'

## Cell 4: Process PDFs

Initializes the RAG pipeline with robust PDF handling.

In [None]:
pdf_paths = [
    '/data/QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf',
    '/data/NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf'
]

if not all(os.path.exists(pdf_path) for pdf_path in pdf_paths):
    print('Please upload your PDF files (QuantumCore_v1.pdf and/or NeoCompute_v2.pdf):')
    uploaded = files.upload()
    pdf_paths = [f'/content/{name}' for name in uploaded.keys()]

result = run_rag_pipeline(pdf_paths)
if result:
    model, index, retriever, generator, PROMPT, document_map, document_sources = result
    print('RAG pipeline initialized successfully.')
else:
    print('Failed to initialize RAG pipeline.')

## Cell 5: Interactive Querying

Interactive query interface with formatted output and error handling.

**Example Queries**:
- Who is the CEO of QuantumCore Solutions? → 'Dr. Elena Ruiz'
- What are the products offered by NeoCompute Technologies? → 'NeoCloud, NeoSecure'
- Who is the CIO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- Who is the CEO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- What is the qubit count of QubitCore? → '50-qubit superconducting architecture'
- What compliance standards does NeoCompute follow? → 'ISO/IEC 27001, SOC 2 Type II'

In [None]:
def interactive_query():
    """Run an interactive query loop with formatted output."""
    print('=====================================')
    print('Interactive RAG Query Interface')
    print('=====================================')
    print('Enter your query (or type "exit" to quit):
')
    while True:
        query = input('Query: ').strip()
        if query.lower() == 'exit':
            print('\nExiting query interface.')
            break
        if not result:
            print('\nError: RAG pipeline not initialized. Please run Cell 4 first.')
            break
        context, answer = query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query)
        print('\n====================================')
        print(f'Query: {query}')
        print('====================================')
        print('\n**Context Retrieved**:\n')
        if context is None:
            print('    Error: No context retrieved.')
        else:
            indented_context = context.replace('\n', '\n    ')
            print(f'    {indented_context}')
        print('\n---')
        print('\n**Answer**:\n')
        print(f'    {answer}')
        print('\n====================================\n')
        print('Enter your next query (or type "exit" to quit):
')

interactive_query()