In [None]:
# Task 6: Interactive Notebook Demo, Jupyter notebook for querying the RAG system

# Cell 1: Imports and Setup
import json
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import pickle
from IPython.display import display, HTML, Markdown

# Cell 2: Load Model and Index
print("Loading model and index...")
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
index = faiss.read_index("faiss_index.bin")

with open("chunks.pkl", "rb") as f:
    chunks = pickle.load(f)

print(f"✓ Model loaded: {model.get_sentence_embedding_dimension()}D embeddings")
print(f"✓ Index loaded: {index.ntotal} vectors")
print(f"✓ Chunks loaded: {len(chunks)} chunks")

# Cell 3: Define Search Function
def search_papers(query, k=3):
    """
    Search for relevant paper chunks
    
    Args:
        query: Search query string
        k: Number of results to return
        
    Returns:
        List of results with metadata
    """
    # Embed the query
    query_embedding = model.encode([query], convert_to_numpy=True)
    
    # Search the index
    distances, indices = index.search(query_embedding, k)
    
    # Collect results
    results = []
    for idx, distance in zip(indices[0], distances[0]):
        if idx < len(chunks):
            chunk = chunks[idx]
            results.append({
                'chunk_id': chunk['chunk_id'],
                'source': chunk['source_id'],
                'text': chunk['text'],
                'distance': float(distance),
                'similarity': 1 / (1 + distance)  # Convert distance to similarity
            })
    
    return results

# Cell 4: Helper Function for Display
def display_results(query, results):
    """
    Display search results in a formatted way
    """
    print("="*80)
    print(f"QUERY: {query}")
    print("="*80)
    print()
    
    for i, result in enumerate(results, 1):
        print(f"{'─'*80}")
        print(f"RESULT #{i}")
        print(f"{'─'*80}")
        print(f"Source Paper: {result['source']}")
        print(f"Chunk ID: {result['chunk_id']}")
        print(f"Similarity Score: {result['similarity']:.4f}")
        print(f"\nContent:")
        print(result['text'])
        print()

# Cell 5: Example Queries
# Try different queries to test the system

# Query 1: Transformers
query1 = "What are transformer models and how do they work?"
results1 = search_papers(query1, k=3)
display_results(query1, results1)

# Cell 6: Query 2 - Attention Mechanism
query2 = "Explain the attention mechanism in neural networks"
results2 = search_papers(query2, k=3)
display_results(query2, results2)

# Cell 7: Query 3 - Training Methods
query3 = "How to train large language models efficiently?"
results3 = search_papers(query3, k=3)
display_results(query3, results3)

# Cell 8: Interactive Search Widget
from ipywidgets import interact, Text, IntSlider

@interact(
    query=Text(value="", description="Query:", placeholder="Enter your question..."),
    k=IntSlider(value=3, min=1, max=10, description="Results:")
)
def interactive_search(query, k):
    if query:
        results = search_papers(query, k)
        display_results(query, results)
    else:
        print("Enter a query to search the arXiv papers!")

# Cell 9: Statistics and Analysis
def analyze_index():
    """
    Show statistics about the indexed papers
    """
    # Count papers
    unique_sources = set(chunk['source_id'] for chunk in chunks)
    
    # Token statistics
    token_counts = [chunk['token_count'] for chunk in chunks]
    
    print("Index Statistics:")
    print(f"  Total papers: {len(unique_sources)}")
    print(f"  Total chunks: {len(chunks)}")
    print(f"  Avg chunks per paper: {len(chunks) / len(unique_sources):.1f}")
    print(f"\nChunk Size Statistics:")
    print(f"  Mean tokens: {np.mean(token_counts):.1f}")
    print(f"  Median tokens: {np.median(token_counts):.1f}")
    print(f"  Min tokens: {np.min(token_counts)}")
    print(f"  Max tokens: {np.max(token_counts)}")
    
    # Show sample papers
    print(f"\nSample Papers:")
    for i, source in enumerate(list(unique_sources)[:5], 1):
        print(f"  {i}. {source}")

analyze_index()

# Cell 10: Advanced: Retrieve full paper context
def get_paper_context(chunk_id, context_chunks=2):
    """
    Get surrounding chunks from the same paper for more context
    """
    # Find the target chunk
    target_chunk = next((c for c in chunks if c['chunk_id'] == chunk_id), None)
    
    if not target_chunk:
        return None
    
    # Find all chunks from same paper
    paper_chunks = [c for c in chunks if c['source_id'] == target_chunk['source_id']]
    paper_chunks.sort(key=lambda x: x['chunk_index'])
    
    # Find target index
    target_idx = next(i for i, c in enumerate(paper_chunks) if c['chunk_id'] == chunk_id)
    
    # Get surrounding chunks
    start_idx = max(0, target_idx - context_chunks)
    end_idx = min(len(paper_chunks), target_idx + context_chunks + 1)
    
    return paper_chunks[start_idx:end_idx]

# Example: Get context for first result
if results1:
    print("Getting extended context for first result...")
    context = get_paper_context(results1[0]['chunk_id'], context_chunks=2)
    
    if context:
        print(f"\nExtended context ({len(context)} chunks):")
        for chunk in context:
            marker = ">>> " if chunk['chunk_id'] == results1[0]['chunk_id'] else "    "
            print(f"{marker}Chunk {chunk['chunk_index']}: {chunk['text'][:100]}...")
