# Juego de Tronos Libro 1 — RAG extremo + Generación de Imagen (Notebook autocontenido)

Este notebook está pensado para exprimir calidad al máximo:
- Retrieval híbrido (dense + BM25) con MMR.
- Expansión multi-consulta opcional.
- QA grounded con verificación/filtro de fidelidad.
- Planner de escena robusto + prompt visual avanzado.
- Generación multi-seed y selección automática por score semántico con CLIP.


In [None]:
!pip -q install -U pandas pyarrow beautifulsoup4 lxml faiss-cpu sentence-transformers transformers accelerate diffusers safetensors rank-bm25


In [None]:
from google.colab import drive
drive.mount('/content/drive')

EPUB_PATH = '/content/drive/MyDrive/juego_de_tronos.epub'


In [None]:
import json
import os
import re
import zipfile
from typing import Optional

import faiss
import numpy as np
import pandas as pd
import torch
from bs4 import BeautifulSoup
from diffusers import StableDiffusion3Pipeline
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM

QWEN_MODEL_ID = 'Qwen/Qwen3-30B-A3B-Thinking-2507-FP8'
EMBED_MODEL_ID = 'BAAI/bge-m3'
RERANKER_MODEL_ID = 'BAAI/bge-reranker-large'
SD3_MODEL_ID = 'stabilityai/stable-diffusion-3.5-large'
CLIP_MODEL_ID = 'openai/clip-vit-large-patch14'


In [None]:
def clean_text(text: str) -> str:
    text = re.sub(r"\r\n?", "\n", text)
    text = re.sub(r"\n\s*\n+", "\n\n", text)
    text = re.sub(r"[ \t]+", " ", text)
    return text.strip()


def extract_title_and_pov(text: str) -> tuple[Optional[str], Optional[str]]:
    lines = [ln.strip() for ln in text.split("\n") if ln.strip()]
    for line in lines[:20]:
        if re.match(r"^[A-ZÁÉÍÓÚÑÜ]+.*\(\d+\)$", line):
            return line, line.split("(")[0].strip()
    return None, None


def list_xhtml_text_files(zf: zipfile.ZipFile) -> list[str]:
    candidates = [f for f in zf.namelist() if f.lower().endswith((".xhtml", ".html"))]
    preferred = [f for f in candidates if '/text/' in f.lower() or '/texto/' in f.lower()]
    return sorted(preferred if preferred else candidates)


def extract_chapters(epub_path: str) -> pd.DataFrame:
    if not os.path.exists(epub_path):
        raise FileNotFoundError(f'No existe EPUB_PATH: {epub_path}')

    chapters = []
    with zipfile.ZipFile(epub_path, 'r') as zf:
        for file_name in list_xhtml_text_files(zf):
            soup = BeautifulSoup(zf.read(file_name), 'lxml')
            text = clean_text(soup.get_text('\n'))
            if len(text) < 800:
                continue
            title, pov = extract_title_and_pov(text)
            chapters.append({
                'chapter_id': len(chapters),
                'epub_file': file_name,
                'title': title,
                'pov': pov,
                'text': text,
                'n_chars': len(text),
            })
    return pd.DataFrame(chapters)


def chunk_text(text: str, chunk_size: int = 3500, overlap: int = 500):
    if chunk_size <= overlap:
        raise ValueError('chunk_size debe ser mayor que overlap')
    start = 0
    chunks = []
    while start < len(text):
        end = min(start + chunk_size, len(text))
        piece = text[start:end].strip()
        if piece:
            chunks.append((start, end, piece))
        start += (chunk_size - overlap)
    return chunks


def build_chunks(chapters_df: pd.DataFrame, chunk_size: int = 3500, overlap: int = 500) -> pd.DataFrame:
    rows = []
    for _, ch in chapters_df.iterrows():
        for i, (s, e, t) in enumerate(chunk_text(ch['text'], chunk_size, overlap)):
            rows.append({
                'chunk_id': f"{int(ch['chapter_id'])}_{i}",
                'chapter_id': int(ch['chapter_id']),
                'epub_file': ch['epub_file'],
                'title': ch['title'],
                'pov': ch['pov'],
                'start_char': s,
                'end_char': e,
                'text': t,
                'n_chars': len(t),
            })
    return pd.DataFrame(rows)


