# Small Prompt Fine-Tuning

A concise Hugging Face `Seq2SeqTrainer` setup for fine-tuning on the short or very-short prompt datasets. Adjust the cell parameters before launching training (ideally on a GPU-enabled runtime such as Google Colab).

In [None]:
from __future__ import annotations

import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from pathlib import Path
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq,
                          Seq2SeqTrainer, Seq2SeqTrainingArguments)

PROJECT_ROOT = Path('..').resolve()
SHORT_TRAIN = PROJECT_ROOT / 'src' / 'training_data' / 'dsp-train.csv'
SHORT_TEST = PROJECT_ROOT / 'src' / 'training_data' / 'dsp-test.csv'
VERY_SHORT_TRAIN = PROJECT_ROOT / 'src' / 'training_data' / 'dvsp-train.csv'
VERY_SHORT_TEST = PROJECT_ROOT / 'src' / 'training_data' / 'dvsp-test.csv'

BASE_MODEL = 'Falconsai/text_summarization'
MAX_SOURCE_LENGTH = 512
MAX_TARGET_LENGTH = 128

np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)


def load_datasets(use_very_short: bool = False):
    if use_very_short:
        train_file = VERY_SHORT_TRAIN
        eval_file = VERY_SHORT_TEST
    else:
        train_file = SHORT_TRAIN
        eval_file = SHORT_TEST
    data_files = {'train': str(train_file), 'validation': str(eval_file)}
    return load_dataset('csv', data_files=data_files)


def prepare_tokenizer():
    return AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)


def preprocess_factory(tokenizer):
    def preprocess(batch):
        inputs = tokenizer(batch['original'], truncation=True, max_length=MAX_SOURCE_LENGTH)
        labels = tokenizer(batch['compressed_prompt'], truncation=True, max_length=MAX_TARGET_LENGTH)
        inputs['labels'] = labels['input_ids']
        return inputs
    return preprocess

def build_trainer(dataset: DatasetDict, tokenizer):
    model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    training_args = Seq2SeqTrainingArguments(
        output_dir='results',
        eval_strategy='epoch',
        save_strategy='epoch',
        learning_rate=3e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        weight_decay=0.01,
        save_total_limit=2,
        num_train_epochs=3,
        predict_with_generate=True,
        generation_max_length=MAX_TARGET_LENGTH,
        logging_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='eval_loss',
        greater_is_better=False,
    )
    return Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],  # type: ignore
        processing_class=tokenizer,  # ok with recent transformers; alternative: tokenizer=tokenizer
        data_collator=data_collator
)

In [None]:
use_very_short = False  # Set True to train on the ≤64 token subset
datasets = load_datasets(use_very_short=use_very_short)
print(datasets)

In [None]:
tokenizer = prepare_tokenizer()
processed = datasets.map(preprocess_factory(tokenizer), batched=True)
keep_columns = ['input_ids', 'attention_mask', 'labels']
# Defensive: check for column_names attribute
if processed == None:
    exit()
if hasattr(processed['train'], 'column_names'):
    processed = processed.remove_columns([c for c in processed['train'].column_names if c not in keep_columns])
else:
    # Fallback: use keys from first row if column_names is missing
    first_row = processed['train'][0] if len(processed['train']) > 0 else {}
    processed = processed.remove_columns([c for c in first_row.keys() if c not in keep_columns])
trainer = build_trainer(processed, tokenizer)
print('Trainer ready. Uncomment trainer.train() when running on GPU.')

In [None]:
# trainer.train()
# trainer.save_model('results/final')

In [None]:
# This is the actual working code direct from last run in Colab
# Above was an attempt to clean it up, but to be safe retaining original here

# %%
# 🔧 Fine-tuning for Prompt Compression (no control tokens) — length-aware + no-worse-than-original (+safe retry)
#     + hard filter to keep inputs ≤512 tokens (and optionally targets ≤128)

# 0) Repro + device helpers
import os, random, numpy as np, torch, datetime as dt, pandas as pd, matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supports_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8  # Ampere+

# 1) Load data
csv_path = 'sample_data/dolly-summarization-data-rouge.csv'
df = pd.read_csv(csv_path)

# --- Tokenizer for length checks & base checkpoint ---
from transformers import AutoTokenizer
base_ckpt = "Falconsai/text_summarization"
_tmp_tok = AutoTokenizer.from_pretrained(base_ckpt)

# --- Helper: token count WITHOUT truncation (to avoid undercounting) ---
def tok_len_no_trunc(text):
    # add_special_tokens=True so lengths match model inputs
    return len(_tmp_tok(str(text), add_special_tokens=True).input_ids)

