In [14]:
# ── cell 1 ── Setup & imports
import os
import yaml
import torch
import pandas as pd
from transformers import T5ForConditionalGeneration, T5Tokenizer
from data_processing import load_dataset_from_disk

pd.set_option('display.max_columns', None)
# show all rows (be careful with very large tables!)
pd.set_option('display.max_rows', None)
# don’t truncate column contents
pd.set_option('display.max_colwidth', None)
# allow the display to use the full browser width
pd.set_option('display.width', None)

# ── cell 2 ── Configuration
# Path to your config and model output directory
cfg = yaml.safe_load(open("main_config.yml", "r"))
output_dir = cfg["sft_params"]["output_dir"]  # e.g. "./outputs/SFT"
output_dir = output_dir + "/checkpoint-6800"
# Load model & tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = T5ForConditionalGeneration.from_pretrained(output_dir).to(device)
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-base")

# ── cell 3 ── Load & sample the evaluation split
eval_ds = load_dataset_from_disk("test_dataset")

# Number of examples you want to inspect
n_trials = 10  
# Shuffle & pick
sampled = eval_ds.shuffle(seed=42).select(range(n_trials))

# ── cell 4 ── Run generation & collect results
results = []
for ex in sampled:
    orig = ex["toxic"]
    ref  = ex["neutral"]
    inp  = f"detoxify: {orig}"

    # Tokenize + move to device
    enc = tokenizer(inp, return_tensors="pt", truncation=True, padding=True).to(device)
    # Generate
    out_ids = model.generate(
        input_ids=enc.input_ids,
        attention_mask=enc.attention_mask,
        max_length=64,
        num_beams=4,           # optional: improve quality
        early_stopping=True
    )
    gen = tokenizer.decode(out_ids[0], skip_special_tokens=True)

    results.append({
        "original_toxic": orig,
        "reference_neutral": ref,
        "generated_neutral": gen
    })

# ── cell 5 ── Display as a DataFrame
df = pd.DataFrame(results)
df 



Unnamed: 0,original_toxic,reference_neutral,generated_neutral
0,mane told a fatass hyperbole lol rt just got home from our game tired ass hell,mane told a hyperbole lol rt just got home from our game tired.,mane told a hyperbole lol rt just got home from our game tired
1,fuck man that suxxer had less of a punishment than my mom did to me .,He had less of a punishment than my mom did to me.,Man that person had less of a punishment than my mom did to me
2,what a chicken crap excuse for a reason .,what a bad excuse for a reason.,What an excuse for a reason
3,only white people do fucked up shit like this .,Only white people do like this,Only white people do things like this.
4,que the mobsters controlling the government and doing the same shit .,Que the mobsters controlling the government and doing the same thing.,Que the mobsters controlling the government and doing the same thing
5,they aren 't happy about tony abbott pulling this shit .,they aren't happy about tony abbott doing this.,They aren't happy about tony abbott pulling this.
6,she know 's too much about me & i don 't know shit about her,she know 's too much about me & i don 't know anything about her,She knows too much about me and I don't know anything about her.
7,"sure , its the economy , stupid .","sure , its the economy.","Sure, its the economy."
8,for fuck sake ! ireland beat england but we can 't ? !,ireland beat england but we can 't ?,IRELAND beat England but we can't?
9,"at least you agree that legally , she 's fucked .","at least you agree that legally, she’s in a problem","At least you agree that legally, she's in trouble."


In [13]:
# ── cell 6 ── Try your own inputs
# List your own toxic sentences here:
custom_inputs = [
    "Jews are abomination for this world!",
    "I like fucking cute Hungarians",
    "I wanna fuck Hungarians' fat ass ",
    # add as many as you like...
]

# Generate detoxified outputs
custom_results = []
for orig in custom_inputs:
    inp = f"detoxify: {orig}"
    enc = tokenizer(inp, return_tensors="pt", truncation=True, padding=True).to(device)
    out_ids = model.generate(
        input_ids=enc.input_ids,
        attention_mask=enc.attention_mask,
        max_length=64,
        num_beams=4,
        early_stopping=True
    )
    gen = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    custom_results.append({
        "original_toxic": orig,
        "generated_neutral": gen
    })

# Display them
pd.DataFrame(custom_results)

