In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
RAG pipeline robusto (solo URLs -> scraping -> embeddings Nomic -> FAISS -> LLM)
- Usa SOLO URLs en columna `source`.
- Scraping robusto + limpieza HTML.
- Embeddings por batch con reintentos (gestiona 502).
- FAISS IndexFlatIP sobre vectores L2-normalizados (cosine sim).
- Generación de definiciones por batch con un pipeline de transformers.
"""

import os
import re
import time
import csv
import json
import math
import requests
import pandas as pd
import numpy as np
from tqdm import tqdm
from bs4 import BeautifulSoup
import faiss
from transformers import pipeline
from nomic import embed
import torch
import gc
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

# ---------------- CONFIG ----------------
HF_TOKEN = "hf_token" 
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF_TOKEN

MODEL_EMBEDDINGS = "nomic-embed-text-v1.5"
MODEL_LLM = "mistralai/Mistral-7B-v0.1" 

CSV_INPUT = "/home/jovyan/lrec_2026/data_10/civil_law/glossary_civil_law_info_semantica_es.csv"
OUTPUT_CSV = "/home/jovyan/lrec_2026/RAG/celex_rag_definitions_civil_law_mistral.csv"

# Ajustables según recursos
SCRAPE_TIMEOUT = 12
SCRAPE_RETRIES = 3
SCRAPE_BACKOFF = 1.5

EMB_BATCH_SIZE = 8         
EMB_RETRIES = 4
EMB_BACKOFF = 2.0

TOP_K = 3
LLM_BATCH_SIZE = 4        
OUTPUT_LANG = "spanish"

# ---------------- UTILIDADES ----------------
def make_requests_session(retries=3, backoff_factor=0.5, status_forcelist=(429, 500, 502, 503, 504)):
    s = requests.Session()
    retry = Retry(total=retries, read=retries, connect=retries,
                  backoff_factor=backoff_factor, status_forcelist=status_forcelist,
                  allowed_methods=False)  # allow all methods
    adapter = HTTPAdapter(max_retries=retry)
    s.mount('http://', adapter)
    s.mount('https://', adapter)
    # headers básicos (imitamos navegador)
    s.headers.update({
        "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0 Safari/537.36",
        "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
        "Accept-Language": "es-ES,es;q=0.9,en;q=0.8"
    })
    return s

session = make_requests_session(retries=SCRAPE_RETRIES, backoff_factor=SCRAPE_BACKOFF)

def clean_text(text):
    if text is None:
        return ""
    # eliminar múltiples espacios, tabs y newlines
    text = re.sub(r'\s+', ' ', str(text))
    return text.strip()

def extract_urls_from_field(source_field):
    # extrae urls http/https; devuelve lista o []
    if pd.isna(source_field):
        return []
    urls = re.findall(r'https?://[^\s"\'<>]+', str(source_field))
    # limpiar caracteres finales comunes
    clean_urls = []
    for u in urls:
        u = u.strip().rstrip('.,;)">')
        clean_urls.append(u)
    return list(dict.fromkeys(clean_urls))  # deduplicar manteniendo orden

def is_pdf_response(resp):
    ctype = resp.headers.get('Content-Type', '').lower()
    return 'pdf' in ctype

def get_main_text_from_html(html):
    """
    Heurística para extraer el texto "principal":
    - preferir <article>, <main>
    - si no, buscar el contenedor con mayor suma de longitud de <p>
    - limpiar nav/footer, scripts, estilos
    """
    soup = BeautifulSoup(html, "html.parser")
    for tag in soup(["script", "style", "noscript", "header", "footer", "nav", "svg", "img", "iframe"]):
        tag.decompose()
    # prefer <article> or <main>
    article = soup.find('article')
    if article:
        txt = article.get_text(separator=' ', strip=True)
        if len(txt) > 100:
            return clean_text(txt)
    main = soup.find('main')
    if main:
        txt = main.get_text(separator=' ', strip=True)
        if len(txt) > 100:
            return clean_text(txt)
    # fallback: buscar el contenedor con mayor cantidad de texto en <p>
    candidates = soup.find_all(['div', 'section', 'body'], recursive=True)
    best_text = ""
    best_len = 0
    for c in candidates:
        ps = c.find_all('p')
        if not ps:
            continue
        combined = " ".join(p.get_text(separator=' ', strip=True) for p in ps)
        l = len(combined)
        if l > best_len:
            best_len = l
            best_text = combined
    if best_len >= 80:
        return clean_text(best_text)
    # último recurso: todo el body
    body = soup.get_text(separator=' ', strip=True)
    return clean_text(body)

def scrape_url(url, session=session, timeout=SCRAPE_TIMEOUT):
    """
    Robusto: reintentos gestionados por session. Devuelve texto limpio o None.
    Omite PDFs (podrías descargar/ocr si lo necesitas).
    """
    try:
        resp = session.get(url, timeout=timeout)
    except Exception as e:
        return None
    if resp is None or resp.status_code != 200:
        return None
    # ignorar PDFs / binarios
    if is_pdf_response(resp):
        return None
    text = get_main_text_from_html(resp.text)
    return text if text and len(text) > 80 else None

# ---------------- EMBEDDINGS (NOMIC) con retries por batch ----------------
def get_nomic_embeddings_batch(texts, model=MODEL_EMBEDDINGS, retries=EMB_RETRIES, backoff=EMB_BACKOFF):
    """
    texts: list[str]
    devuelve np.array(shape=(len(texts), dim), dtype=float32) o lanza excepción.
    Implementa retries exponenciales en caso de fallos 502/5xx.
    """
    attempt = 0
    while attempt <= retries:
        try:
            res = embed.text(texts=texts, model=model, task_type="search_document")
            emb = np.array(res["embeddings"], dtype="float32")
            return emb
        except Exception as e:
            attempt += 1
            # intentar si es un error transitorio
            wait = backoff ** attempt
            print(f"⚠️ Embeddings error (attempt {attempt}/{retries}). Retrying in {wait:.1f}s. Error: {e}")
            time.sleep(wait)
    raise RuntimeError(f"Failed to get embeddings after {retries} retries.")

def l2_normalize(a, axis=1, eps=1e-10):
    norms = np.linalg.norm(a, axis=axis, keepdims=True)
    return a / (norms + eps)

# ---------------- CARGA CSV y extracción URLs ----------------
print("📥 Cargando CSV de entrada...")
df = pd.read_csv(CSV_INPUT, sep=';', encoding='utf-8', engine='python', on_bad_lines='skip')

# Extraer URLs de 'source' (solo URLs válidas)
df['urls_scrapables'] = df['source'].apply(extract_urls_from_field)
df_scrap = df[df['urls_scrapables'].map(len) > 0].reset_index(drop=True)
print(f"✅ Filas con URLs scrapables: {len(df_scrap)}")

# ---------------- SCRAPING y colección de documentos (por término) ----------------
all_docs = []       # textos (strings)
terms_scrap = []    # término correspondiente
sources_scrap = []  # lista de urls originales usadas

print("📄 Haciendo scraping por término (headers + retries)...")
for idx, row in tqdm(df_scrap.iterrows(), total=len(df_scrap), desc="Scraping términos"):
    term = row['term']
    urls = row['urls_scrapables']
    term_texts = []
    used_urls = []
    for url in urls:
        try:
            t = scrape_url(url)
        except Exception:
            t = None
        if t:
            term_texts.append(t)
            used_urls.append(url)
        # pequeño sleep para no saturar
        time.sleep(0.2)
    if term_texts:
        # concatenar los textos de las URLs válidas para ese término
        joined = "\n\n".join(term_texts)
        # truncar para no pasar texto enorme a embeddings (opcional)
        MAX_CHARS = 100000
        if len(joined) > MAX_CHARS:
            joined = joined[:MAX_CHARS]
        all_docs.append(joined)
        terms_scrap.append(term)
        sources_scrap.append(json.dumps(used_urls, ensure_ascii=False))  # para guardar en CSV
print(f"✅ Scraping completado. Documentos válidos obtenidos: {len(all_docs)}")

if len(all_docs) == 0:
    raise SystemExit("No se obtuvieron documentos scrapeados. Revisa las URLs en tu CSV.")

# ---------------- GENERAR EMBEDDINGS EN BATCHES y construir matriz completa ----------------
print("🔹 Generando embeddings en batches...")
embeddings_list = []
for i in tqdm(range(0, len(all_docs), EMB_BATCH_SIZE), desc="Embeddings batches"):
    batch_texts = all_docs[i:i+EMB_BATCH_SIZE]
    try:
        emb_batch = get_nomic_embeddings_batch(batch_texts)
        emb_batch = l2_normalize(emb_batch)
        embeddings_list.append(emb_batch)
    except Exception as e:
        print(f"⚠️ Error en embeddings batch para índices {i}-{i+len(batch_texts)-1}: {e}")
        # continuar con los que sí funcionaron
        continue
    # limpieza de memoria
    torch.cuda.empty_cache()
    gc.collect()

if len(embeddings_list) == 0:
    raise RuntimeError("No se pudieron generar embeddings para ningún documento.")

embeddings = np.vstack(embeddings_list).astype('float32')
print(f"✅ Embeddings generados. Shape: {embeddings.shape}")

# ---------------- FAISS (IndexFlatIP sobre vectores normalizados) ----------------
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
print("✅ Índice FAISS creado y poblado")

# ---------------- RAG: función de recuperación ----------------
def rag_retrieve_by_term(term, top_k=TOP_K):
    # Query text: usar solo término (podrías incluir metadata si quieres)
    query_text = f"Término: {term}"
    q_emb = get_nomic_embeddings_batch([query_text])
    q_emb = l2_normalize(q_emb)
    D, I = index.search(q_emb, top_k)
    retrieved_docs = []
    retrieved_sources = []
    for idx in I[0]:
        if idx < len(all_docs):
            retrieved_docs.append(all_docs[idx])
            retrieved_sources.append(sources_scrap[idx])
    return retrieved_docs, retrieved_sources

# ---------------- CARGAR LLM (transformers pipeline) ----------------
print("🔹 Cargando modelo LLM (transformers pipeline)...")
llm_pipeline = pipeline(
    "text-generation",
    model=MODEL_LLM,
    tokenizer=MODEL_LLM,
    device_map="auto",
    torch_dtype="auto",
    max_new_tokens=200,
    temperature=0.2,
    repetition_penalty=1.05
)
print("✅ LLM cargado")

# ---------------- PROMPT (Spanish) ----------------
PROMPT_TEMPLATE = """Eres un experto en terminología.
Usando únicamente la información semántica proporcionada en el contexto, escribe una definición precisa y concisa en español para el término: "{term}".

