# Lab 2: Retrieval-Augmented Generation (RAG)

Objective: Build a minimal RAG pipeline using `transformers` with an encoder-only embedder (BERT) and a decoder-only generator (Llama 1B). We focus on understanding internal mechanisms, not efficiency.

This lab includes:
- A brief RAG overview
- Indexing Padua documents with BERT embeddings (token-based chunking)
- Retrieving top-n similar chunks by cosine similarity
- Generating short answers with Llama using retrieved context
- Exercises matching course goals

## Setup and Imports
We reuse patterns from Lab 1: local `cache_dir`, simple pooling for embeddings, and an optional Llama generation flag.

In [None]:
import os, json
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

print('Transformers version:', __import__('transformers').__version__)
print('Torch version:', torch.__version__)

BASE_DIR = os.path.join('lab2')
DATA_DIR = os.path.join(BASE_DIR, 'data')
CACHE_DIR = os.path.join(BASE_DIR, 'models_cache')
os.makedirs(CACHE_DIR, exist_ok=True)

RUN_LLAMA = False  # toggle to True to actually generate answers

# Model IDs (adjust if needed)
BERT_ID = 'bert-base-uncased'
LLAMA_PRIMARY = 'meta-llama/Llama-3.2-1B-Instruct'
LLAMA_FALLBACK = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'

## Load Embedder (BERT) and Tokenizer
We use BERT as an encoder-only embedder. Embeddings are computed via mean pooling of the last hidden states across tokens.

In [None]:
tokenizer_bert = AutoTokenizer.from_pretrained(BERT_ID, cache_dir=CACHE_DIR)
model_bert = AutoModel.from_pretrained(BERT_ID, cache_dir=CACHE_DIR)
print('BERT special tokens:', tokenizer_bert.special_tokens_map)
print('BERT dtype:', next(model_bert.parameters()).dtype)
print('BERT device:', next(model_bert.parameters()).device)

## Load Generator (Llama 1B)
We attempt to load a 1B chat model. If the primary model is unavailable, we fall back to TinyLlama 1.1B. Generation is optional and requires resources (GPU recommended).

In [None]:
def load_llama(model_id):
    tok = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_id, cache_dir=CACHE_DIR, torch_dtype=torch.float16, device_map='auto'
    )
    return tok, mdl

try:
    tokenizer_llama, model_llama = load_llama(LLAMA_PRIMARY)
    LLAMA_ID = LLAMA_PRIMARY
except Exception as e:
    print('Fallback to TinyLlama due to:', type(e).__name__, str(e)[:300])
    tokenizer_llama, model_llama = load_llama(LLAMA_FALLBACK)
    LLAMA_ID = LLAMA_FALLBACK

print('Llama model:', LLAMA_ID)
print('Has chat template?', bool(getattr(tokenizer_llama, 'chat_template', None)))

## Load Documents and Prepare Chunking
We index 25 Padua documents. We split each document into token-based chunks using the BERT tokenizer (e.g., 128 tokens per chunk).

In [None]:
docs_path = os.path.join(DATA_DIR, 'kb_docs.json')
with open(docs_path, 'r', encoding='utf-8') as f:
    DOCS = json.load(f)
len(DOCS), DOCS[0]['title']

In [None]:
# Chunking utility: split text into chunks capped by max_tokens
def chunk_text(text, tokenizer, max_tokens=128):
    tokens = tokenizer(text, return_tensors='pt', truncation=False, add_special_tokens=False)
    input_ids = tokens['input_ids'][0].tolist()
    chunks = []
    for i in range(0, len(input_ids), max_tokens):
        sub_ids = input_ids[i:i+max_tokens]
        chunk_txt = tokenizer.decode(sub_ids)
        chunks.append({'text': chunk_txt, 'token_count': len(sub_ids)})
    return chunks

# Build KB chunks
KB_CHUNKS = []
for d in DOCS:
    chs = chunk_text(d['text'], tokenizer_bert, max_tokens=128)
    for idx, ch in enumerate(chs):
        KB_CHUNKS.append({
            'doc_id': d['id'],
            'chunk_id': idx,
            'title': d['title'],
            'text': ch['text'],
            'token_count': ch['token_count']
        })
len(KB_CHUNKS)

## Compute Embeddings and Save KB Index
We compute mean-pooled BERT embeddings for each chunk and save a JSON KB index containing text, embedding, and embedding_dim.

In [None]:
def embed_texts(texts, tokenizer, model, batch_size=8):
    embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        enc = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            out = model(**enc)
            pooled = out.last_hidden_state.mean(dim=1).cpu().numpy()
            embs.extend(pooled.tolist())
    return np.array(embs)

chunk_texts = [c['text'] for c in KB_CHUNKS]
EMB = embed_texts(chunk_texts, tokenizer_bert, model_bert, batch_size=8)
EMB.shape

In [None]:
# Save KB index
kb_index = []
for i, c in enumerate(KB_CHUNKS):
    emb = EMB[i].tolist()
    kb_index.append({
        'doc_id': c['doc_id'],
        'chunk_id': c['chunk_id'],
        'title': c['title'],
        'text': c['text'],
        'embedding': emb,
        'embedding_dim': len(emb),
        'token_count': c['token_count']
    })
index_path = os.path.join(DATA_DIR, 'kb_index.json')
with open(index_path, 'w', encoding='utf-8') as f:
    json.dump(kb_index, f, ensure_ascii=False, indent=2)
len(kb_index), kb_index[0]['embedding_dim']

## Retrieval: Top-n Similar Chunks
We compute cosine similarity between the query embedding and KB chunk embeddings and return the top-n chunks.

In [None]:
def cosine_similarity_matrix(A, B):
    A = A / np.clip(np.linalg.norm(A, axis=1, keepdims=True), 1e-12, None)
    B = B / np.clip(np.linalg.norm(B, axis=1, keepdims=True), 1e-12, None)
    return A @ B.T

def embed_query(query):
    enc = tokenizer_bert([query], return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        out = model_bert(**enc)
        emb = out.last_hidden_state.mean(dim=1).cpu().numpy()
    return emb

def retrieve_top_n(query, kb_emb, kb_meta, n=3):
    q = embed_query(query)
    sims = cosine_similarity_matrix(q, kb_emb)[0]
    top_idx = np.argsort(-sims)[:n]
    return [(float(sims[i]), kb_meta[i]) for i in top_idx]

# Load queries
queries_path = os.path.join(DATA_DIR, 'queries.json')
with open(queries_path, 'r', encoding='utf-8') as f:
    QUERIES = json.load(f)

# Demo retrieval on first query
q0 = QUERIES[0]['query']
top_chunks = retrieve_top_n(q0, EMB, kb_index, n=3)
for sim, meta in top_chunks:
    print(f
, 'title=', meta['title'], '
text=', meta['text'][:180], '
')

## Generation: Answer from Retrieved Context
We instruct Llama to answer concisely using only the provided context. If the answer is not in the context, the model should say it does not know.

In [None]:
def build_context(chunks):
    ctx = ''
    for _, meta in chunks:
        ctx += f