# RAG Pipeline Demo

This Jupyter Notebook implements a Retrieval-Augmented Generation (RAG) pipeline for a take-home project interview. It answers queries using two PDF datasets: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` and `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf`, handling questions about leadership, products, and specifications.

## Objective
- Combine retrieval and generative AI for accurate answers from PDFs.
- Use lightweight models (`lightonai/GTE-ModernColBERT-v1`, `google/flan-t5-base`) for CPU compatibility.
- Apply regex post-processing for roles (e.g., CEO names) and products (e.g., product lists) with deduplication.
- Provide an interactive query interface with readable output.

## Architecture
- **Knowledge Base**: PDFs are loaded with `PyPDFLoader`, split into chunks (300 chars, 50-char overlap) using `RecursiveCharacterTextSplitter`, and stored with source tracking for company filtering.
- **Semantic Layer**: Chunks and queries are embedded with `lightonai/GTE-ModernColBERT-v1`.
- **Retrieval**: `retrieve.ColBERT` fetches 15 chunks, reranked to 3 by `rank.rerank`.
- **Augmentation**: Top 3 chunks are combined with the query via `PromptTemplate`.
- **Generation**: `google/flan-t5-base` generates answers, post-processed for roles, products, or lists.

## Updates
- Improved role extraction regex for names (e.g., 'Dr. Elena Ruiz' for CEO).
- Refined product regex to filter non-products (e.g., 'Compliance').
- Added company filtering (QuantumCore/NeoCompute) based on query keywords.
- Enhanced output formatting with clear sections, suppressed progress bars, and concise progress messages.

## Setup
- **Dependencies**: `pylate`, `langchain`, `transformers`, `google-colab`, `pypdf`, `hf_xet`.
- **Environment**: Google Colab, CPU-friendly.
- **Datasets**: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` (quantum computing), `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf` (AI hardware/software).

## Instructions
1. Cell 1: Install libraries.
2. Cell 2: Import libraries.
3. Cell 3: Define RAG pipeline.
4. Cell 4: Process PDFs.
5. Cell 5: Run interactive queries with formatted output.

## Cell 1: Install Dependencies

Installs Python libraries for the RAG pipeline, ensuring compatibility in Google Colab. Includes `pylate` (ColBERT embeddings/retrieval), `langchain` (document processing), `transformers` (FLAN-T5), `google-colab` (Colab utilities), and `langchain-community`, `pypdf`, `hf_xet` for PDF processing and Hugging Face integration.

In [None]:
!pip install pylate langchain transformers google-colab
!pip install -U langchain-community pypdf hf_xet

## Cell 2: Import Libraries

Imports libraries for the pipeline and suppresses warnings for clean output. Includes `pylate` (embedding/retrieval), `langchain` (PDF loading, text splitting, prompts), `transformers` (FLAN-T5), `google.colab.files` (Colab uploads), `os` (file paths), `re` (regex), and suppresses `pypdf` warnings.

In [None]:
import warnings
from pylate import models, indexes, retrieve, rank
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from google.colab import files
import os
from transformers import pipeline
import re

warnings.filterWarnings('ignore', category=UserWarning, module='pypdf._reader')
warnings.filterWarnings('ignore', category=DeprecationWarning, module='pypdf._reader')

## Cell 3: Define RAG Pipeline

Defines the RAG pipeline functions:
- `run_rag_pipeline`: Loads, chunks, embeds, and indexes PDFs, initializes retriever and generator, tracks source PDFs for company filtering.
- `query_rag`: Encodes queries, retrieves/reranks chunks, generates answers, and post-processes for roles, products, or lists. Features:
  - Company filtering (QuantumCore/NeoCompute).
  - Improved role regex (e.g., CEO names).
  - Refined product regex (excludes non-products).
  - Suppressed progress bars, custom progress messages.

Ensures valid JSON with proper string escaping.

