# 🔍 RAG Demo Notebook: Retrieval-Augmented Generation
This notebook walks through:
- Embedding a document corpus
- Retrieving top matches per query
- Assembling prompts with retrieval
- Generating grounded answers
- Tracking citation sources and chat history

In [None]:
from pathlib import Path
import json, faiss
import pandas as pd
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from IPython.display import Markdown, display

## 📁 Step 1: Load and Chunk Documents

In [None]:
corpus_path = Path("data/corpus")
files = list(corpus_path.glob("*.md"))
chunks, sources = [], []

for file in files:
    text = file.read_text()
    for para in text.split("\n\n"):
        if len(para.split()) > 20:
            chunks.append(para.strip())
            sources.append(file.name)

print(f"Loaded {len(chunks)} content chunks from {len(files)} files.")

## 🔢 Step 2: Embed and Index Chunks

In [None]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedder.encode(chunks, show_progress_bar=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

## 🔍 Step 3: Define RAG Query Function

In [None]:
def rag_query(query, k=3):
    q_embed = embedder.encode([query])
    _, indices = index.search(q_embed, k)
    top_chunks = [chunks[i] for i in indices[0]]
    top_sources = [sources[i] for i in indices[0]]
    return top_chunks, top_sources

## 🧠 Step 4: Assemble Prompt with Citations

In [None]:
def format_prompt(chunks, query):
    context = "\n\n".join(chunks)
    return f"Answer the following based only on the provided context.\n\nContext:\n{context}\n\nQuestion: {query}\nAnswer:"

## 🤖 Step 5: Generate Grounded Answer

In [None]:
rag_model = pipeline("text-generation", model="./checkpoints/my_model")

query = "How does the coroutine decorator work?"
chunks_used, cited_sources = rag_query(query)
prompt = format_prompt(chunks_used, query)
answer = rag_model(prompt, max_new_tokens=200)[0]['generated_text']

display(Markdown(f"**Query:** {query}"))
display(Markdown(f"**Answer:**\n\n{answer}"))
display(Markdown(f"**Citations:** {', '.join(set(cited_sources))}"))

## 💬 (Optional) Chat History Tracking

In [None]:
chat_history = []

def ask_rag(query):
    chunks_used, sources_used = rag_query(query)
    full_prompt = format_prompt(chunks_used, query)
    response = rag_model(full_prompt, max_new_tokens=150)[0]['generated_text']
    chat_history.append({
        'query': query,
        'response': response.strip(),
        'sources': sources_used
    })
    display(Markdown(f"**You:** {query}"))
    display(Markdown(f"**Bot:** {response.strip()}\n\n**From:** {', '.join(set(sources_used))}"))