In [None]:
chapters_df = extract_chapters(EPUB_PATH)
chunks_df = build_chunks(chapters_df)

chapters_df.to_parquet('chapters.parquet', index=False)
chunks_df.to_parquet('chunks.parquet', index=False)
print('Capítulos:', len(chapters_df), '| Chunks:', len(chunks_df))


In [None]:
embedder = SentenceTransformer(EMBED_MODEL_ID)
reranker = CrossEncoder(RERANKER_MODEL_ID)


def embed_texts(texts, batch_size=32):
    vecs = embedder.encode(
        texts,
        batch_size=batch_size,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )
    return vecs.astype('float32')

# Dense index
emb = embed_texts(chunks_df['text'].tolist())
faiss_index = faiss.IndexFlatIP(emb.shape[1])
faiss_index.add(emb)

# Lexical index (BM25)
def bm25_tokenize(text: str):
    return re.findall(r"\w+", text.lower(), flags=re.UNICODE)

bm25_corpus = [bm25_tokenize(t) for t in chunks_df['text'].tolist()]
bm25 = BM25Okapi(bm25_corpus)

print('FAISS:', faiss_index.ntotal, '| BM25 docs:', len(bm25_corpus))


In [None]:
tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    QWEN_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map='auto',
)


def run_chat(messages, max_new_tokens=400, do_sample=False, temperature=0.0):
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=1.0,
        repetition_penalty=1.03,
    )
    new_tokens = out[0][inputs['input_ids'].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()


In [None]:
def expand_query(question: str, enable=True, n=3):
    if not enable:
        return [question]
    messages = [
        {
            'role': 'system',
            'content': 'Genera variantes cortas de búsqueda para recuperar pasajes de novela. Devuelve JSON con clave queries.',
        },
        {'role': 'user', 'content': f'Pregunta: {question} | n={n}'},
    ]
    raw = run_chat(messages, max_new_tokens=180)
    try:
        js = json.loads(raw[raw.find('{'):raw.rfind('}')+1])
        q = [question] + [x.strip() for x in js.get('queries', []) if isinstance(x, str)]
        return list(dict.fromkeys([x for x in q if x]))[:n+1]
    except Exception:
        return [question]


def minmax_norm(arr):
    arr = np.array(arr, dtype=np.float32)
    if len(arr) == 0:
        return arr
    mn, mx = float(arr.min()), float(arr.max())
    if mx - mn < 1e-8:
        return np.zeros_like(arr)
    return (arr - mn) / (mx - mn)


def mmr_select(candidates_df: pd.DataFrame, query_emb: np.ndarray, top_k=12, lambda_mult=0.75):
    if len(candidates_df) <= top_k:
        return candidates_df.reset_index(drop=True)

    cand_emb = np.vstack(candidates_df['dense_vec'].tolist())
    sim_query = cand_emb @ query_emb

    selected = []
    remaining = list(range(len(candidates_df)))

    while remaining and len(selected) < top_k:
        if not selected:
            best = max(remaining, key=lambda i: sim_query[i])
            selected.append(best)
            remaining.remove(best)
            continue

        def mmr_score(i):
            redundancy = max(float(cand_emb[i] @ cand_emb[j]) for j in selected)
            return lambda_mult * float(sim_query[i]) - (1 - lambda_mult) * redundancy

        best = max(remaining, key=mmr_score)
        selected.append(best)
        remaining.remove(best)

    return candidates_df.iloc[selected].reset_index(drop=True)


def hybrid_retrieve(question: str, top_k=12, faiss_k=120, bm25_k=120, alpha=0.6, use_query_expansion=True):
    queries = expand_query(question, enable=use_query_expansion, n=3)

    dense_scores = {}
    bm25_scores = {}
    query_emb_main = embed_texts([question], batch_size=1)[0]

    for q in queries:
        # Dense
        q_emb = embed_texts([q], batch_size=1)
        scores, idxs = faiss_index.search(q_emb, faiss_k)
        for s, idx in zip(scores[0].tolist(), idxs[0].tolist()):
            dense_scores[idx] = max(dense_scores.get(idx, -1e9), float(s))

        # Lexical
        bm = bm25.get_scores(bm25_tokenize(q))
        top_idx = np.argsort(bm)[::-1][:bm25_k]
        for idx in top_idx:
            bm25_scores[int(idx)] = max(bm25_scores.get(int(idx), -1e9), float(bm[idx]))

    union = sorted(set(dense_scores.keys()) | set(bm25_scores.keys()))
    if not union:
        return chunks_df.head(top_k).copy()

    d = minmax_norm([dense_scores.get(i, min(dense_scores.values()) if dense_scores else 0.0) for i in union])
    b = minmax_norm([bm25_scores.get(i, min(bm25_scores.values()) if bm25_scores else 0.0) for i in union])
    h = alpha * d + (1 - alpha) * b

    cand = chunks_df.iloc[union].copy().reset_index(drop=True)
    cand['dense_score'] = d
    cand['bm25_score'] = b
    cand['hybrid_score'] = h
    cand['dense_vec'] = [emb[i] for i in union]

    # Rerank
    pairs = [(question, t) for t in cand['text'].tolist()]
    cand['rerank_score'] = reranker.predict(pairs)
    cand = cand.sort_values(['rerank_score', 'hybrid_score'], ascending=False).head(max(top_k * 3, 24)).reset_index(drop=True)

    # MMR diversify
    cand = mmr_select(cand, query_emb_main, top_k=top_k, lambda_mult=0.75)

    return cand.drop(columns=['dense_vec'])


In [None]:
def build_context(passages_df: pd.DataFrame, max_chars_each: int = 1500) -> str:
    blocks = []
    for _, row in passages_df.iterrows():
        txt = row['text'][:max_chars_each].strip()
        blocks.append(f"[{row['chunk_id']}] ({row['pov']} | {row['title']})\n{txt}")
    return '\n\n'.join(blocks)


def answer_question(question: str, passages_df: pd.DataFrame) -> str:
    context = build_context(passages_df)
    messages = [
        {
            'role': 'system',
            'content': (
                "Eres experto en 'Juego de Tronos' (Libro 1). "
                "Responde únicamente con hechos sustentados en fragmentos. "
                "Formato obligatorio: 1) respuesta breve y clara, 2) viñetas con evidencia textual, "
                "3) referencias [chunk_id] por afirmación. "
                "Si no hay evidencia suficiente: 'No encontrado en los fragmentos proporcionados'."
            ),
        },
        {'role': 'user', 'content': f'Pregunta: {question}\n\nFragmentos:\n{context}'},
    ]
    return run_chat(messages, max_new_tokens=520)


def verify_answer(question: str, answer: str, passages_df: pd.DataFrame):
    context = build_context(passages_df, max_chars_each=1200)
    messages = [
        {
            'role': 'system',
            'content': 'Evalúa fidelidad factual. Devuelve SOLO JSON: {"faithful": bool, "issues": [str], "rewrite": str}',
        },
        {
            'role': 'user',
            'content': f'Pregunta: {question}\n\nRespuesta: {answer}\n\nContexto:\n{context}',
        },
    ]
    raw = run_chat(messages, max_new_tokens=260)
    try:
        js = json.loads(raw[raw.find('{'):raw.rfind('}')+1])
        faithful = bool(js.get('faithful', False))
        rewrite = js.get('rewrite', answer)
        issues = js.get('issues', [])
        if not faithful and isinstance(rewrite, str) and rewrite.strip():
            return rewrite.strip(), issues
        return answer, issues
    except Exception:
        return answer, ['No se pudo verificar automáticamente']


In [None]:
scene_schema = {
    'style': 'string',
    'subject': 'string',
    'setting': 'string',
    'time_of_day': 'day|night|dawn|dusk|unknown',
    'weather': 'string',
    'mood': 'string',
    'characters': [{'name': 'string', 'appearance': 'string', 'clothing': 'string'}],
    'action': 'string',
    'camera': 'string',
    'palette': 'string',
    'important_objects': ['string'],
    'avoid': ['string'],
}
def _balanced_json_substring(raw: str) -> str | None:
    start = raw.find('{')
    if start < 0:
        return None
    depth, in_str, esc = 0, False, False
    for i in range(start, len(raw)):
        ch = raw[i]
        if in_str:
            if esc:
                esc = False
            elif ch == '\\':
                esc = True
            elif ch == '"':
                in_str = False
            continue
        if ch == '"':
            in_str = True
            continue
        if ch == '{':
            depth += 1
        elif ch == '}':
            depth -= 1
            if depth == 0:
                return raw[start:i+1]
    return None
def extract_first_json(raw: str) -> dict | None:
    raw = (raw or '').strip()
    if not raw:
        return None
    for candidate in [raw, raw.replace('```json', '').replace('```JSON', '').replace('```', '').strip()]:
        try:
            obj = json.loads(candidate)
            if isinstance(obj, dict):
                return obj
        except Exception:
            pass
        sub = _balanced_json_substring(candidate)
        if sub:
            try:
                obj = json.loads(sub)
                if isinstance(obj, dict):
                    return obj
            except Exception:
                pass
    return None
def _scene_fallback(question: str, answer: str, passages_df: pd.DataFrame) -> dict:
    top = passages_df.iloc[0] if len(passages_df) else None
    loc = 'Westeros medieval fantasy setting'
    if top is not None:
        t = str(top.get('title') or '').strip()
        p = str(top.get('pov') or '').strip()
        if t or p:
            loc = f'{t} ({p})'
    return {
        'style': 'cinematic realism, epic medieval fantasy',
        'subject': question,
        'setting': loc,
        'time_of_day': 'unknown',
        'weather': 'moody atmosphere',
        'mood': 'dramatic',
        'characters': [],
        'action': re.sub(r'\s+', ' ', answer)[:220],
        'camera': 'dynamic medium shot, 35mm lens',
        'palette': 'cold desaturated tones',
        'important_objects': [],
        'avoid': ['tv actors', 'celebrity likeness', 'modern clothing', 'watermark', 'text'],
    }
def normalize_scene(scene: dict, question: str, answer: str, passages_df: pd.DataFrame) -> dict:
    base = _scene_fallback(question, answer, passages_df)
    if not isinstance(scene, dict):
        return base
    for k, v in base.items():
        if k not in scene or scene[k] in (None, ''):
            scene[k] = v
    for k in ['characters', 'important_objects', 'avoid']:
        if not isinstance(scene.get(k), list):
            scene[k] = base[k]
    return scene
def plan_scene(question: str, answer: str, passages_df: pd.DataFrame, debug=False) -> dict:
    context = build_context(passages_df, max_chars_each=1200)
    messages = [
        {
            'role': 'system',
            'content': (
                'You are a cinematic art director. Return ONLY one valid JSON object. '
                'Ground details in context. No TV actor names.'
            ),
        },
        {
            'role': 'user',
            'content': (
                f'Question: {question}\n\nAnswer: {answer}\n\nContext:\n{context}\n\n'
                f'Schema: {json.dumps(scene_schema, ensure_ascii=False)}\n'
                'Output strictly JSON only.'
            ),
        },
    ]
    raw = run_chat(messages, max_new_tokens=420)
    scene = extract_first_json(raw)
    if scene is None:
        fixer = [
            {'role': 'system', 'content': 'Convert content into exactly one valid JSON object.'},
            {'role': 'user', 'content': f'Schema: {json.dumps(scene_schema, ensure_ascii=False)}\n\nContent:\n{raw}'},
        ]
        raw_fix = run_chat(fixer, max_new_tokens=260)
        scene = extract_first_json(raw_fix)
        if debug:
            print(raw[:1000])
            print(raw_fix[:1000])
    return normalize_scene(scene, question, answer, passages_df)
def scene_to_prompt(scene: dict) -> tuple[str, str]:
    cinematic_booster = 'masterpiece, cinematic still, ultra detailed, volumetric light, film grain, sharp focus'
    chars = []
    for c in (scene.get('characters') or [])[:4]:
        desc = ', '.join([x for x in [c.get('name'), c.get('appearance'), c.get('clothing')] if x])
        if desc:
            chars.append(desc)
    parts = [
        cinematic_booster,
        scene.get('style', ''),
        f"subject: {scene.get('subject', '')}",
        scene.get('action', ''),
        f"setting: {scene.get('setting', '')}",
        f"time: {scene.get('time_of_day', '')}",
        f"weather: {scene.get('weather', '')}",
        f"mood: {scene.get('mood', '')}",
        f"palette: {scene.get('palette', '')}",
        f"characters: {'; '.join(chars)}" if chars else '',
        f"camera: {scene.get('camera', '')}",
        'props: ' + ', '.join(scene.get('important_objects', [])) if scene.get('important_objects') else '',
    ]
    prompt = ', '.join([p.strip() for p in parts if p and str(p).strip()])
    negative = (scene.get('avoid') or []) + [
        'worst quality, low quality, blurry, jpeg artifacts',
        'text, watermark, logo, subtitles',
        'modern objects, modern clothes, cars, smartphones',
        'tv actors, celebrity face',
        'extra fingers, malformed hands, bad anatomy',
    ]
    return prompt, ', '.join(dict.fromkeys(negative))


In [None]:
image_pipe = StableDiffusion3Pipeline.from_pretrained(
    SD3_MODEL_ID,
    torch_dtype=torch.bfloat16,
).to('cuda')
image_pipe.enable_attention_slicing()

# CLIP para elegir mejor imagen entre varias seeds
from transformers import CLIPProcessor, CLIPModel
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to('cuda')
clip_proc = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)


