# 05. RAG Retrieval and Agent

## The Finale: Building the Agent
This is where everything comes together. We will build a simple AI agent that:
1.  Takes your question.
2.  **Retrieves** relevant information from our FAISS index.
3.  **Generates** a natural language answer using an LLM (Large Language Model).

## Step 1: Install Libraries
We need `transformers` to run the LLM.

In [None]:
%pip install faiss-cpu sentence-transformers transformers

## Step 2: Load Resources
We need to load everything we built in previous steps:
- The FAISS Index
- The ID Mapping
- The Embedding Model (to convert your question into numbers)
- The LLM (to write the answer)

In [None]:
import faiss
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# 1. Load Index and Mapping
index_path = "/dbfs/FileStore/rag_data/faiss_index.bin"
mapping_path = "/dbfs/FileStore/rag_data/id_mapping.pickle"

index = faiss.read_index(index_path)

with open(mapping_path, "rb") as f:
    id_mapping = pickle.load(f)

# 2. Load Embedding Model
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# 3. Load LLM (Flan-T5 Small)
# We use 'google/flan-t5-small' because it fits in the memory of the Free Edition.
# It's not the smartest model, but it proves the concept.
llm_model_name = "google/flan-t5-small" 
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)

# Create a pipeline for text generation
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)

print("All systems go! Resources loaded.")

## Step 3: Define Retrieval Function
This function takes a user's question and finds the most relevant text chunks.

In [None]:
def retrieve_context(query, k=3):
    """
    Searches for the top 'k' chunks relevant to the query.
    """
    # 1. Convert query to vector
    query_vector = embed_model.encode([query]).astype("float32")
    
    # 2. Search FAISS index
    # distances: how close the match is
    # indices: the internal ID of the match
    distances, indices = index.search(query_vector, k)
    
    # 3. Get the real Chunk IDs
    retrieved_ids = [id_mapping[i] for i in indices[0] if i != -1]
    
    if not retrieved_ids:
        return []
    
    # 4. Fetch the actual text from our Delta table
    # We use Spark SQL to get the text for these IDs
    ids_str = ",".join([str(id) for id in retrieved_ids])
    
    df_context = spark.sql(f"""
        SELECT chunk_text, source_file 
        FROM rag_demo.gold_embeddings 
        WHERE chunk_id IN ({ids_str})
    """)
    
    return df_context.collect()

# Let's test it!
print("Testing Retrieval...")
results = retrieve_context("What is Delta Lake?")
for row in results:
    print(f"Found in {row.source_file}: {row.chunk_text[:100]}...")

## Step 4: Define the RAG Agent
This function combines Retrieval + Generation.

In [None]:
def rag_agent(query):
    print(f"User Query: {query}")
    print("Thinking... (Retrieving context)")
    
    # 1. Retrieve Context
    context_rows = retrieve_context(query)
    
    # Combine all retrieved text into one big string
    context_text = "\n\n".join([row.chunk_text for row in context_rows])
    
    # 2. Create the Prompt
    # We tell the LLM exactly what to do
    prompt = f"""
    Answer the question based on the context below. If the answer is not in the context, say "I don't know".
    
    Context:
    {context_text}
    
    Question:
    {query}
    
    Answer:
    """
    
    # 3. Generate Answer
    response = generator(prompt)
    return response[0]['generated_text']

## Step 5: Ask Questions!
Now you can ask your RAG agent questions about the documents you uploaded.

In [None]:
# Example 1
answer1 = rag_agent("Explain Delta Lake architecture")
print(f"\nAgent Answer: {answer1}\n")
print("-" * 50)

In [None]:
# Example 2
answer2 = rag_agent("What is Databricks?")
print(f"\nAgent Answer: {answer2}\n")