# === HARD FILTER: keep only rows within model limits ===
max_input_length  = 512
max_target_length = 256
df["src_len"]  = df["original"].astype(str).apply(tok_len_no_trunc)
df["tgt_len"]  = df["compression"].astype(str).apply(tok_len_no_trunc)

before_n = len(df)
# If you want to ignore target length filtering, set filter_targets=False
filter_targets = False
if filter_targets:
    df = df[(df["src_len"] <= max_input_length) & (df["tgt_len"] <= max_target_length)].copy()
else:
    df = df[df["src_len"] <= max_input_length].copy()

after_n = len(df)
print(f"✅ Kept {after_n}/{before_n} rows (dropped {before_n - after_n} that exceeded limits).")

# --- Oversample short inputs (focus on failure mode) ---
SHORT_THRESH = 40  # tokens; tune as needed
short_df = df[df["src_len"] <= SHORT_THRESH]
# Oversample short prompts (2x)
oversampled_df = pd.concat([df, short_df, short_df], ignore_index=True)

# We no longer need the src_len/tgt_len helper cols downstream
oversampled_df = oversampled_df.drop(columns=["src_len","tgt_len"])

# Split
train_df, test_df = train_test_split(oversampled_df, test_size=0.15, random_state=42)

# 2) Build HF datasets
from datasets import Dataset, DatasetDict
train_dataset = Dataset.from_pandas(train_df, preserve_index=False)
test_dataset  = Dataset.from_pandas(test_df,  preserve_index=False)
dataset = DatasetDict({'train': train_dataset, 'test': test_dataset})

# 3) Tokenizer & model
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    EarlyStoppingCallback, TrainerCallback
)
tokenizer = AutoTokenizer.from_pretrained(base_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(base_ckpt).to(device)

# Generation defaults (used by Trainer unless we override per-batch)
gen_conf = model.generation_config
gen_conf.num_beams = 4
gen_conf.no_repeat_ngram_size = 3
gen_conf.length_penalty = 1.0  # neutral by default (avoid brevity bias on short inputs)

# 4) Preprocess/tokenize (now safe: inputs already ≤512, targets ≤128)
def preprocess_function(examples):
    inputs  = [str(x) for x in examples["original"]]
    targets = [str(y) for y in examples["compression"]]
    src_tok = tokenizer(inputs, max_length=max_input_length, truncation=True)
    labels  = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
    model_inputs = dict(**src_tok)
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["src_len"] = [len(ids) for ids in src_tok["input_ids"]]
    return model_inputs

tokenized_datasets = dataset.map(
    preprocess_function, batched=True,
    remove_columns=[c for c in dataset["train"].column_names if c not in ("original","compression")]
)

# 5) Metrics (ROUGE + compression ratio diagnostics)
import evaluate
rouge = evaluate.load("rouge")

eval_src_lens = tokenized_datasets["test"]["src_len"]

def summarize_compression(decoded_preds, src_lens):
    pred_lens = [len(tokenizer(p, add_special_tokens=True).input_ids) for p in decoded_preds]
    ratios = [ (pl / max(1, sl)) for pl, sl in zip(pred_lens, src_lens) ]
    violations = [1 if pl > sl else 0 for pl, sl in zip(pred_lens, src_lens)]
    return {
        "comp_ratio_mean": float(np.mean(ratios)),
        "comp_ratio_p90":  float(np.percentile(ratios, 90)),
        "pct_violations":  float(np.mean(violations))
    }

import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    pad_id = tokenizer.pad_token_id
    predictions = np.asarray(predictions)
    labels      = np.asarray(labels)
    if not np.issubdtype(predictions.dtype, np.integer):
        predictions = predictions.astype(np.int64, copy=False)
    if not np.issubdtype(labels.dtype, np.integer):
        labels = labels.astype(np.int64, copy=False)
    predictions = np.where(predictions < 0, pad_id, predictions)
    labels      = np.where(labels < 0,      pad_id, labels)

    decoded_preds  = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels,      skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: (v.mid.fmeasure if hasattr(v, "mid") else v) for k, v in result.items()}
    result.update(summarize_compression(decoded_preds, eval_src_lens))
    return result

# === Length-aware generation + post-filter + safe retry ===

