# RAG Pipeline for PDF Query Answering

This notebook implements a Retrieval-Augmented Generation (RAG) pipeline to answer queries based on two PDF datasets: QuantumCore Solutions and NeoCompute Technologies. The pipeline processes PDFs, embeds text chunks, retrieves relevant chunks, and generates answers using a Few-Shot prompt. It is optimized to run on CPU for accessibility.

## Objective
- Answer user queries about company details (e.g., leadership, products, specifications).
- Support questions about QuantumCore, NeoCompute, or both, with source tracking.
- Handle role-based, list-based, detail-based, and general questions.

## Architecture
1. **PDF Processing**: Load and split PDFs into 300-character chunks.
2. **Embedding & Indexing**: Use `lightonai/GTE-ModernColBERT-v1` to embed chunks and store in a `Voyager` index.
3. **Retrieval & Reranking**: Retrieve top 15 chunks using `ColBERT`, rerank to top 3.
4. **Generation**: Use `google/flan-t5-base` with a Few-Shot prompt to generate answers.

## Updates
- Optimized Few-Shot prompt to reduce token count.
- Added input truncation to prevent token length errors.
- Enhanced error handling for robustness.
- Added token count logging for debugging.

## Setup
- Run cells sequentially.
- Ensure PDFs are available in the working directory or uploaded via Colab.
- Tested on Google Colab with CPU.

In [None]:
# Cell 1: Install Dependencies
!pip install -q pylate langchain transformers google-colab langchain-community pypdf
# Note: Removed hf_xet as its purpose is unclear; reinstall if needed.

In [None]:
# Cell 2: Import Libraries
import warnings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from transformers import pipeline, AutoTokenizer
from google.colab import files
from pylate import models, indexes, retrieve, rank
import contextlib
import io

# Suppress pypdf warnings for cleaner output
warnings.filterwarnings('ignore', category=DeprecationWarning, module='pypdf._reader')

In [None]:
# Cell 3: Define RAG Pipeline