In [None]:
def run_rag_pipeline(pdf_paths):
    """Initialize the RAG pipeline by processing PDFs, creating embeddings, and setting up models."""
    try:
        all_document_texts = []
        all_document_ids = []
        document_map = {}
        current_doc_id = 0
        document_sources = {}

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
        for pdf_path in pdf_paths:
            print(f'Processing PDF: {pdf_path}')
            loader = PyPDFLoader(pdf_path)
            documents = loader.load()
            chunks = text_splitter.split_documents(documents)
            document_texts = [chunk.page_content for chunk in chunks]
            document_ids = [str(i + current_doc_id) for i in range(len(document_texts))]
            document_map.update(dict(zip(document_ids, document_texts)))
            source_name = os.path.basename(pdf_path).split('_')[0]
            document_sources.update({doc_id: source_name for doc_id in document_ids})
            all_document_texts.extend(document_texts)
            all_document_ids.extend(document_ids)
            current_doc_id += len(document_texts)
            print(f'Created {len(document_texts)} chunks from {pdf_path}')

        print(f'Total chunks created: {len(all_document_texts)}')

        model_name = 'lightonai/GTE-ModernColBERT-v1'
        model = models.ColBERT(model_name_or_path=model_name)

        index_folder = 'pylate-index'
        index_name = 'pdf_index'
        index = indexes.Voyager(index_folder=index_folder, index_name=index_name, override=True)

        documents_embeddings = model.encode(
            all_document_texts,
            batch_size=32,
            is_query=False,
            show_progress_bar=False
        )
        index.add_documents(all_document_ids, documents_embeddings=documents_embeddings)

        retriever = retrieve.ColBERT(index=index)

        generator = pipeline('text2text-generation', model='google/flan-t5-base', max_length=300)

        prompt_template = "Using only the provided text, answer the user's question with a concise and accurate response. For questions about specific roles (e.g., CEO, CTO, CFO), return only the full name of the individual in that role. For questions about lists (e.g., products), return all items as a comma-separated list of names only. Exclude any details not directly relevant to the question, such as technical specifications, unless explicitly requested. If the answer is not in the text, respond with 'The answer could not be found in the text.'\n\nText: {context}\n\nQuestion: {question}\n\nAnswer:"
        PROMPT = PromptTemplate(template=prompt_template, input_variables=['context', 'question'])

        return model, index, retriever, generator, PROMPT, document_map, document_sources

    except Exception as e:
        print(f'Error processing PDFs: {e}')
        return None

def query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query):
    """Process a user query with retrieval, generation, and post-processing."""
    try:
        queries = [query]

        print('Encoding query...')
        query_embedding = model.encode(
            queries,
            batch_size=32,
            is_query=True,
            show_progress_bar=False
        )

        print('Retrieving documents...')
        top_k_initial = 15
        initial_results = retriever.retrieve(queries_embeddings=query_embedding, k=top_k_initial)
        retrieved_doc_ids = [result['id'] for result in initial_results[0]]

        company = None
        if 'quantumcore' in query.lower():
            company = 'QuantumCore'
        elif 'neocompute' in query.lower():
            company = 'NeoCompute'
        if company:
            retrieved_doc_ids = [doc_id for doc_id in retrieved_doc_ids if document_sources.get(doc_id) == company]
        retrieved_documents = [document_map[doc_id] for doc_id in retrieved_doc_ids]

        print('Reranking documents...')
        reranked_results = rank.rerank(
            documents_ids=[retrieved_doc_ids],
            queries_embeddings=query_embedding,
            documents_embeddings=[model.encode(retrieved_documents, is_query=False, show_progress_bar=False)]
        )

        reranked_doc_ids = []
        if reranked_results and isinstance(reranked_results[0], list):
            for result in reranked_results[0]:
                if isinstance(result, dict) and 'id' in result:
                    reranked_doc_ids.append(result['id'])
                elif isinstance(result, str):
                    reranked_doc_ids.append(result)
        else:
            reranked_doc_ids = retrieved_doc_ids

        reranked_documents = [document_map[doc_id] for doc_id in reranked_doc_ids]

        max_context_length = 600
        context = '\n'.join(reranked_documents[:3])[:max_context_length]
        prompt_text = PROMPT.format(context=context, question=query)

        print('Generating answer...')
        response = generator(prompt_text)[0]['generated_text']
        answer = response.strip()

        non_product_terms = {'Compliance', 'Cooling', 'Features', 'Storage', 'Networking', 'Frameworks', 'Uptime', 'Encryption', 'Certifications', 'Software'}
        if any(role in query.lower() for role in ['ceo', 'cto', 'cfo', 'cio']):
            role = next((r for r in ['CEO', 'CTO', 'CFO', 'CIO'] if r.lower() in query.lower()), None)
            if role:
                match = re.search(r'- ([^,]+?),\s*'+role+r'\s*:', context, re.IGNORECASE)
                if match:
                    answer = match.group(1).strip()
                else:
                    answer = 'The answer could not be found in the text.'
        elif 'product' in query.lower():
            product_names = re.findall(r'- (\w+): (?:Quantum|Cloud-based|processing unit|platform|cryptographic security|cloud-native|security modules)', context, re.IGNORECASE)
            product_names = [name for name in product_names if name not in non_product_terms]
            if product_names:
                answer = ', '.join(sorted(set(product_names)))
            else:
                answer = 'The answer could not be found in the text.'
        elif 'compliance' in query.lower():
            compliance_standards = re.findall(r'- ([^\n]+)', context, re.IGNORECASE)
            compliance_standards = [std.strip() for std in compliance_standards if any(term in std for term in ['SOC', 'ISO', 'HIPAA', 'GDPR'])]
            if compliance_standards:
                answer = ', '.join(sorted(set(compliance_standards)))
            else:
                answer = 'The answer could not be found in the text.'
        elif ', ' in answer:
            items = set(answer.split(', '))
            items = [item for item in items if item not in non_product_terms]
            answer = ', '.join(sorted(items)) if items else 'The answer could not be found in the text.'

        return context, answer

    except Exception as e:
        print(f'Error processing query: {e}')
        return None, 'Error processing query.'

