# Retrieval-Augmented Generation (RAG) Pipeline Demo

This Jupyter Notebook demonstrates a minimal Retrieval-Augmented Generation (RAG) pipeline designed for a take-home project interview.

The system answers user queries based on two PDF datasets: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` and `NeoCompute_Technologies_RAG_Demo_Dataset_v3.pdf`, showcasing versatility for varied queries.

## Objective
- Combine retrieval and generative AI to answer queries grounded in PDF content.
- Use small, CPU-friendly models (`lightonai/GTE-ModernColBERT-v1`, `google/flan-t5-base`).
- Minimize post-processing with regex for roles/products and deduplication.
- Support interactive querying for demo purposes.

## Architecture
- **Knowledge Base**: PDFs are loaded (`PyPDFLoader`), split into chunks (`RecursiveCharacterTextSplitter`), and stored in memory.
- **Semantic Layer**: Chunks and queries are embedded using `lightonai/GTE-ModernColBERT-v1` for semantic comparison.
- **Retrieval System**: `retrieve.ColBERT` fetches top 15 chunks, reranked to top 3 (`rank.rerank`).
- **Augmentation**: Retrieved chunks (600-char limit) are combined with the query via `PromptTemplate`.
- **Generation**: `google/flan-t5-base` generates answers, post-processed with regex and deduplication.

## Setup
- **Dependencies**: `pylate`, `langchain`, `transformers`, `google.colab`, `pypdf`, `hf_xet`.
- **Environment**: Google Colab with CPU.
- **Datasets**: PDFs in `/data` folder or uploaded manually.

## How to Run
1. Run Cell 1 to install dependencies.
2. Run Cell 2 to import libraries and define the pipeline.
3. Run Cell 3 to upload PDFs or specify paths.
4. Run Cell 4 to process PDFs and initialize the pipeline.
5. Run Cell 5 to query the pipeline interactively.

## Cell 1: Install Dependencies

Install required libraries for the RAG pipeline. This ensures the notebook runs in a clean Colab environment.

In [None]:
!pip install pylate langchain transformers google-colab
!pip install -U langchain-community pypdf hf_xet

## Cell 2: Import Libraries and Define RAG Pipeline

This cell imports libraries, suppresses warnings, and defines the RAG pipeline as a modular class-based structure.

In [None]:
import warnings
import os
import re
from pylate import models, indexes, retrieve, rank
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from google.colab import files
from transformers import pipeline

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning, module='pypdf._reader')
warnings.filterWarnings('ignore', category=DeprecationWarning, module='pypdf._reader')

# Configuration dictionary for easy parameter tweaking
CONFIG = {
    'chunk_size': 300,
    'chunk_overlap': 50,
    'model_name': 'lightonai/GTE-ModernColBERT-v1',
    'index_folder': 'pylate-index',
    'index_name': 'pdf_index',
    'max_context_length': 600,
    'top_k_initial': 15,
    'batch_size': 32,
    'max_length': 300
}

class RAGPipeline:
    def __init__(self, config):
        """Initialize the RAG pipeline with configuration."""
        self.config = config
        self.model = None
        self.index = None
        self.retriever = None
        self.generator = None
        self.prompt = None
        self.document_map = {}

    def load_and_chunk_pdf(self, pdf_path):
        """Load and chunk a PDF into text segments."""
        if not os.path.exists(pdf_path):
            print(f'PDF not found: {pdf_path}')
            return None, None, None
        try:
            print(f'Processing PDF: {pdf_path}')
            loader = PyPDFLoader(pdf_path)
            documents = loader.load()
            if not documents:
                print(f'No content found in PDF: {pdf_path}')
                return None, None, None
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=self.config['chunk_size'],
                chunk_overlap=self.config['chunk_overlap']
            )
            chunks = text_splitter.split_documents(documents)
            document_texts = [(chunk.page_content, {'pdf': pdf_path, 'page': chunk.metadata['page']}) for chunk in chunks]
            document_ids = [str(i) + '_' + os.path.basename(pdf_path) for i in range(len(document_texts))]
            document_map = {doc_id: text for doc_id, (text, _) in zip(document_ids, document_texts)}
            print(f'Created {len(document_texts)} chunks from {pdf_path}')
            return document_texts, document_ids, document_map
        except Exception as e:
            print(f'Error processing PDF {pdf_path}: {e}')
            return None, None, None

    def initialize_models(self):
        """Initialize the ColBERT model for embedding."""
        try:
            self.model = models.ColBERT(model_name_or_path=self.config['model_name'])
            print(f'Loaded model: {self.config["model_name"]}')
        except Exception as e:
            print(f'Failed to load ColBERT model {self.config["model_name"]}: {e}')
            return False
        return True

    def create_index(self, document_texts, document_ids):
        """Create and populate the Voyager index with document embeddings."""
        try:
            self.index = indexes.Voyager(
                index_folder=self.config['index_folder'],
                index_name=self.config['index_name'],
                override=True
            )
            document_texts_only = [text for text, _ in document_texts]
            documents_embeddings = self.model.encode(
                document_texts_only,
                batch_size=self.config['batch_size'],
                is_query=False,
                show_progress_bar=True
            )
            self.index.add_documents(document_ids, documents_embeddings=documents_embeddings)
            print(f'Indexed {len(document_ids)} documents')
            return True
        except Exception as e:
            print(f'Error creating index: {e}')
            return False

    def setup_retriever_generator(self):
        """Initialize the retriever and generator with prompt template."""
        try:
            self.retriever = retrieve.ColBERT(index=self.index)
            self.generator = pipeline('text2text-generation', model='google/flan-t5-base', max_length=self.config['max_length'])
            prompt_template = '''Using only the provided text, answer the user's question with a concise and accurate response. For questions about specific roles (e.g., CEO, CTO, CFO), return only the full name of the individual in that role. For questions about lists (e.g., products), return all items as a comma-separated list of names only. Exclude any details not directly relevant to the question, such as technical specifications, unless explicitly requested. If the answer is not in the text, respond with 'The answer could not be found in the text.'

Text: {context}

Question: {question}

Answer:''' 
            self.prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question'])
            print('Retriever and generator initialized')
            return True
        except Exception as e:
            print(f'Error setting up retriever or generator: {e}')
            return False

    def query_rag(self, query):
        """Process a query through the RAG pipeline."""
        try:
            queries = [query]

            # Encode Query
            query_embedding = self.model.encode(
                queries,
                batch_size=self.config['batch_size'],
                is_query=True,
                show_progress_bar=True
            )

            # Retrieve Top Documents
            initial_results = self.retriever.retrieve(queries_embeddings=query_embedding, k=self.config['top_k_initial'])
            retrieved_doc_ids = [result['id'] for result in initial_results[0]]
            retrieved_documents = [self.document_map[doc_id] for doc_id in retrieved_doc_ids]

            # Rerank Documents
            reranked_results = rank.rerank(
                documents_ids=[retrieved_doc_ids],
                queries_embeddings=query_embedding,
                documents_embeddings=[self.model.encode(retrieved_documents, is_query=False)]
            )

            # Get Reranked Documents
            reranked_doc_ids = []
            if reranked_results and isinstance(reranked_results[0], list):
                for result in reranked_results[0]:
                    if isinstance(result, dict) and 'id' in result:
                        reranked_doc_ids.append(result['id'])
                    elif isinstance(result, str):
                        reranked_doc_ids.append(result)
            else:
                reranked_doc_ids = retrieved_doc_ids

            reranked_documents = [self.document_map[doc_id] for doc_id in reranked_doc_ids]

            # Create Context
            context = '\n'.join(reranked_documents[:3])[:self.config['max_context_length']]
            prompt_text = self.prompt.format(context=context, question=query)

            # Generate Answer
            response = self.generator(prompt_text)[0]['generated_text']
            answer = response.strip()

            # Post-processing
            if ', ' in answer:
                items = set(answer.split(', '))
                answer = ', '.join(sorted(items)) if items else 'The answer could not be found in the text.'
            if any(role in query.lower() for role in ['ceo', 'cto', 'cfo']):
                role = query.lower().split('who is')[1].strip().upper()
                match = re.search(rf'- ([^,]+), {role}:', context)
                if match:
                    answer = match.group(1).strip()
                else:
                    answer = 'The answer could not be found in the text.'
            if 'product' in query.lower():
                product_names = re.findall(r'- (\w+):', context)
                if product_names:
                    answer = ', '.join(sorted(set(product_names)))
                else:
                    answer = 'The answer could not be found in the text.'

            return context, answer

        except Exception as e:
            print(f'Error processing query: {e}')
            return None, 'Error processing query.'

    def run(self, pdf_paths):
        """Run the pipeline for a list of PDFs."""
        all_document_texts = []
        all_document_ids = []
        for pdf_path in pdf_paths:
            texts, ids, doc_map = self.load_and_chunk_pdf(pdf_path)
            if texts:
                all_document_texts.extend(texts)
                all_document_ids.extend(ids)
                self.document_map.update(doc_map)

        if not all_document_texts:
            print('No valid PDFs processed.')
            return False

        if not self.initialize_models():
            return False

        if not self.create_index(all_document_texts, all_document_ids):
            return False

        if not self.setup_retriever_generator():
            return False

        return True

def main(pdf_paths, queries=None):
    """Orchestrate the RAG pipeline setup and optional querying."""
    pipeline = RAGPipeline(CONFIG)
    if pipeline.run(pdf_paths):
        print('Pipeline initialized successfully.')
        if queries:
            for query in queries:
                context, answer = pipeline.query_rag(query)
                print(f'Query: {query}\nAnswer: {answer}\nContext: {context}\n')
    else:
        print('Pipeline initialization failed.')

## Cell 3: Upload or Specify PDFs

Upload the PDFs (`QuantumCore_v1.pdf`, `NeoCompute_v3.pdf`) or specify their paths if pre-uploaded to `/data`. This cell prepares the knowledge base.

In [None]:
# Specify PDF paths or upload
pdf_paths = ['/data/QuantumCore_v1.pdf', '/data/NeoCompute_v3.pdf']

# Uncomment to upload PDFs manually in Colab
# print('Please upload your PDF files (QuantumCore_v1.pdf and/or NeoCompute_v3.pdf):')
# uploaded = files.upload()
# pdf_paths = list(uploaded.keys())

print(f'PDFs to process: {pdf_paths}')

## Cell 4: Run the Pipeline

This cell initializes the RAG pipeline with the specified PDFs.

In [None]:
main(pdf_paths)

## Cell 5: Interactive Querying

This cell allows interactive querying of the pipeline.

In [None]:
pipeline = RAGPipeline(CONFIG)
if pipeline.run(pdf_paths):
    while True:
        query = input('Enter your query (or "exit" to stop): ')
        if query.lower() == 'exit':
            break
        context, answer = pipeline.query_rag(query)
        print(f'\nQuery: {query}\nAnswer: {answer}\nContext: {context}\n')
else:
    print('Pipeline initialization failed.')