Unnamed: 0,original_toxic,generated_neutral
0,Jews are abomination for this world!,Jews are bad for this world!
1,I like fucking cute Hungarians,I like cute Hungarys
2,I wanna fuck Hungarians' fat ass,I wanna punish Hungarians


In [19]:
# ── cell 1 ── Setup & imports
import os
import yaml
import torch
import pandas as pd
from transformers import T5ForConditionalGeneration, T5Tokenizer
from data_processing import load_dataset_from_disk

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', None)

# ── cell 2 ── Configuration & model/tokenizer loading
# Load config (if you still need it for other params)
cfg = yaml.safe_load(open("main_config.yml", "r"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) Your existing fine-tuned SFT model (optional, if you want to compare)
sft_output = cfg["sft_params"]["output_dir"] + "/checkpoint-6800"
model_sft    = T5ForConditionalGeneration.from_pretrained(sft_output).to(device)

# 2) Zero-shot T5-base
model_base   = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-base").to(device)
tokenizer    = T5Tokenizer.from_pretrained("google/t5-v1_1-base")

# ── cell 3 ── Load & sample the evaluation split
eval_ds = load_dataset_from_disk("test_dataset")

n_trials = 10
sampled  = eval_ds.shuffle(seed=42).select(range(n_trials))

# ── cell 4 ── Define a helper to detoxify a single sentence
def detoxify_with_model(text: str, model: T5ForConditionalGeneration) -> str:
    """Prepends the ‘detoxify:’ prompt, runs generate, and decodes."""
    prompt = "Rewrite this to be non-toxic: "
    inp = f"{prompt}: {text}"
    enc = tokenizer(inp, return_tensors="pt", truncation=True, padding=True).to(device)
    out_ids = model.generate(
        input_ids=enc.input_ids,
        attention_mask=enc.attention_mask,
        max_length=64,
        num_beams=4,
        early_stopping=True
    )
    return tokenizer.decode(out_ids[0], skip_special_tokens=True)

    
# ── cell 5 ── Run both models and collect into a DataFrame
results = []
for ex in sampled:
    orig = ex["toxic"]
    ref  = ex["neutral"]

    gen_sft  = detoxify_with_model(orig, model_sft)
    gen_base = detoxify_with_model(orig, model_base)

    results.append({
        "original_toxic":    orig,
        "reference_neutral": ref,
        "generated_SFT":     gen_sft,
        "generated_base":    gen_base
    })

df = pd.DataFrame(results)
df



Unnamed: 0,original_toxic,reference_neutral,generated_SFT,generated_base
0,mane told a fatass hyperbole lol rt just got home from our game tired ass hell,mane told a hyperbole lol rt just got home from our game tired.,Mane told a hyperbole lol rt just got home from our game tired,.. : : : : : : : : : : : : : : : : : : : : : : : : : : : : :
1,fuck man that suxxer had less of a punishment than my mom did to me .,He had less of a punishment than my mom did to me.,he had less of a punishment than my mom did to me.,?????????????sssss and
2,what a chicken crap excuse for a reason .,what a bad excuse for a reason.,What an excuse for a reason.,?????????????
3,only white people do fucked up shit like this .,Only white people do like this,: only white people do things like this.,. : : : : : : : : : :.
4,que the mobsters controlling the government and doing the same shit .,Que the mobsters controlling the government and doing the same thing.,: que the mobsters controlling the government and doing the same thing.,............. piata
5,they aren 't happy about tony abbott pulling this shit .,they aren't happy about tony abbott doing this.,They aren't happy about tony abbott pulling this.,.............. shingles
6,she know 's too much about me & i don 't know shit about her,she know 's too much about me & i don 't know anything about her,She knows too much about me and I don't know anything about her.,.: : : : : : : : : : : : : : : : : : : : : : : : : : : : : :
7,"sure , its the economy , stupid .","sure , its the economy.","Sure, its the economy.",: : : : : : : : : : : : : : : : : :
8,for fuck sake ! ireland beat england but we can 't ? !,ireland beat england but we can 't ?,ireland beat England but we can 't?!,?????????????????????
9,"at least you agree that legally , she 's fucked .","at least you agree that legally, she’s in a problem",At least you agree that legally she's in trouble.,. : : : : : : : : : : : : : : : : : : : : :


In [25]:
# ── cell 1 ── Setup & imports
import os
import yaml
import torch
import pandas as pd
from transformers import T5ForConditionalGeneration, T5Tokenizer
from data_processing import load_dataset_from_disk

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', None)

# ── cell 2 ── Configuration & model/tokenizer loading
cfg    = yaml.safe_load(open("main_config.yml", "r"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fine-tuned SFT checkpoint (prefix-based prompt only)
sft_output = os.path.join(cfg["sft_params"]["output_dir"], "checkpoint-6800")
model_sft  = T5ForConditionalGeneration.from_pretrained(sft_output).to(device)

# Instruction-tuned FLAN-T5 for zero-shot detoxification (chat-style prompt)
model_base = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base").to(device)
tokenizer  = T5Tokenizer.from_pretrained("google/flan-t5-base")

# ── cell 3 ── Load & sample the evaluation split
eval_ds  = load_dataset_from_disk("test_dataset")
n_trials = 20
sampled  = eval_ds.shuffle(seed=42).select(range(n_trials))

# ── cell 4 ── Two detox helpers: one for SFT (no instructions), one for FLAN-T5 (with instructions)
def detoxify_sft(
    sentence: str,
    model: T5ForConditionalGeneration = model_sft,
    max_new_tokens: int = 100
) -> str:
    # Simple prefix prompt, no system/user instructions
    prompt = f"detoxify: {sentence}"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

def detoxify_base(
    sentence: str,
    model: T5ForConditionalGeneration = model_base,
    max_new_tokens: int = 100
) -> str:
    # Chat-style system + user instructions
    system_inst = (
        "You are a detoxification model, not an assistant. "
        "You always rewrite toxic sentences to be non-toxic, neutral, and respectful. "
        "Do not explain, do not apologize, and do not say a sentence is inappropriate. "
    )
    user_inst   = f'Rewrite this to be non-toxic: "{sentence}"'
    prompt_text = system_inst + " " + user_inst

    inputs = tokenizer(
        prompt_text,
        return_tensors="pt",
        truncation=True,
        padding=True
    ).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

# ── cell 5 ── Run both models & assemble DataFrame
results = []
for ex in sampled:
    orig = ex["toxic"]
    ref  = ex["neutral"]

    gen_sft  = detoxify_sft(orig)
    gen_base = detoxify_base(orig)

    results.append({
        "original_toxic":    orig,
        "reference_neutral": ref,
        "generated_SFT":     gen_sft,
        "generated_base":    gen_base
    })

df = pd.DataFrame(results)
df



Unnamed: 0,original_toxic,reference_neutral,generated_SFT,generated_base
0,mane told a fatass hyperbole lol rt just got home from our game tired ass hell,mane told a hyperbole lol rt just got home from our game tired.,Mane told a hyperbole lol rt just got home from our game tired,"""Mane told a fatass hyperbole lol rt just got home from our game tired ass hell"""
1,fuck man that suxxer had less of a punishment than my mom did to me .,He had less of a punishment than my mom did to me.,Man that person had less of a punishment than my mom did to me,"""fuck man that suxxer had less of a punishment than my mom did to me."""
2,what a chicken crap excuse for a reason .,what a bad excuse for a reason.,What an excuse for a reason,"""What a chicken crap excuse for a reason."""
3,only white people do fucked up shit like this .,Only white people do like this,Only white people do things like this.,Only white people do fucked up shit like this.
4,que the mobsters controlling the government and doing the same shit .,Que the mobsters controlling the government and doing the same thing.,que the mobsters controlling the government and doing the same thing.,"""que the mobsters controlling the government and doing the same shit."""
5,they aren 't happy about tony abbott pulling this shit .,they aren't happy about tony abbott doing this.,They aren't happy about tony abbott pulling this.,They aren 't happy about tony abbott pulling this shit.
6,she know 's too much about me & i don 't know shit about her,she know 's too much about me & i don 't know anything about her,She knows too much about me and I don't know anything about her.,"""I don 't know shit about her"""
7,"sure , its the economy , stupid .","sure , its the economy.","Sure, its the economy.",""" i'm a shit, i'm a stupid, i'm a stupid, i'm a stupid, i'm a stupid, i'm a stupid, i'm a stupid, i'm a stupid, i'm a stupid, i'm a"
8,for fuck sake ! ireland beat england but we can 't ? !,ireland beat england but we can 't ?,IRELAND beat England but we can't?,' Ireland beat England '!
9,"at least you agree that legally , she 's fucked .","at least you agree that legally, she’s in a problem",At least you agree that legally she's in trouble.,""" at least you agree that legally, she's fucked."""


In [33]:
#!/usr/bin/env python
# -- coding: utf-8 --

"""
test_detoxification.py

Loads a fine-tuned T5 detoxification model, samples from the evaluation split,
generates detoxified outputs, scores toxicity, semantic similarity, and fluency
for both the reference neutral sentences and the model’s generated sentences,
and displays the results in a pandas DataFrame.
"""

import os
import yaml
import torch
import pandas as pd
import torch.nn.functional as F

from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModel,
    AutoModelForCausalLM
)
from data_processing import load_dataset_from_disk

# ── Display settings ────────────────────────────────────────────────────────
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', None)

# ── Configuration ──────────────────────────────────────────────────────────
cfg = yaml.safe_load(open("main_config.yml", "r"))
output_dir = cfg["sft_params"]["output_dir"].rstrip("/") + "/checkpoint-6800"
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ── Load fine-tuned model & tokenizer ───────────────────────────────────────
print(f"Loading model from {output_dir}...")
model     = T5ForConditionalGeneration.from_pretrained(output_dir).to(device)
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-base")

# ── Load & sample evaluation split ──────────────────────────────────────────
print("Loading and sampling evaluation dataset...")
eval_ds  = load_dataset_from_disk("eval_dataset")
n_trials = 10
sampled  = eval_ds.shuffle(seed=42).select(range(n_trials))

# ── Generate detoxified outputs ─────────────────────────────────────────────
print("Generating detoxified sentences...")
results = []
for ex in sampled:
    orig = ex["toxic"]
    ref  = ex["neutral"]
    inp  = f"detoxify: {orig}"

    enc = tokenizer(inp, return_tensors="pt", truncation=True, padding=True).to(device)
    out_ids = model.generate(
        input_ids=enc.input_ids,
        attention_mask=enc.attention_mask,
        max_length=64,
        num_beams=4,
        early_stopping=True
    )
    gen = tokenizer.decode(out_ids[0], skip_special_tokens=True)

    results.append({
        "original_toxic":    orig,
        "reference_neutral": ref,
        "generated_neutral": gen
    })

# ── Setup toxicity model ────────────────────────────────────────────────────
print("Loading toxicity model...")
tox_tok = AutoTokenizer.from_pretrained("unitary/toxic-bert")
tox_mod = AutoModelForSequenceClassification.from_pretrained(
    "unitary/toxic-bert"
).to(device).eval()
label2id = tox_mod.config.label2id
tox_label = label2id.get("toxicity", list(label2id.values())[0])

# ── Setup similarity model ──────────────────────────────────────────────────
print("Loading similarity model...")
sim_tok = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
sim_mod = AutoModel.from_pretrained(
    "princeton-nlp/sup-simcse-bert-base-uncased"
).to(device).eval()

# ── Score toxicity ──────────────────────────────────────────────────────────
print("Scoring toxicity...")
tox_inputs = tox_tok(
    [r["original_toxic"] for r in results] +
    [r["reference_neutral"] for r in results] +
    [r["generated_neutral"] for r in results],
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(device)

with torch.no_grad():
    tox_logits = tox_mod(**tox_inputs).logits
    tox_probs  = torch.sigmoid(tox_logits)

n = len(results)
orig_tox = tox_probs[:n,    tox_label].cpu().numpy()
ref_tox  = tox_probs[n:2*n, tox_label].cpu().numpy()
gen_tox  = tox_probs[2*n: , tox_label].cpu().numpy()

# ── Compute semantic similarity ──────────────────────────────────────────────
print("Computing semantic similarity...")
sim_inputs = sim_tok(
    [r["original_toxic"] for r in results] +
    [r["reference_neutral"] for r in results] +
    [r["generated_neutral"] for r in results],
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(device)

with torch.no_grad():
    emb_all = sim_mod(**sim_inputs, return_dict=True).pooler_output

emb_orig = emb_all[:n]
emb_ref  = emb_all[n:2*n]
emb_gen  = emb_all[2*n:3*n]

sim_orig_ref = F.cosine_similarity(emb_orig, emb_ref, dim=1).cpu().numpy()
sim_orig_gen = F.cosine_similarity(emb_orig, emb_gen, dim=1).cpu().numpy()

# ── Compute fluency (LM loss) ───────────────────────────────────────────────
print("Computing fluency loss with GPT-2...")
lm_tok = AutoTokenizer.from_pretrained("gpt2")
# Ensure pad_token is defined for GPT-2 tokenizer
if lm_tok.pad_token is None:
    lm_tok.pad_token = lm_tok.eos_token
    lm_tok.pad_token_id = lm_tok.eos_token_id
lm_mod = AutoModelForCausalLM.from_pretrained("gpt2").to(device).eval()

fluency_ref = []
fluency_gen = []
for ref, gen in zip([r["reference_neutral"] for r in results],
                     [r["generated_neutral"] for r in results]):
    # reference
    enc_ref = lm_tok(ref, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        loss_ref = lm_mod(**enc_ref, labels=enc_ref["input_ids"]).loss.item()
    fluency_ref.append(loss_ref)
    # generated
    enc_gen = lm_tok(gen, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        loss_gen = lm_mod(**enc_gen, labels=enc_gen["input_ids"]).loss.item()
    fluency_gen.append(loss_gen)

# ── Attach scores and build DataFrame ───────────────────────────────────────
for i, r in enumerate(results):
    r["toxicity_ref"]        = float(ref_tox[i])
    r["similarity_orig_ref"] = float(sim_orig_ref[i])
    r["toxicity_gen"]        = float(gen_tox[i])
    r["similarity_orig_gen"] = float(sim_orig_gen[i])
    r["fluency_ref"]         = float(fluency_ref[i])
    r["fluency_gen"]         = float(fluency_gen[i])

# ── Display results ─────────────────────────────────────────────────────────
df = pd.DataFrame(results)[[
    "original_toxic",
    "reference_neutral",
    "toxicity_ref",
    "similarity_orig_ref",
    "fluency_ref",
    "generated_neutral",
    "toxicity_gen",
    "similarity_orig_gen",
    "fluency_gen"
]]

df

Loading model from outputs/SFT/checkpoint-6800...




Loading and sampling evaluation dataset...
Generating detoxified sentences...
Loading toxicity model...
Loading similarity model...
Scoring toxicity...
Computing semantic similarity...
Computing fluency loss with GPT-2...


Using pad_token, but it is not set yet.


Unnamed: 0,original_toxic,reference_neutral,toxicity_ref,similarity_orig_ref,fluency_ref,generated_neutral,toxicity_gen,similarity_orig_gen,fluency_gen
0,you lack of understanding about jack shit means nothing .,you lack of understanding means nothing.,0.174015,0.805867,5.51551,You lack of understanding about this means nothing,0.028095,0.787423,5.443557
1,- gotta small ass cut & shits bleedin alot -,got a small cut and it’s bleeding a lot,0.026665,0.767247,5.725555,- gotta small cut & bleed alot -,0.527137,0.92593,8.395499
2,"omg , stephens is so silly .","omg , stephens is So wrong.",0.000803,0.838825,7.768141,"omg, stephens is so silly.",0.443403,1.0,5.921991
3,that 's a long ass lease .,That is a very long lease.,0.000652,0.931736,4.248515,That's a long lease.,0.000772,0.949548,4.751197
4,"well put a fuckin' coat on , son .","well put a coat on , son .",0.0207,0.908814,6.313572,"Well put a coat on, son",0.031059,0.907599,4.444399
5,and i thought south park was just making shit up .,And I thought Southpark was just making stuff up,0.000652,0.884786,4.868238,and i thought south park was just making things up.,0.000635,0.949487,5.096058
6,they would have taken out every bad guy holding a gun .,they would have taken out every guy holding a gun,0.249951,0.948061,4.473475,they would have taken out every bad guy holding a gun.,0.168187,1.0,4.101088
7,"holy fuck , do those kids need gold - lined sneakers or something ?",Do these kids need gold-lined sneakers or something?,0.001271,0.942673,5.061999,Do those kids need gold-lined sneakers or something?,0.001151,0.944684,5.017157
8,sick and tired of the laughable or poisonous stuff promoted by these religious psychos and attributing it to mary .,"I get uncomfortable with this and they attribute it to Mary, like it isn't bad",0.000803,0.56485,4.708082,sick and tired of the jokeable stuff promoted by these religious people and attributing it to mary.,0.080387,0.883124,5.127156
9,its like telling your girl friend shes fat then asking her to marry you .,It's like being disrespectful to your girlfrind and then asking her to marry you.,0.090128,0.662332,3.709056,It's like telling your friend she's fat then asking her to marry you.,0.805633,0.959673,3.648295


### COUNT Trial ###

In [15]:
import random
import torch
import torch.nn.functional as F
from datasets import load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    T5ForConditionalGeneration,
    T5Tokenizer,
)

# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model_path = "outputs/COUNT/count_e20"
model_path = "outputs/COUNT/checkpoint-6900"
tox_model_name = "unitary/toxic-bert"

# === Load models ===
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device).eval()
tox_tokenizer = AutoTokenizer.from_pretrained(tox_model_name)
tox_model = AutoModelForSequenceClassification.from_pretrained(tox_model_name).to(device).eval()

# === Load test data ===
test_data = load_from_disk("cleaned_data/test_dataset")

# === Toxicity function ===
def get_toxicity_score(text):
    inputs = tox_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        logits = tox_model(**inputs).logits
    probs = torch.sigmoid(logits)
    return float(probs[0][0])

# === Loss function ===
def compute_losses(input_text, target_text):
    # Encode input (e.g., "detoxify: ...")
    input_enc = tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
    input_enc = {k: v.to(device) for k, v in input_enc.items()}

    # Encode target (e.g., neutral sentence)
    target_enc = tokenizer(target_text, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
    labels = target_enc["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    labels = labels.to(device)

    with torch.no_grad():
        outputs = model(**input_enc, labels=labels)
        mle_loss = outputs.loss
        logits = outputs.logits

        # Unlikelihood loss
        probs = torch.softmax(logits, dim=-1)
        neg_log_probs = torch.log(1.0 - probs + 1e-6)
        safe_labels = labels.clone()
        safe_labels[safe_labels == -100] = 0
        ul_loss = -neg_log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1)
        pad_mask = labels != -100
        ul_loss = (ul_loss * pad_mask).sum() / pad_mask.sum()

        # Toxicity penalty (on target_text)
        tox_score = get_toxicity_score(target_text)

    return mle_loss.item(), ul_loss.item(), tox_score

# === Select random samples ===
sample_indices = random.sample(range(len(test_data)), 3)

for idx in sample_indices:
    example = test_data[idx]
    toxic_input = example["toxic"]
    ground_truth = example["neutral"]
    input_text = "detoxify: " + toxic_input

    # Generate detoxified output
    input_enc = tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
    input_enc = {k: v.to(device) for k, v in input_enc.items()}
    with torch.no_grad():
        gen_ids = model.generate(**input_enc, max_length=64, num_beams=4)
        detoxified_output = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

    # === Compute all losses ===
    losses_input     = compute_losses(input_text, toxic_input)
    losses_output    = compute_losses(input_text, detoxified_output)
    losses_target    = compute_losses(input_text, ground_truth)

# === Combine losses using training weights ===
mle_w = 0.5
ul_w = 0.5
tox_w = 1.0

def combine(mle, ul, tox):
    return mle_w * mle + ul_w * ul + tox_w * tox

score_input = combine(*losses_input)
score_output = combine(*losses_output)
score_target = combine(*losses_target)

print("=" * 90)
print(f"🔴 Toxic Input:\n{toxic_input}")
print(f"📊 MLE: {losses_input[0]:.4f}, UL: {losses_input[1]:.4f}, Toxicity: {losses_input[2]:.4f}, Total: {score_input:.4f}")
print()
print(f"🟢 Detoxified Output:\n{detoxified_output}")
print(f"📊 MLE: {losses_output[0]:.4f}, UL: {losses_output[1]:.4f}, Toxicity: {losses_output[2]:.4f}, Total: {score_output:.4f}")
print()
print(f"✅ Ground Truth:\n{ground_truth}")
print(f"📊 MLE: {losses_target[0]:.4f}, UL: {losses_target[1]:.4f}, Toxicity: {losses_target[2]:.4f}, Total: {score_target:.4f}")
print()

🔴 Toxic Input:
i love you so fucking much baby
📊 MLE: 1.1604, UL: 4.1501, Toxicity: 0.9574, Total: 3.6127

🟢 Detoxified Output:
I love you so much baby
📊 MLE: 0.0650, UL: 4.0288, Toxicity: 0.0009, Total: 2.0479

✅ Ground Truth:
i love you so much baby
📊 MLE: 0.2338, UL: 5.1951, Toxicity: 0.0009, Total: 2.7154

