In [None]:
import kagglehub
import numpy as np
import re
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer as Trainer
from transformers import Seq2SeqTrainingArguments as TrainingArguments
import torch
import evaluate
import optuna
from optuna.exceptions import TrialPruned
import json
import os
import gc

In [None]:
base_url = "https://raw.githubusercontent.com/LCS2-IIITD/Multimodal-Sarcasm-Explanation-MuSE/main/Dataset/"
MuSE_train_df = pd.read_csv(base_url + "train_df.tsv", sep="\t", header=None,
                       names=["PID","Caption","Explanation"])
MuSE_val_df   = pd.read_csv(base_url + "val_df.tsv",   sep="\t", header=None,
                       names=["PID","Caption","Explanation"])
MuSE_test_df  = pd.read_csv(base_url + "test_df.tsv",  sep="\t", header=None,
                       names=["PID","Caption","Explanation"])

print(MuSE_train_df.head())
print(MuSE_val_df.head())
print(MuSE_test_df.head())
print(MuSE_train_df.shape)
print(MuSE_val_df.shape)
print(MuSE_test_df.shape)

path = kagglehub.dataset_download("prayag007/sarcasm-explanation")
print("Path to dataset files:", path)

kag_train_df = pd.read_csv(f"{path}/train_df.tsv", sep="\t", header=0)
kag_val_df   = pd.read_csv(f"{path}/val_df.tsv",   sep="\t", header=0)

print(kag_train_df.head())
print(kag_val_df.head())
print(kag_train_df.shape)
print(kag_val_df.shape)


In [None]:
MuSE_train_df = MuSE_train_df.drop(columns=["PID"])
MuSE_val_df   = MuSE_val_df.drop(columns=["PID"])
MuSE_test_df  = MuSE_test_df.drop(columns=["PID"])

for df in [kag_train_df, kag_val_df]:
    for col in ["Unnamed: 0", "pid"]:
        if col in df.columns:
            df.drop(columns=[col], inplace=True)

print(kag_train_df.columns)

MuSE_train_df.rename(columns={"Caption": "text", "Explanation": "explanation"}, inplace=True)
MuSE_val_df.rename(columns={"Caption": "text", "Explanation": "explanation"}, inplace=True)
MuSE_test_df.rename(columns={"Caption": "text", "Explanation": "explanation"}, inplace=True)

kag_train_df.rename(columns={"text": "text", "explanation": "explanation"}, inplace=True)
kag_val_df.rename(columns={"text": "text", "explanation": "explanation"}, inplace=True)

MuSE_train_df["source"] = "muse"
MuSE_val_df["source"] = "muse"
MuSE_test_df["source"] = "muse"
kag_train_df["source"] = "kaggle"
kag_val_df["source"] = "kaggle"

combined_train = pd.concat([MuSE_train_df, kag_train_df], ignore_index=True)
combined_val   = pd.concat([MuSE_val_df, kag_val_df], ignore_index=True)

print(combined_train.shape)
print(combined_val.shape)
combined_train.sample(3)

In [None]:
def build_target(row):
    exp = str(row["explanation"]).strip()
    return f" {exp}"

combined_train["target_text"] = combined_train.apply(build_target, axis=1)
combined_val["target_text"]   = combined_val.apply(build_target, axis=1)

train_ds = Dataset.from_pandas(combined_train[["text", "target_text"]])
val_ds   = Dataset.from_pandas(combined_val[["text", "target_text"]])

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

MAX_LEN = 160

def preprocess(examples):
    prompts = [
        f"Explain why the following sentence sounds sarcastic:\n"
        f"Sentence: {t}\n"
        f"Explanation:{exp}"
        for t, exp in zip(examples["text"], examples["target_text"])
    ]
    out = tokenizer(prompts, truncation=True, padding="max_length", max_length=MAX_LEN)
    out["labels"] = out["input_ids"].copy()  
    return out

train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
val_tok   = val_ds.map(preprocess,   batched=True, remove_columns=val_ds.column_names)

In [None]:
def subset_dataset(ds, fraction=0.2, seed=4213):
    total = len(ds)
    subset_size = int(total * fraction)
    np.random.seed(seed)
    indices = np.random.choice(total, subset_size, replace=False)
    return ds.select(indices.tolist())

train_tok_sub = subset_dataset(train_tok, 0.2)
val_tok_sub   = subset_dataset(val_tok, 0.2)


rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    pred_texts  = tokenizer.batch_decode(preds,  skip_special_tokens=True)
    label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
    r = rouge.compute(predictions=pred_texts, references=label_texts)
    return {"rougeL": round(r["rougeL"], 4)}

def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return torch.argmax(logits, dim=-1)


def objective(trial):
    learning_rate = trial.suggest_categorical("learning_rate", [1e-5, 3e-5, 1e-4, 3e-4])
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
    dropout_rate = trial.suggest_categorical("dropout_rate", [0.1, 0.2, 0.3])
    weight_decay = trial.suggest_categorical("weight_decay", [0.0, 0.01, 0.05])
    warmup_ratio = trial.suggest_categorical("warmup_ratio", [0.03, 0.06, 0.1])

    try:
        model = AutoModelForCausalLM.from_pretrained(model_name)

        model.config.attn_pdrop = dropout_rate
        model.config.resid_pdrop = dropout_rate
        model.config.embd_pdrop = dropout_rate

        model.resize_token_embeddings(len(tokenizer))
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.use_cache = False
        model.to("cpu")

        collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False
        )

        training_args = Seq2SeqTrainingArguments(
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=max(2, batch_size // 2),
            num_train_epochs=1,
            weight_decay=weight_decay,
            warmup_ratio=warmup_ratio,
            eval_strategy="epoch",
            save_strategy="no",
            logging_strategy="epoch",
            predict_with_generate=True,
            gradient_accumulation_steps=max(1, 32 // batch_size),
            fp16=False,
            report_to="none"
        )

        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_tok_sub,
            eval_dataset=val_tok_sub,
            data_collator=collator,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics
        )

        trainer.train()
        eval_results = trainer.evaluate()
        rougeL = eval_results.get("eval_rougeL", 0.0)

        return rougeL

    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"Trial {trial.number} failed due to OOM â€” skipping.")
            raise TrialPruned()
        else:
            raise e