## Cell 4: Process PDFs

Specifies PDF paths (`QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf`, `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf`) and initializes the RAG pipeline. Checks for missing files and prompts for uploads in Colab. Initializes ColBERT model, Voyager index, retriever, FLAN-T5 generator, prompt template, document map, and source tracking.

In [None]:
pdf_paths = [
    '/data/QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf',
    '/data/NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf'
]

if not all(os.path.exists(pdf_path) for pdf_path in pdf_paths):
    print('Please upload your PDF files (QuantumCore_v1.pdf and/or NeoCompute_v2.pdf):')
    uploaded = files.upload()
    pdf_paths = list(uploaded.keys())

result = run_rag_pipeline(pdf_paths)
if result:
    model, index, retriever, generator, PROMPT, document_map, document_sources = result
    print('RAG pipeline initialized successfully.')
else:
    print('Failed to initialize RAG pipeline.')

## Cell 5: Interactive Querying

Provides an interactive query interface with formatted output. Features:
- Clear sections with headers, horizontal lines, and indentation.
- Suppressed progress bars, replaced with concise messages.
- Example queries test roles, products, specifications, and compliance.

**Example Queries**:
- Who is the CEO of QuantumCore Solutions? → 'Dr. Elena Ruiz'
- What are the products offered by NeoCompute Technologies? → 'NeoCloud, NeoSecure'
- Who is the CIO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- Who is the CEO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- What is the qubit count of QubitCore? → '50-qubit superconducting architecture'
- What compliance standards does NeoCompute follow? → 'SOC 2 Type II, ISO/IEC 27001, HIPAA, GDPR'

In [None]:
def interactive_query():
    """Run an interactive query loop with formatted output."""
    print('====================================')
    print('RAG Query Interface')
    print('====================================')
    print('Enter your query (or type "exit" to quit):\n')
    while True:
        query = input('Query: ')
        if query.lower() == 'exit':
            print('\nExiting query interface.')
            break
        if not result:
            print('\nRAG pipeline not initialized. Please run Cell 4 first.')
            break
        context, answer = query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query)
        print('\n====================================')
        print(f'Query: {query}')
        print('====================================')
        print('\n**Context Retrieved**:\n')
        print(f'    {context.replace("\n", "\n    ")}')
        print('\n---')
        print('\n**Answer**:\n')
        print(f'    {answer}')
        print('\n====================================\n')
        print('Enter your next query (or type "exit" to quit):\n')

interactive_query()