def clip_text_image_score(prompt: str, image):
    inputs = clip_proc(text=[prompt], images=image, return_tensors='pt', padding=True).to('cuda')
    with torch.no_grad():
        outputs = clip_model(**inputs)
        text_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
        image_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
        score = (text_emb * image_emb).sum().item()
    return score


def generate_best_image(prompt: str, negative: str, seeds=(7, 13, 23), width=1024, height=1024, steps=32, guidance=6.5):
    candidates = []
    for s in seeds:
        gen = torch.Generator(device='cuda').manual_seed(int(s))
        img = image_pipe(
            prompt=prompt,
            negative_prompt=negative,
            num_inference_steps=steps,
            guidance_scale=guidance,
            width=width,
            height=height,
            generator=gen,
        ).images[0]
        score = clip_text_image_score(prompt, img)
        candidates.append((score, s, img))

    candidates.sort(key=lambda x: x[0], reverse=True)
    best_score, best_seed, best_img = candidates[0]
    return best_img, {'best_seed': best_seed, 'best_clip_score': best_score, 'all_scores': [(s, float(sc)) for sc, s, _ in candidates]}


In [None]:
def ask_and_draw_insane(question: str, top_k=14, faiss_k=140, bm25_k=140, alpha=0.6, query_expansion=True, debug=False):
    passages = hybrid_retrieve(
        question,
        top_k=top_k,
        faiss_k=faiss_k,
        bm25_k=bm25_k,
        alpha=alpha,
        use_query_expansion=query_expansion,
    )

    answer_raw = answer_question(question, passages)
    answer_final, faithfulness_issues = verify_answer(question, answer_raw, passages)

    scene = plan_scene(question, answer_final, passages, debug=debug)
    prompt, negative = scene_to_prompt(scene)

    image, image_meta = generate_best_image(prompt, negative, seeds=(7, 17, 29, 43), steps=34, guidance=6.5)

    return {
        'question': question,
        'answer_raw': answer_raw,
        'answer_final': answer_final,
        'faithfulness_issues': faithfulness_issues,
        'passages': passages,
        'scene': scene,
        'prompt': prompt,
        'negative_prompt': negative,
        'image_meta': image_meta,
        'image': image,
    }


In [None]:
result = ask_and_draw_insane('¿Cómo escapó Tyrion del Nido de Águilas?', debug=True)
print('RESPUESTA FINAL:\n', result['answer_final'])
print('\nIssues fidelidad:', result['faithfulness_issues'])
print('\nPrompt final:\n', result['prompt'])
print('\nImage meta:', result['image_meta'])
result['image']