Contexto (documentos recuperados):
{context}

Definición:
"""

# ---------------- GENERACIÓN DE DEFINICIONES EN BATCH ----------------
print("🧠 Generando definiciones (RAG + LLM) por batches...")
definitions = []
out_terms = []
out_sources = []

for i in tqdm(range(0, len(terms_scrap), LLM_BATCH_SIZE), desc="Generando batches LLM"):
    batch_terms = terms_scrap[i:i+LLM_BATCH_SIZE]
    for term in batch_terms:
        try:
            retrieved_docs, retrieved_sources = rag_retrieve_by_term(term, top_k=TOP_K)
            context = "\n\n---\n\n".join(retrieved_docs).strip()
            if not context:
                context = "No hay información semántica disponible."
            prompt = PROMPT_TEMPLATE.format(term=term, context=context)
            out = llm_pipeline(prompt)[0]["generated_text"]
            # Attempt to cut after the "Definición:" token
            if "Definición:" in out:
                definition = out.split("Definición:")[-1].strip()
            elif "Definition:" in out:
                definition = out.split("Definition:")[-1].strip()
            else:
                # fallback: whole output after the prompt
                definition = out.replace(prompt, "").strip()
            definition = definition.split("\n")[0].strip()
        except Exception as e:
            definition = f"ERROR: {e}"
            retrieved_sources = []
        definitions.append(definition)
        out_terms.append(term)
        out_sources.append(json.dumps(retrieved_sources, ensure_ascii=False))
    # limpiar memoria entre batches
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    gc.collect()

# ---------------- GUARDAR CSV ----------------
df_out = pd.DataFrame({
    "term": out_terms,
    "definition": definitions,
    "retrieved_sources": out_sources
})

df_out.to_csv(OUTPUT_CSV, sep=';', index=False, encoding='utf-8', quoting=csv.QUOTE_ALL, escapechar='\\')
print(f"\n✅ Definiciones generadas y guardadas en: {OUTPUT_CSV}")
