# Retrieval-Augmented Generation (RAG) Pipeline Demo (TAKE_HOME_PROJECT)

This Jupyter Notebook implements a minimal Retrieval-Augmented Generation (RAG) pipeline for a take-home project interview. The system answers user queries by leveraging content from two PDF datasets: `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` and `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf`. It demonstrates versatility in handling varied queries (e.g., leadership roles, product lists, technical specifications) using a lightweight, CPU-friendly setup suitable for Google Colab.

## Objective
- **Purpose**: Combine retrieval and generative AI to provide accurate, context-grounded answers from PDF content.
- **Resource Efficiency**: Use small models (`lightonai/GTE-ModernColBERT-v1` for embeddings, `google/flan-t5-base` for generation) to ensure compatibility with CPU environments.
- **Post-Processing(moved to prompt)**: Apply minimal regex-based post-processing for role-based queries (e.g., extracting CEO names) and product queries (e.g., listing product names), with deduplication to ensure clean outputs.
- **Interactivity**: Support an interactive query interface for demo purposes, with example queries to showcase functionality.

## Architecture
The pipeline follows a modular RAG design:
- **Knowledge Base**: PDFs are loaded using `PyPDFLoader` and split into chunks (200 characters, 25-character overlap) with `RecursiveCharacterTextSplitter`. Chunks are stored in a dictionary mapping document IDs to text, with source tracking for company-specific filtering.
- **Semantic Layer**: Text chunks and queries are embedded into dense vectors using `lightonai/GTE-ModernColBERT-v1` for semantic similarity comparison.
- **Retrieval System**: `retrieve.ColBERT` fetches the top 15 relevant chunks based on query embeddings, which are reranked to the top 3 using `rank.rerank` for improved relevance.
- **Augmentation**: The top 3 chunks (up to 500 characters) are combined with combined with the query via a Few-Shot PromptTemplate tocreate a contextualized input for the generative model
- **Generation**: `google/flan-t5-base` produces concise answers, with post-processing to extract names for role queries (e.g., CEO), list products for product queries, or deduplicate comma-separated lists.
- **Fixes Implemented**:

  -  Added source tracking to filter chunks by company based on query keywords.
  -  Implemented Few-Shot Prompting with fictional examples.


## Setup
- **Dependencies**: Requires `pylate`, `langchain`, `transformers`, `google-colab`, `pypdf`, `hf_xet` for PDF processing, embedding, retrieval, and generation.
- **Environment**: Designed for Google Colab with CPU, ensuring accessibility without GPU requirements.
- **Datasets**: Processes `QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` (quantum computing company details) and `NeoCompute_Technologies_RAG_Demo_Dataset_v3.pdf` (assumed similar content).

## Instructions
1. **Cell 1**: Install required Python libraries to set up the environment.
2. **Cell 2**: Import libraries and suppress warnings for cleaner output.
3. **Cell 3**: Define the RAG pipeline functions (`run_rag_pipeline` and `query_rag`) with improved logic.
4. **Cell 4**: Load and process the PDFs, initializing the pipeline with models and indexes.
5. **Cell 5**: Run an interactive query interface to test the pipeline with example or custom queries.

The pipeline combines chunks from both PDFs into a single knowledge base but filters by company when specified in queries, ensuring relevant responses.

## Cell 1: Install Dependencies

This cell installs the necessary Python libraries for the RAG pipeline. It ensures compatibility in a clean Google Colab environment by installing `pylate` (for ColBERT embeddings and retrieval), `langchain` (for document loading and splitting), `transformers` (for the FLAN-T5 model), `google-colab` (for Colab utilities), and additional dependencies (`langchain-community`, `pypdf`, `hf_xet`) for PDF processing and Hugging Face integration.

In [2]:
# Install core libraries for RAG pipeline (pylate for ColBERT, langchain for document processing, transformers for generation)
!pip install pylate langchain transformers google-colab
# Install additional dependencies for PDF loading and Hugging Face integration
!pip install -U langchain-community pypdf hf_xet
!pip install nltk

Collecting pylate
  Downloading pylate-1.2.0-py3-none-any.whl.metadata (16 kB)
Collecting sentence-transformers==4.0.2 (from pylate)
  Downloading sentence_transformers-4.0.2-py3-none-any.whl.metadata (13 kB)