study = optuna.create_study(direction="maximize", study_name="gpt2_full_finetune_optuna")
study.optimize(objective, n_trials=6)


os.makedirs("./optuna_results", exist_ok=True)

best_trial = study.best_trial
best_params = best_trial.params
best_params["best_rougeL"] = best_trial.value

json_path = "./optuna_results/best_gpt2_params.json"
with open(json_path, "w") as f:
    json.dump(best_params, f, indent=4)

print(f"\nSaved best GPT-2 parameters to {json_path}")
print(json.dumps(best_params, indent=4))

In [None]:
params_path = "./optuna_results/best_gpt2_params.json"
if not os.path.exists(params_path):
    raise FileNotFoundError(f"Cannot find {params_path}. Run Optuna tuning first.")

with open(params_path, "r") as f:
    best_params = json.load(f)

learning_rate = best_params["learning_rate"]
batch_size    = best_params["batch_size"]
dropout_rate  = best_params["dropout_rate"]
weight_decay  = best_params["weight_decay"]
warmup_ratio  = best_params["warmup_ratio"]

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(model_name)

model.config.attn_pdrop  = dropout_rate
model.config.resid_pdrop = dropout_rate
model.config.embd_pdrop  = dropout_rate

model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False

for p in model.parameters():
    p.requires_grad = True

rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    pred_texts  = tokenizer.batch_decode(preds,  skip_special_tokens=True)
    label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
    r = rouge.compute(predictions=pred_texts, references=label_texts)
    return {"rougeL": round(r["rougeL"], 4)}

def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return torch.argmax(logits, dim=-1)


collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
=
training_args = Seq2SeqTrainingArguments(
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=max(2, batch_size // 2),
    weight_decay=weight_decay,
    warmup_ratio=warmup_ratio,
    num_train_epochs=5,
    eval_strategy="epoch",
    save_strategy="no",
    logging_strategy="epoch",
    predict_with_generate=True,
    gradient_accumulation_steps=max(1, 32 // batch_size),
    fp16=False,
    report_to="none"
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


gc.collect()
trainer.train()

save_dir = "./gpt2_full_sarcasm_final"
os.makedirs(save_dir, exist_ok=True)

model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
torch.save(model.state_dict(), f"{save_dir}/gpt2_full_sarcasm_final.pt")

# Full fine-tuned generation

In [None]:
base_model_name = "gpt2"
fine_tuned_path = "./gpt2_full_sarcasm_final"   

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(fine_tuned_path)
tokenizer.pad_token = tokenizer.eos_token  
tokenizer.padding_side = "left"

base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
fine_tuned_model = AutoModelForCausalLM.from_pretrained(fine_tuned_path).to(device)

def generate_response(model, sentence):
    prompt = (
        "Explain why this text is sarcastic"
        f"Sentence: \"{sentence}\"\n"
        "Explanation:"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        top_k=50,
        num_beams=1,
        no_repeat_ngram_size=3,
        repetition_penalty=1.2,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    # keep only what comes after "Only output the Explanation:"
    if "Explanation:" in text:
        text = text.split("Explanation:")[-1].strip()

    # cut off anything that looks like the prompt itself
    text = re.sub(
        r"^(Explain why.*?Explanation:\s*)", "", text, flags=re.IGNORECASE | re.DOTALL
    )

    # remove stray prefixes
    text = re.sub(r"^(Explanation|Answer|Response)\s*[:\-]\s*", "", text, flags=re.IGNORECASE)
    return text.strip()


sentences = [
    "Oh perfect, the fire alarm goes off right when I start my presentation.",
    "I just love when my boss schedules a meeting during lunch.",
    "Oh great, the printer jammed right before the deadline.",
    "Wonderful, traffic is even worse than yesterday!",
    "Yeah, because everyone totally loves working overtime for free.",
    "Crying before I go into work... This is going to be a great night. #Sarcasm #WishItWasTrue",
    "Oh sure, because staying up till 3am totally helps with productivity ðŸ˜’",
    "hey <user> thanks for making it easy for me to take my music with me . # ihateyourupdates",
    "It could confuse your muscles and make muscle grow in places where you didn't actually work out.",
    "Yay, 2-hour traffic for a 10-minute errand. Exactly what I needed ðŸ™ƒ",
    "This guy gets a gold star for such excellent parking in the handicap lot!",
    "How else will we feel superior if not by our amazing taste in phones?",
    "Guess Iâ€™ll just refresh the page for the 20th time. Maybe thatâ€™ll fix it ðŸ¤¡",
    "My phone dying at 5% is the highlight of my day.",
    "parrot's previous owner obviously watched a lot of the price is right",
    "Oh totally, I love when people reply â€˜kâ€™ to my long texts.",
    "Gotta save our children from the dangers of text on a screen in a rhythm game",
    "Great, another inspirational quote on LinkedIn. Just what I needed.",
    "even aside from the blatant misogyny, this is great because we have so much space in our prisons!"
]

for s in sentences:
    base_out = generate_response(base_model, s)
    fine_out = generate_response(fine_tuned_model, s)

    print(f"\n Sentence: {s}")
    print(f"Base GPT-2: {base_out}")
    print(f"Fine-tuned GPT-2: {fine_out}")