def length_aware_gen_kwargs(input_ids, short_tok=14, ratio_long=0.75, hard_cap=max_target_length):
    """
    For short inputs (<= short_tok), allow up to *equal* length (no compression pressure).
    For longer inputs, aim for ~75% of source length with a tiny slack, but never exceed source.
    """
    src_len = int(input_ids.shape[1])
    if src_len <= short_tok:
        cap = min(hard_cap, src_len)  # equal-length ceiling
        return dict(
            num_beams=6,
            no_repeat_ngram_size=3,
            length_penalty=1.0,        # neutral
            max_new_tokens=cap,
            repetition_penalty=1.03
        )
    # longer inputs: compress a bit, add small slack, but clamp to source len
    target = int(max(8, ratio_long * src_len) + 2)
    cap = min(hard_cap, target, src_len)  # never exceed source
    return dict(
        num_beams=4,
        no_repeat_ngram_size=3,
        length_penalty=0.9,            # mild preference for shorter on long inputs
        max_new_tokens=cap,
        repetition_penalty=1.03
    )

def compress_postfilter(src_text, pred_text, tok, allow_equal=True):
    """Guarantee no-worse-than-original: if pred is longer (or trivially equal), fall back to source."""
    src_len  = len(tok(src_text, add_special_tokens=True).input_ids)
    pred_len = len(tok(pred_text, add_special_tokens=True).input_ids)
    if pred_len > src_len:
        return src_text
    if allow_equal and pred_len == src_len:
        if pred_text.strip().rstrip('?.!,:;') == src_text.strip().rstrip('?.!,:;'):
            return src_text
    return pred_text

@torch.inference_mode()
def safe_generate_with_retry(model, enc, gkw, eos_id=None):
    """
    Generate once. If we appear to have hit the max_new_tokens ceiling without EOS,
    retry with a tiny extra margin (but still <= source length).
    """
    out = model.generate(**enc, **gkw)
    seq = out[0].tolist()
    cap = gkw.get("max_new_tokens", None)
    eos = eos_id if eos_id is not None else getattr(model.config, "eos_token_id", None)
    hit_cap = (cap is not None) and (len(seq) >= cap)
    no_eos  = (eos is not None) and (eos not in seq)
    if hit_cap and no_eos:
        src_len = int(enc["input_ids"].shape[1])
        wiggle = min(4, max(0, src_len - cap))  # add up to +4 tokens but never exceed source length
        if wiggle > 0:
            gkw2 = dict(gkw)
            gkw2["max_new_tokens"] = cap + wiggle
            out = model.generate(**enc, **gkw2)
    return out

# 6) Baseline eval (base model) with length-aware gen + post-filter + retry
@torch.inference_mode()
def baseline_eval(texts, refs, batch_size=8):
    preds, src_lens = [], []
    eos_id = getattr(model.config, "eos_token_id", None)
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        enc = tokenizer(batch, max_length=max_input_length, truncation=True, padding=True, return_tensors="pt").to(device)
        src_lens.extend([len(ids) for ids in enc["input_ids"]])
        gkw = length_aware_gen_kwargs(enc["input_ids"])
        out = safe_generate_with_retry(model, enc, gkw, eos_id=eos_id)
        decoded = tokenizer.batch_decode(out, skip_special_tokens=True)
        preds.extend([compress_postfilter(s, p, tokenizer, allow_equal=True) for s, p in zip(batch, decoded)])
    scores = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
    scores = {k: (v.mid.fmeasure if hasattr(v, "mid") else v) for k, v in scores.items()}
    scores.update(summarize_compression(preds, src_lens))
    return scores, preds

test_texts = test_df["original"].astype(str).tolist()
test_refs  = test_df["compression"].astype(str).tolist()
baseline_scores, baseline_preds = baseline_eval(test_texts, test_refs)
print("📊 Baseline (no FT) — ROUGE & compression:", baseline_scores)

# 7) Training args
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    predict_with_generate=True,
    generation_max_length=max_target_length,    # ceiling; we override per-batch in callback/eval
    generation_num_beams=gen_conf.num_beams,
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,                         # will +1 with LR drop
    weight_decay=0.01,
    label_smoothing_factor=0.1,
    warmup_ratio=0.1,
    save_total_limit=2,
    push_to_hub=True,
    hub_model_id="dotslashderek/small-prompt-compression",
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    logging_strategy="epoch",
    gradient_checkpointing=True,
    fp16=(torch.cuda.is_available() and not supports_bf16),
    bf16=supports_bf16
)

# 8) Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# 9) Rolling sample-predictions callback (length-aware + post-filter + retry)
from transformers import TrainerCallback
import datetime as dt
import string