# Optimized Few-Shot Prompt
prompt_template = r"""
You are an expert assistant answering questions based solely on the provided text. Follow these rules:
1. For roles (e.g., CEO), return the full name.
2. For lists (e.g., products), return a comma-separated list, sorted alphabetically.
3. For details (e.g., specifications), return the exact detail.
4. For other questions, provide a brief answer.
5. If no answer is found, return: "The answer could not be found in the text."

**Examples**:
- Text: "Jane Smith, CEO." Question: Who is the CEO? Answer: Jane Smith
- Text: "CloudPeak, SecureVault." Question: What products? Answer: CloudPeak, SecureVault
- Text: "AlphaCore: 100 qubits." Question: Qubit count? Answer: 100 qubits

**Text**: {context}

**Question**: {question}

**Answer**:
"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

def run_rag_pipeline(pdf_paths):
    """
    Initialize the RAG pipeline by processing PDFs and setting up models.
    
    Args:
        pdf_paths: List of paths to PDF files.
    
    Returns:
        Tuple of (model, index, retriever, generator, PROMPT, document_map, document_sources)
        or None if processing fails.
    """
    try:
        document_map = {}
        document_sources = {}
        all_splits = []

        # Initialize text splitter
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)

        # Process each PDF
        for pdf_path in pdf_paths:
            print(f'Processing PDF: {pdf_path}')
            try:
                loader = PyPDFLoader(pdf_path)
                documents = loader.load()
                if not documents:
                    print(f'No content extracted from {pdf_path}')
                    continue

                # Split documents into chunks
                splits = text_splitter.split_documents(documents)
                company_name = 'QuantumCore' if 'QuantumCore' in pdf_path else 'NeoCompute'

                # Store splits and track sources
                for i, split in enumerate(splits):
                    doc_id = f'{pdf_path}_{i}'
                    document_map[doc_id] = split.page_content
                    document_sources[doc_id] = company_name
                    all_splits.append(split.page_content)

            except FileNotFoundError:
                print(f'PDF not found: {pdf_path}')
                continue
            except Exception as e:
                print(f'Error processing {pdf_path}: {e}')
                continue

        if not all_splits:
            print('No valid documents processed.')
            return None

        # Initialize embedding model
        print('Initializing embedding model...')
        with contextlib.redirect_stderr(io.StringIO()):
            model = models.ColBERT(model_name='lightonai/GTE-ModernColBERT-v1')

        # Create and populate index
        print('Creating index...')
        index = indexes.Voyager('rag_index', override=True)
        with contextlib.redirect_stderr(io.StringIO()):
            model.add_to_index(all_splits, index=index, batch_size=32)

        # Initialize retriever and generator
        print('Initializing retriever and generator...')
        retriever = retrieve.ColBERT(index=index)
        generator = pipeline('text2text-generation', model='google/flan-t5-base', max_length=300)

        return model, index, retriever, generator, PROMPT, document_map, document_sources

    except Exception as e:
        print(f'Error initializing pipeline: {e}')
        return None

def query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query):
    """
    Process a user query with retrieval and generation, ensuring input fits within token limits.
    
    Args:
        model: ColBERT model for encoding queries.
        index: Voyager index for document embeddings.
        retriever: ColBERT retriever for fetching chunks.
        generator: FLAN-T5 pipeline for text generation.
        PROMPT: PromptTemplate with Few-Shot examples.
        document_map: Dict mapping document IDs to text.
        document_sources: Dict mapping document IDs to source (QuantumCore/NeoCompute).
        query: User query string.
    
    Returns:
        Generated answer or error message.
    """
    try:
        # Initialize tokenizer for flan-t5-base
        tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')
        
        # Encode query
        print('Encoding query...')
        with contextlib.redirect_stderr(io.StringIO()):
            query_embedding = model.encode([query], is_query=True, show_progress_bar=False)
        
        # Retrieve top 15 chunks
        print('Retrieving chunks...')
        retrieved = retriever.retrieve(query_embedding, k=15)
        
        # Rerank to top 3
        print('Reranking chunks...')
        reranked = rank.rerank(query, retrieved, k=3)
        
        # Filter by company if specified
        if 'QuantumCore' in query:
            reranked = [doc_id for doc_id in reranked if document_sources[doc_id] == 'QuantumCore']
        elif 'NeoCompute' in query:
            reranked = [doc_id for doc_id in reranked if document_sources[doc_id] == 'NeoCompute']
        
        # Combine context from top chunks (up to 3)
        context = "\n".join([document_map[doc_id] for doc_id in reranked[:3]])
        
        # Format prompt
        prompt = PROMPT.format(context=context, question=query)
        
        # Log token counts for debugging
        prompt_tokens = len(tokenizer.encode(PROMPT.template))
        context_tokens = len(tokenizer.encode(context))
        query_tokens = len(tokenizer.encode(query))
        print(f'Token counts: Prompt template={prompt_tokens}, Context={context_tokens}, Query={query_tokens}, Total={prompt_tokens + context_tokens + query_tokens}')
        
        # Truncate prompt to fit within 512 tokens
        max_tokens = 512
        tokenized_prompt = tokenizer(prompt, truncation=True, max_length=max_tokens, return_tensors='pt')
        truncated_prompt = tokenizer.decode(tokenized_prompt['input_ids'][0], skip_special_tokens=True)
        
        # Generate answer
        print('Generating answer...')
        with contextlib.redirect_stderr(io.StringIO()):
            answer = generator(truncated_prompt, max_length=300)[0]['generated_text']
        
        # Post-process answer
        if not answer.strip():
            return "The answer could not be found in the text."
        
        return answer.strip()
    
    except ValueError as e:
        if "sequence length" in str(e).lower():
            print("Input exceeds model’s token limit. Truncated input used.")
            return answer.strip() if 'answer' in locals() else "The answer could not be found in the text."
        else:
            print(f'Error processing query: {e}')
            return "The answer could not be found in the text."
    except Exception as e:
        print(f'Error processing query: {e}')
        return "The answer could not be found in the text."

# Example usage (uncomment to test)
# pdf_paths = ['QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf', 'NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf']
# pipeline_components = run_rag_pipeline(pdf_paths)
# if pipeline_components:
#     model, index, retriever, generator, PROMPT, document_map, document_sources = pipeline_components
#     query = "Who is the CEO of QuantumCore?"
#     answer = query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query)
#     print(f'Answer: {answer}')

In [None]:
# Cell 4: Upload PDFs
# Note: Placeholder for PDF upload logic (assumed from original notebook)
from google.colab import files

print('Please upload the PDF files:')
uploaded = files.upload()
pdf_paths = list(uploaded.keys())

if not pdf_paths:
    print('No PDFs uploaded. Please upload the required files.')
else:
    print(f'Uploaded PDFs: {pdf_paths}')
    pipeline_components = run_rag_pipeline(pdf_paths)
    if pipeline_components:
        model, index, retriever, generator, PROMPT, document_map, document_sources = pipeline_components
    else:
        print('Pipeline initialization failed.')

In [None]:
# Cell 5: Interactive Query Interface
# Note: Placeholder for interactive query loop (assumed from original notebook)
if 'pipeline_components' in locals() and pipeline_components:
    while True:
        query = input('Enter your query (or type "exit" to quit): ')
        if query.lower() == 'exit':
            break
        answer = query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query)
        print(f'Answer: {answer}\n')
else:
    print('Pipeline not initialized. Please run previous cells.')