# Generación del dataset para hacer inferencia con el Discriminador

Este notebook genera textos sintéticos estilo Shakespeare utilizando el modelo
Mistral 7B fine-tuneado con LoRA.

Se generan múltiples textos combinando distintos prompts, con el objetivo de construir un dataset de benchmark que luego se utiliza para evaluación automática (clasificador, perplexity y métricas finales).

## 1. Imports y rutas.

In [None]:
import os
import torch
from pathlib import Path

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

BASE_MODEL_DIR = "/content/drive/MyDrive/StoryWriter/Modelo_FineTuning/mistral-7b-instruct-v0.3"
LORA_DIR       = "/content/drive/MyDrive/StoryWriter/Modelo_FineTuning/mistral-finetuneado(lora)"
OUTPUT_DIR     = "/content/drive/MyDrive/StoryWriter/Data/Benchmark_data/mistral_finetune"

OUTPUT_DIR = Path(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

# Generación
N_SAMPLES_PER_COMBO = 20
MAX_NEW_TOKENS = 700
BASE_SEED = 1000

## 2. Prompts de generación.

In [None]:
BASIC_PROMPT = """
Write a single paragraph between 150 and 300 words in the style of
Shakespeare's stories. The paragraph must be original,
not copied, and self-contained.
"""

BETTER_PROMPT = """
You are an expert writer imitating William Shakespeare.

Write one single self-contained paragraph between 150 and 300 words in Early Modern English,
in the style of Shakespeare’s plays and sonnets. The paragraph must be original, not copied,
and should use iambic or quasi-iambic rhythm, archaic pronouns (thee, thou, thy), and
elevated metaphors.

Avoid copying any real Shakespeare sentences; the text must be entirely new.
"""

PROMPTS = {
    "basic": BASIC_PROMPT,
    "better": BETTER_PROMPT,
}

## 3. Carga del tokenizer y modelos.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 512

# Modelo base
model_base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_DIR,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device == "cuda" else None,
)
model_base.eval()

# Modelo con LoRA
model_lora = PeftModel.from_pretrained(model_base, LORA_DIR)
model_lora.eval()

MODELS = {
    "lora": model_lora,
}

## 4. Función de generación de texto.

In [None]:
def generate_text(model, prompt, max_new_tokens=700, seed=None):
    """
    Genera texto a partir de un prompt usando sampling.
    """
    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.9,
            top_p=0.9,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )

    return tokenizer.decode(output[0], skip_special_tokens=True)

## 5. Generación y guardado de textos.

In [None]:
for model_name, model in MODELS.items():
    for prompt_name, prompt in PROMPTS.items():
        print(f"=== Modelo: {model_name} | Prompt: {prompt_name} ===")

        for i in range(N_SAMPLES_PER_COMBO):
            seed = BASE_SEED + i
            text = generate_text(
                model,
                prompt,
                max_new_tokens=MAX_NEW_TOKENS,
                seed=seed
            )

            filename = f"{model_name}_{prompt_name}_{i:02d}.txt"
            out_path = OUTPUT_DIR / filename

            with open(out_path, "w", encoding="utf-8") as f:
                f.write(text)

            print("Guardado:", filename)

## 6. Limpieza del prompt (en caso de que suceda).

In [None]:
def remove_prompt_prefix(text: str, prompt: str) -> str:
    if text.startswith(prompt):
        return text[len(prompt):].strip()
    return text.strip()


for path in OUTPUT_DIR.glob("*.txt"):
    with open(path, "r", encoding="utf-8") as f:
        text = f.read()

    # Intentar limpiar ambos prompts
    cleaned = remove_prompt_prefix(text, BASIC_PROMPT)
    cleaned = remove_prompt_prefix(cleaned, BETTER_PROMPT)

    with open(path, "w", encoding="utf-8") as f:
        f.write(cleaned)