Collecting datasets>=2.20.0 (from pylate)
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting voyager>=2.0.9 (from pylate)
  Downloading voyager-2.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.9 kB)
Collecting sqlitedict>=2.1.0 (from pylate)
  Downloading sqlitedict-2.1.0.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers
  Downloading transformers-4.48.2-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ujson==5.10.0 (from pylate)
  Downloading ujson-5.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)
Collecting ninja==1.11.1.4 (from pylat

In [None]:
!pip install pylate langchain transformers ipywidgets numpy huggingface_hub

In [None]:
!jupyter nbextension enable --py widgetsnbextension

## Cell 2: Import Libraries

This cell imports the required Python libraries for the pipeline and suppresses warnings to ensure cleaner output in Colab. Key libraries include:
- `pylate` for ColBERT-based embedding and retrieval (`models`, `indexes`, `retrieve`, `rank`).
- `langchain` for PDF loading (`PyPDFLoader`), text splitting (`RecursiveCharacterTextSplitter`), and prompt creation (`PromptTemplate`).
- `transformers` for the FLAN-T5 model (`pipeline`).
- `google.colab.files` for handling file uploads in Colab.
- `os`, `re` for file path handling and regex post-processing.
- Warnings from `pypdf` are suppressed to avoid cluttering the output.

## Cell 3: Define RAG Pipeline

This cell defines the core functions of the RAG pipeline:
- **`run_rag_pipeline`**: Processes PDFs by loading, chunking, embedding, and indexing them, then initializes the retriever and generator. It  tracks the source PDF for each chunk to enable company-specific filtering. Create a prompt template with fictional examples to guide answer extraction
- **`query_rag`**: Handles user queries by encoding them, retrieving and reranking relevant chunks, augmenting the query with context, generating an answer, and applying post-processing. Fixes include:
  - **Company Filtering**: Filters chunks by company (QuantumCore or NeoCompute) based on query keywords.
  - **Role Extraction**: Uses an improved regex to extract names for roles (e.g., CEO) reliably.
  
  - **Robust Post-Processing**: Ensures accurate (removed) and fallback to raw generated answers when needed.

The pipeline is designed to be robust, handling errors gracefully and providing clear feedback if processing fails. All strings (e.g., prompt template, regex patterns(removed)) are properly escaped to ensure valid JSON.

## Cell 4: Process PDFs

This cell specifies the paths to the PDF datasets (`QuantumCore_Solutions_RAG_Demo_Dataset_v1.pdf` and `NeoCompute_Technologies_RAG_Demo_Dataset_v2.pdf`) and initializes the RAG pipeline by calling `run_rag_pipeline`. If running locally, ensure the PDFs are in the `/data` directory. In Google Colab, the cell checks for missing files and prompts the user to upload them. The pipeline is initialized with the ColBERT model, Voyager index, retriever, FLAN-T5 generator, prompt template, document map, and source tracking for company-specific filtering.

## Cell 5: Interactive Querying

This cell provides an interactive interface to query the RAG system. Users can enter custom queries or use provided examples. The system retrieves relevant chunks, generates an answer, and displays both the context and response. Example queries test various aspects of the pipeline:
- Role queries (e.g., CEO name) use regex to extract full names.
- Product queries list product names, with fixes to exclude non-products.
- Specification queries (e.g., qubit count) extract specific details.
- Compliance queries return lists of standards, deduplicated and filtered.

**Example Queries**:
- Who is the CEO of QuantumCore Solutions?
- What are the products offered by NeoCompute Technologies?
- What is the qubit count of QubitCore? → Expected: '
- What compliance standards does NeoCompute follow?

## Cell 6: Interactive Querying



**Example Queries**:
- Who is the CEO of QuantumCore Solutions? → 'Dr. Elena Ruiz'
- What are the products offered by NeoCompute Technologies? → 'NeoCloud, NeoSecure'
- Who is the CIO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- Who is the CEO of NeoCompute Technologies? → 'The answer could not be found in the text.'
- What is the qubit count of QubitCore? → '50-qubit superconducting architecture'
- What compliance standards does NeoCompute follow? → 'ISO/IEC 27001, SOC 2 Type II'

In [43]:
import warnings
import os
from pylate import models, indexes, retrieve, rank
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from transformers import pipeline
import contextlib
import io
import tempfile
from ipywidgets import widgets, Layout, HBox, VBox
from IPython.display import display, clear_output
import numpy as np
import torch
from huggingface_hub import login

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

# Prompt for Hugging Face token
hf_token_widget = widgets.Password(value='', placeholder='Enter your Hugging Face token', description='HF Token:')
def on_hf_token_submit(b):
    login(token=hf_token_widget.value)
    print('Hugging Face token set successfully.')
    hf_token_widget.layout.display = 'none'
hf_token_button = widgets.Button(description='Submit Token')
hf_token_button.on_click(on_hf_token_submit)
display(hf_token_widget, hf_token_button)

def run_rag_pipeline(pdf_files, chunk_size, chunk_overlap, index_type, batch_size, llm_model, status_output):
    """Initialize the RAG pipeline with user-specified parameters."""
    def print_status(msg):
        with status_output:
            print(msg)

    try:
        # Initialize data structures
        all_document_texts = []
        all_document_ids = []
        document_map = {}
        current_doc_id = 0
        document_sources = {}

        # Use character-based splitting
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        print_status("Using character-based splitting.")

        # Process uploaded PDFs
        temp_dir = tempfile.gettempdir()
        for pdf_name, pdf_content in pdf_files.items():
            if not pdf_name.endswith('.pdf'):
                print_status(f'Skipping non-PDF file: {pdf_name}')
                continue
            if len(pdf_content['content']) > 10 * 1024 * 1024:  # Skip >10MB files
                print_status(f'Skipping large file: {pdf_name}')
                continue
            pdf_path = os.path.join(temp_dir, pdf_name)
            print_status(f'Processing PDF: {pdf_name}')
            with open(pdf_path, 'wb') as f:
                f.write(pdf_content['content'])
            # Add error handling for PDF loading
            try:
                loader = PyPDFLoader(pdf_path)
                documents = loader.load()
            except Exception as e:
                print_status(f'Error loading {pdf_name}: {e}')
                continue
            if not documents:
                print_status(f'No content loaded from {pdf_name}')
                continue
            chunks = text_splitter.split_documents(documents)
            document_texts = [chunk.page_content for chunk in chunks]
            avg_chunk_len = sum(len(t) for t in document_texts) / len(document_texts) if document_texts else 0
            max_chunk_len = max(len(t) for t in document_texts) if document_texts else 0
            print_status(f'Created {len(document_texts)} chunks from {pdf_name}, avg length: {avg_chunk_len:.1f} chars, max length: {max_chunk_len} chars')
            document_ids = [str(i + current_doc_id) for i in range(len(document_texts))]
            document_map.update(dict(zip(document_ids, document_texts)))
            source_name = os.path.basename(pdf_name).split('_')[0]
            document_sources.update({doc_id: source_name for doc_id in document_ids})
            all_document_texts.extend(document_texts)
            all_document_ids.extend(document_ids)
            current_doc_id += len(document_texts)

        if not all_document_texts:
            print_status('No documents processed.')
            return None

        print_status(f'Total chunks created: {len(all_document_texts)}')

        # Set up embedding model
        model_name = 'lightonai/GTE-ModernColBERT-v1'
        try:
            # Use recommended document_length (8192) and query_length (32) for the model
            model = models.ColBERT(model_name_or_path=model_name, document_length=8192, query_length=32)
            print_status(f"Initialized ColBERT model with document_length=8192, query_length=32.")
        except TypeError as e:
            print_status(f'Warning: Failed to set query_length explicitly ({e}). Using document_length=8192 and model default for query_length.')
            model = models.ColBERT(model_name_or_path=model_name, document_length=8192)

        # Generate and store embeddings
        print_status('Generating document embeddings...')
        with contextlib.redirect_stderr(io.StringIO()):
            documents_embeddings_list_of_tensors = model.encode(
                all_document_texts,
                batch_size=batch_size,
                is_query=False,
                show_progress_bar=False
            )

        if documents_embeddings_list_of_tensors:
            print_status(f'Pre-padding document embeddings shape: ({len(documents_embeddings_list_of_tensors)},), individual shapes: [{documents_embeddings_list_of_tensors[0].shape}, {documents_embeddings_list_of_tensors[1].shape}] (first 2)')

        # Pad embeddings to fixed length (299 tokens as per model's internal max_seq_length)
        max_token_len_for_padding = 299
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        padded_documents_embeddings_tensors = []
        for emb_tensor in documents_embeddings_list_of_tensors:
            # Ensure emb_tensor is a PyTorch tensor
            if not isinstance(emb_tensor, torch.Tensor):
                emb_tensor = torch.tensor(emb_tensor, dtype=torch.float32, device=device)
            seq_len = emb_tensor.shape[0]
            if seq_len < max_token_len_for_padding:
                dtype = emb_tensor.dtype if isinstance(emb_tensor.dtype, torch.dtype) else torch.float32
                padding = torch.zeros((max_token_len_for_padding - seq_len, emb_tensor.shape[1]),
                                      dtype=dtype, device=emb_tensor.device)
                emb_tensor = torch.cat([emb_tensor, padding], dim=0)
            elif seq_len > max_token_len_for_padding:
                emb_tensor = emb_tensor[:max_token_len_for_padding, :]
            padded_documents_embeddings_tensors.append(emb_tensor.to(device))

        # Show the shape after padding
        batched_embeddings = torch.stack(padded_documents_embeddings_tensors)
        print_status(f'Padded document embeddings shape: {batched_embeddings.shape}')

        # Create index
        index = indexes.Voyager(index_folder='pylate-index', index_name='pdf_index', override=True) if index_type == 'Voyager' else indexes.Plade(index_folder='pylate-index', index_name='pdf_index', override=True)
        try:
            index.add_documents(all_document_ids, documents_embeddings=padded_documents_embeddings_tensors)
            print_status("Index created successfully.")
        except Exception as e:
            print_status(f'ERROR creating index: {str(e)}')
            try:
                index.add_documents(all_document_ids, documents_embeddings=batched_embeddings)
                print_status("Index created successfully with batched embeddings.")
            except Exception as e2:
                print_status(f'ERROR creating index with batched embeddings: {str(e2)}')
                raise

        # Initialize retriever
        retriever = retrieve.ColBERT(index=index)

        # Initialize generator with user-selected LLM
        print_status(f'Loading LLM model: {llm_model}')
        try:
            generator = pipeline(
                'text2text-generation',
                model=llm_model,
                max_length=150,
                device=0 if torch.cuda.is_available() else -1,
                model_kwargs={
                    "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
                    "low_cpu_mem_usage": True
                }
            )
        except Exception as e:
            print_status(f'Error loading optimized LLM model, trying basic configuration: {e}')
            generator = pipeline('text2text-generation', model=llm_model, max_length=150)

        prompt_template = r"""
You are an expert assistant answering questions based solely on the provided text. Prioritize information about roles (e.g., CEO, CTO) and return precise answers. Follow these rules:
1. For roles, return the full name (e.g., "Jane Smith" for CEO).
2. For lists (e.g., products), return a comma-separated list, sorted alphabetically.
3. For details (e.g., specifications), return the exact detail.
4. For comparisons, clearly compare attributes across entities.
5. If no answer is found, return: "The answer could not be found in the text."

**Examples**:
- Text: "Jane Smith, CEO of the CloudPeak." Question: Who is the CEO of  CloudPeak? Answer: Jane Smith
- Text: "CloudPeak, SecureVault." Question: What products? Answer: CloudPeak, SecureVault
- Text: "AlphaCore: 100 qubits." Question: Qubit count? Answer: 100 qubits
- Text: "QuantumCore: 50 qubits. NeoCompute: 75 qubits." Question: Compare qubit counts? Answer: QuantumCore has 50 qubits, NeoCompute has 75 qubits

**Text**: {context}

**Question**: {question}

**Answer**:
"""
        PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

        # Return the original list of PyTorch tensors for reranking later
        return model, index, retriever, generator, PROMPT, document_map, document_sources, padded_documents_embeddings_tensors

    except Exception as e:
        print_status(f'Error processing PDFs: {e}')
        raise
        return None

def query_rag(model, index, retriever, generator, PROMPT, document_map, document_sources, query, max_context_length, batch_size, documents_embeddings_tensors, status_output):
    """Process a user query with retrieval and generation."""
    def print_status(msg):
        with status_output:
            print(msg)

    try:
        queries = [query]
        print_status(f'Querying: {query}')
        print_status('Encoding query...')
        with contextlib.redirect_stderr(io.StringIO()):
            query_embedding_list_of_tensors = model.encode(
                queries,
                batch_size=batch_size,
                is_query=True,
                show_progress_bar=False
            )

        # Extract the actual query tensor
        if isinstance(query_embedding_list_of_tensors, list) and len(query_embedding_list_of_tensors) > 0:
            actual_query_tensor = query_embedding_list_of_tensors[0]
            # Convert to PyTorch tensor if it's a NumPy array
            if not isinstance(actual_query_tensor, torch.Tensor):
                print_status(f'Converting query embedding to PyTorch tensor')
                actual_query_tensor = torch.tensor(actual_query_tensor, dtype=torch.float32, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            if actual_query_tensor.ndim == 3 and actual_query_tensor.shape[0] == 1:
                actual_query_tensor = actual_query_tensor.squeeze(0)

            print_status(f'Query embedding shape (after potential list extraction/squeeze): {actual_query_tensor.shape}')
        else:
            print_status('Error: query_embedding is not a list or is empty after encoding.')
            return None, 'Query processing failed: Could not encode query for reranking.'

        print_status('Retrieving documents...')
        with contextlib.redirect_stderr(io.StringIO()):
            top_k_initial = 15
            initial_results = retriever.retrieve(queries_embeddings=query_embedding_list_of_tensors, k=top_k_initial)

        if not initial_results or not initial_results[0]:
            print_status('Error: No documents retrieved.')
            return None, 'Error: No relevant documents found.'

        retrieved_doc_ids = [result['id'] for result in initial_results[0] if 'id' in result]
        if not retrieved_doc_ids:
            print_status('No document IDs after retrieval.')
            return None, 'Error: No relevant documents found.'

        # Filter by company
        from difflib import SequenceMatcher
        def match_company(query, company_name, threshold=0.8):
            query_lower = query.lower()
            company_lower = company_name.lower()
            return (company_name.lower() in query_lower or
                    SequenceMatcher(None, query_lower, company_lower).ratio() > threshold)

        company = None
        filtered_doc_ids = retrieved_doc_ids
        if match_company(query, 'QuantumCore'):
            company = 'QuantumCore'
        elif match_company(query, 'NeoCompute'):
            company = 'NeoCompute'
        if company:
            filtered_doc_ids = [doc_id for doc_id in retrieved_doc_ids if document_sources.get(doc_id) == company]
            if not filtered_doc_ids:
                print_status(f'No documents for {company}. Falling back to all documents.')
                filtered_doc_ids = retrieved_doc_ids

        # Reranking with padded embeddings - FIXED IMPLEMENTATION
        print_status('Reranking documents...')
        try:
            with contextlib.redirect_stderr(io.StringIO()):
                valid_doc_ids_int = [int(doc_id) for doc_id in filtered_doc_ids if int(doc_id) < len(documents_embeddings_tensors)]
                if not valid_doc_ids_int:
                    print_status('No valid document IDs for reranking. Using initial results.')
                    reranked_doc_ids = filtered_doc_ids[:3]  # Changed from 2 to 3
                else:
                    selected_document_tensors = [documents_embeddings_tensors[idx] for idx in valid_doc_ids_int]

                    # Show selected embeddings shape
                    if selected_document_tensors:
                        print_status(f'Selected embeddings shape for reranking: {selected_document_tensors[0].shape}')

                    # FIXED CODE: Structure data correctly for PyLate's rerank function
                    print_status('Preparing data for reranking with correct structure...')

                    # 1. Make sure query_embeddings is a 2D tensor [seq_len, hidden_dim]
                    if actual_query_tensor.dim() == 3:
                        query_embeddings = actual_query_tensor.squeeze(0) # Remove batch dimension if present
                    else:
                        query_embeddings = actual_query_tensor # Already 2D

                    # 2. Structure the inputs as expected by rerank:
                    #    - queries_embeddings should be a list with one tensor
                    #    - documents_ids should be a list with one list of document IDs
                    #    - documents_embeddings should be a list with one list of document tensors
                    queries_embeddings_for_rerank = [query_embeddings]  # One query
                    documents_ids_for_rerank = [valid_doc_ids_int]      # One list of doc IDs
                    documents_embeddings_for_rerank = [selected_document_tensors]  # One list of doc embeddings

                    print_status(f'Ready for reranking with {len(documents_embeddings_for_rerank[0])} documents')

                    # 3. Execute reranking with properly structured inputs
                    reranked_results = rank.rerank(
                        documents_ids=documents_ids_for_rerank,
                        queries_embeddings=queries_embeddings_for_rerank,
                        documents_embeddings=documents_embeddings_for_rerank
                    )

                    # Process results
                    reranked_doc_ids = []
                    if reranked_results and len(reranked_results) > 0:
                        for result in reranked_results[0]:  # First query results
                            original_doc_id_str = str(result['id']) if isinstance(result, dict) and 'id' in result else str(result)
                            if original_doc_id_str in document_map:
                                reranked_doc_ids.append(original_doc_id_str)
                        reranked_doc_ids = reranked_doc_ids[:2]  # Changed from 2 to 3
                        if not reranked_doc_ids:
                            print_status('No valid document IDs after reranking. Using initial results.')
                            reranked_doc_ids = [str(doc_id) for doc_id in valid_doc_ids_int[:3]]  # Changed from 2 to 3
        except Exception as e:
            print_status(f'Reranking failed: {e}. Falling back to initial retrieval results.')
            reranked_doc_ids = [str(doc_id) for doc_id in filtered_doc_ids[:3]]  # Changed from 2 to 3

        print_status(f'Selected doc IDs for context: {reranked_doc_ids}')

        user_max_context_length_chars = int(max_context_length)

        full_reranked_documents = [document_map.get(doc_id, '') for doc_id in reranked_doc_ids]
        full_reranked_documents = [doc for doc in full_reranked_documents if doc]

        context_parts = []
        current_context_char_length = 0
        for doc in full_reranked_documents:
            if current_context_char_length + len(doc) <= user_max_context_length_chars:
                context_parts.append(doc)
                current_context_char_length += len(doc)
            else:
                remaining_chars = user_max_context_length_chars - current_context_char_length
                if remaining_chars > 0:
                    context_parts.append(doc[:remaining_chars])
                break

        context = '\n'.join(context_parts)
        if not context:
            print_status('No context generated.')
            return None, 'No relevant context found.'

        prompt_text = PROMPT.format(context=context, question=query)
        print_status(f'Context length used (chars): {len(context)}')
        print_status('Generating answer...')

        try:
            generation_kwargs = {
                'max_length': 150,
                'num_beams': 2,
                'early_stopping': True,
                'do_sample': False,
                'pad_token_id': generator.tokenizer.eos_token_id
            }
            response = generator(prompt_text, **generation_kwargs)[0]['generated_text']
        except Exception as gen_error:
            print_status(f'Generation with optimized parameters failed: {gen_error}. Using default parameters.')
            response = generator(prompt_text, max_length=150)[0]['generated_text']

        answer = response.strip()

        if not answer or answer.lower() == 'none':
            if not context:
                answer = 'The answer could not be found due to no relevant context in the text.'
            elif company and not filtered_doc_ids:
                answer = f'The answer could not be found for {company} in the text.'
            else:
                answer = 'The answer could not be found in the retrieved text.'

        return context, answer

    except Exception as e:
        print_status(f'Error processing query: {e}')
        return None, f'Error: {e}'

# Streamlined UI with widgets
def create_rag_interface():
    # Define widgets
    upload_widget = widgets.FileUpload(accept='.pdf', multiple=True, description='Upload PDFs')
    chunk_size_widget = widgets.IntSlider(value=350, min=200, max=500, step=50, description='Chunk Size (chars):')
    chunk_overlap_widget = widgets.IntSlider(value=50, min=10, max=100, step=10, description='Chunk Overlap (chars):')
    index_type_widget = widgets.Dropdown(options=['Voyager', 'Plade'], value='Voyager', description='Index Type:')
    batch_size_widget = widgets.IntSlider(value=32, min=8, max=64, step=8, description='Batch Size:')
    llm_model_widget = widgets.Dropdown(
        options=[
            'google/flan-t5-base',
            'google/flan-t5-large',
            'google/flan-t5-xl',
            'microsoft/DialoGPT-medium',
            'facebook/bart-large',
            'facebook/bart-large-cnn'
        ],
        value='google/flan-t5-large',
        description='LLM Model:'
    )
    query_widget = widgets.Text(value='', placeholder='Type your query here', description='Query:', layout={'width': '500px'})
    context_length_widget = widgets.IntSlider(value=1500, min=500, max=2500, step=100, description='Max Context Length (chars):')
    run_button = widgets.Button(description='Run Pipeline', button_style='success')
    reset_button = widgets.Button(description='Reset', button_style='warning')
    context_output = widgets.Output(layout={'border': '1px solid black', 'padding': '10px', 'width': '50%', 'height': '300px', 'overflow': 'auto'})
    answer_output = widgets.Output(layout={'border': '1px solid black', 'padding': '10px', 'width': '50%', 'height': '300px', 'overflow': 'auto'})
    status_output = widgets.Output(layout={'border': '1px solid black', 'padding': '10px', 'height': '100px', 'overflow': 'auto'})
    header_widget = widgets.HTML(value='<h2>Streamlined RAG Query Interface</h2>')
    loading_widget = widgets.HTML(value='')

    # Pipeline state
    pipeline_state = {'components': None}

    # Button callbacks
    def on_run_button_clicked(b):
        loading_widget.value = '<i>Processing...</i>'
        with status_output:
            clear_output()
            print('Starting operation...')

        if pipeline_state['components'] is None:
            with status_output:
                print('Pipeline not initialized. Initializing now (this may take a minute)...')

            if not upload_widget.value:
                with status_output:
                    print('Error: Please upload at least one PDF to initialize the pipeline.')
                loading_widget.value = ''
                return

            pdf_files = upload_widget.value
            result = run_rag_pipeline(
                pdf_files,
                chunk_size=chunk_size_widget.value,
                chunk_overlap=chunk_overlap_widget.value,
                index_type=index_type_widget.value,
                batch_size=batch_size_widget.value,
                llm_model=llm_model_widget.value,
                status_output=status_output
            )
            if result:
                pipeline_state['components'] = result
                with status_output:
                    print('RAG pipeline initialized successfully.')
            else:
                with status_output:
                    print('Failed to initialize RAG pipeline. Please check inputs and logs.')
                loading_widget.value = ''
                return
        else:
            with status_output:
                print('Pipeline already initialized. Proceeding with query.')

        if not query_widget.value.strip():
            with status_output:
                print('Error: Please enter a query.')
            loading_widget.value = ''
            return

        model, index, retriever, generator, PROMPT, document_map, document_sources, doc_embeddings_tensors = pipeline_state['components']

        with context_output:
            clear_output()
        with answer_output:
            clear_output()

        context, answer = query_rag(
            model, index, retriever, generator, PROMPT, document_map, document_sources,
            query_widget.value, context_length_widget.value, batch_size_widget.value, doc_embeddings_tensors, status_output
        )

        with context_output:
            if context is None:
                print('Error: No context retrieved.')
            else:
                print(context)
        with answer_output:
            print(answer)
        with status_output:
            if context is None and answer.startswith('Error'):
                print('Query processing failed.')
            else:
                print('Query processed successfully.')
        loading_widget.value = ''

    def on_reset_button_clicked(b):
        pipeline_state['components'] = None
        query_widget.value = ''
        with context_output:
            clear_output()
        with answer_output:
            clear_output()
        with status_output:
            clear_output()
            print('Interface reset. Uploaded files remain in the upload widget; upload new files to replace them.')
            print('Pipeline components will be re-initialized on the next "Run Pipeline" click if new PDFs are uploaded.')

    run_button.on_click(on_run_button_clicked)
    reset_button.on_click(on_reset_button_clicked)

    config_box = VBox([
        upload_widget,
        chunk_size_widget,
        chunk_overlap_widget,
        index_type_widget,
        batch_size_widget,
        llm_model_widget,
        context_length_widget,
        query_widget,
        HBox([run_button, reset_button]),
        loading_widget
    ], layout=Layout(padding='10px'))
    output_box = HBox([context_output, answer_output], layout=Layout(padding='10px'))
    display(VBox([header_widget, config_box, output_box, status_output]))

create_rag_interface()

Password(description='HF Token:', placeholder='Enter your Hugging Face token')

Button(description='Submit Token', style=ButtonStyle())

VBox(children=(HTML(value='<h2>Streamlined RAG Query Interface</h2>'), VBox(children=(FileUpload(value={}, acc…

Adding documents to the index (bs=2000): 100%|██████████| 1/1 [01:10<00:00, 70.26s/it]
Device set to use cpu
Adding documents to the index (bs=2000): 100%|██████████| 1/1 [01:13<00:00, 73.16s/it]


config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]