class RollingSamplePredictionCallback(TrainerCallback):
    def __init__(self, tokenizer, raw_eval_dataset, num_samples=3, max_len=128, output_dir="./results"):
        self.tokenizer = tokenizer
        self.raw_eval_dataset = raw_eval_dataset  # has "original"/"compression"
        self.num_samples = num_samples
        self.max_len = max_len
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
        self.roll_path = os.path.join(self.output_dir, "samples_all.txt")

    def on_evaluate(self, args, state, control, model=None, **kwargs):
        if model == None: 
            return
        model.eval()
        epoch = int(state.epoch or 0)
        stamp = dt.datetime.now().isoformat(timespec="seconds")
        header = f"\n\n📘 Epoch {epoch} — {stamp}\n" + ("-" * 100) + "\n"
        print(header)
        with open(self.roll_path, "a", encoding="utf-8") as f:
            f.write(header)
            import random
            idxs = random.sample(range(len(self.raw_eval_dataset)), k=min(self.num_samples, len(self.raw_eval_dataset)))
            for i, idx in enumerate(idxs, start=1):
                ex = self.raw_eval_dataset[int(idx)]
                inp = str(ex["original"]); ref = str(ex["compression"])
                if ref[-1] in string.punctuation:
                  ref = ref[:-1]
                enc = self.tokenizer(inp, return_tensors="pt", truncation=True, max_length=512).to(model.device)
                gkw = length_aware_gen_kwargs(enc["input_ids"], hard_cap=self.max_len)
                with torch.no_grad():
                    out = safe_generate_with_retry(model, enc, gkw, eos_id=getattr(model.config, "eos_token_id", None))
                pred = self.tokenizer.decode(out[0], skip_special_tokens=True)
                pred = compress_postfilter(inp, pred, self.tokenizer, allow_equal=True)
                entry = (
                    f"🟢 Sample {i}\n"
                    f"Input({enc['input_ids'].shape[1]} tok): {inp[:500]}...\n"
                    f"Pred ({len(self.tokenizer(pred, add_special_tokens=True).input_ids)} tok): {pred[:500]}...\n"
                    f"Ref  ({len(self.tokenizer(ref, add_special_tokens=True).input_ids)} tok): {ref[:500]}...\n"
                    + ("-" * 100) + "\n"
                )
                print(entry); f.write(entry)
        print(f"✅ Appended to {self.roll_path}")

# 10) Trainer
from transformers import Seq2SeqTrainer, EarlyStoppingCallback
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    processing_class=tokenizer,  # ok with recent transformers; alternative: tokenizer=tokenizer
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=2),
        RollingSamplePredictionCallback(tokenizer, dataset["test"], num_samples=3, max_len=256, output_dir="./results"),
    ],
)

# 11) Train (stage 1)
trainer.train()

# 12) Small LR drop + 1 extra epoch (stage 2)
#for g in trainer.optimizer.param_groups:
#    g["lr"] = 2e-5  # lower for stabilization

#trainer.args.generation_num_beams = 4
#trainer.args.generation_max_length = 160
#trainer.args.repetition_penalty = 1.1
#trainer.args.num_train_epochs += 1
#trainer.train(resume_from_checkpoint=True)

# 13) Push
trainer.push_to_hub()

# 14) Plots (loss & ROUGE + compression)
logs = pd.DataFrame(trainer.state.log_history)
logs.to_csv("./results/trainer_log_history.csv", index=False)

train_logs = logs[logs["loss"].notna()][["step", "loss"]].reset_index(drop=True)
eval_logs  = logs[logs["eval_loss"].notna()].reset_index(drop=True)

plt.figure(); plt.plot(train_logs["step"], train_logs["loss"])
plt.title("Training Loss vs Step"); plt.xlabel("Step"); plt.ylabel("Loss"); plt.grid(True); plt.show()

plt.figure(); plt.plot(eval_logs["epoch"], eval_logs["eval_loss"], marker="o")
plt.title("Validation Loss vs Epoch"); plt.xlabel("Epoch"); plt.ylabel("Eval Loss"); plt.grid(True); plt.show()

plt.figure()
for k, label in [("eval_rougeL","ROUGE-L"), ("eval_rouge1","ROUGE-1"), ("eval_rouge2","ROUGE-2"),
                 ("eval_comp_ratio_mean","CompRatio-mean"), ("eval_pct_violations","%Violations")]:
    if k in eval_logs: plt.plot(eval_logs["epoch"], eval_logs[k], marker="o", label=label)
plt.title("Validation Metrics vs Epoch"); plt.xlabel("Epoch"); plt.legend(); plt.grid(True); plt.show()

print("Saved raw trainer logs to ./results/trainer_log_history.csv")