In [None]:
import numpy as np
import re
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
from datasets import concatenate_datasets
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer as Trainer
from transformers import Seq2SeqTrainingArguments as TrainingArguments
import torch
import evaluate
import gc
import psutil
import json
import os
from sklearn.model_selection import train_test_split
from copy import deepcopy
from torch.utils.data import ConcatDataset

In [None]:
# Load dataset
data = pd.read_csv("sarcasm_KD_final.csv").fillna("")

train_df, val_df = train_test_split(data, test_size=0.2, random_state=4213)

print(train_df.head())

In [None]:
PROMPT_A = (
    "In exactly 1-2 sentences, identify the specific words or phrases that make the text sarcastic "
    "and explain how they create the sarcastic effect. "
    "Focus only on observable linguistic elements without adding interpretation beyond what's directly evident in the text.\n\n"
)

PROMPT_B = (
    "In exactly 1-2 sentences, explain what the speaker actually means by removing the sarcasm "
    "and stating their true intended message directly. "
    "Focus on the genuine sentiment or opinion being expressed beneath the sarcastic language.\n\n"
)

def build_target(row, column_name):
    """Construct the output text (label) for training."""
    exp = str(row[column_name]).strip()
    return f"Explanation: {exp}"

# Task A â€” sarcasm cue identification
train_df["target_text_A"] = train_df.apply(lambda r: build_target(r, "part_sarcastic"), axis=1)
val_df["target_text_A"]   = val_df.apply(lambda r: build_target(r, "part_sarcastic"), axis=1)

# Task B â€” true intent explanation
train_df["target_text_B"] = train_df.apply(lambda r: build_target(r, "sarcasm_explanation"), axis=1)
val_df["target_text_B"]   = val_df.apply(lambda r: build_target(r, "sarcasm_explanation"), axis=1)


taskA_train_ds = Dataset.from_pandas(train_df[["text", "target_text_A"]].rename(columns={"target_text_A": "target_text"}))
taskA_val_ds   = Dataset.from_pandas(val_df[["text", "target_text_A"]].rename(columns={"target_text_A": "target_text"}))

taskB_train_ds = Dataset.from_pandas(train_df[["text", "target_text_B"]].rename(columns={"target_text_B": "target_text"}))
taskB_val_ds   = Dataset.from_pandas(val_df[["text", "target_text_B"]].rename(columns={"target_text_B": "target_text"}))

model_name = "./flan_t5_full_sarcasm_final"
tokenizer = AutoTokenizer.from_pretrained(model_name)

MAX_SRC_LEN = 128
MAX_TGT_LEN = 64

def make_preprocess_fn(prompt, tokenizer):
    def preprocess(examples):
        inputs = [prompt + "Text: " + t for t in examples["text"]]
        model_inputs = tokenizer(inputs, max_length=MAX_SRC_LEN, truncation=True)
        labels = tokenizer(examples["target_text"], max_length=MAX_TGT_LEN, truncation=True)  
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    return preprocess

preprocess_A = make_preprocess_fn(PROMPT_A, tokenizer)
preprocess_B = make_preprocess_fn(PROMPT_B, tokenizer)

taskA_train_tok = taskA_train_ds.map(preprocess_A, batched=True, remove_columns=taskA_train_ds.column_names)
taskA_val_tok   = taskA_val_ds.map(preprocess_A,   batched=True, remove_columns=taskA_val_ds.column_names)

taskB_train_tok = taskB_train_ds.map(preprocess_B, batched=True, remove_columns=taskB_train_ds.column_names)
taskB_val_tok   = taskB_val_ds.map(preprocess_B,   batched=True, remove_columns=taskB_val_ds.column_names)

In [None]:
#SEQUENTIALLY TRAINING GENERALIST

import gc,psutil
# ============================================================
# ðŸ§¹ Memory Cleanup Helper
# ============================================================
def clear_memory(tag=""):
    """Clear CUDA + CPU memory to prevent OOM between runs."""
    print(f"\nðŸ§¹ Clearing memory {tag} ...")
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()
        mem_alloc = torch.cuda.memory_allocated() / 1024**2
        mem_reserved = torch.cuda.memory_reserved() / 1024**2
        print(f"   CUDA memory allocated: {mem_alloc:.2f} MB | reserved: {mem_reserved:.2f} MB")
    process = psutil.Process(os.getpid())
    print(f"   CPU RSS: {process.memory_info().rss / 1024**2:.2f} MB\n")


params_path_A = "./optuna_results/best_t5_taskA_params.json"
with open(params_path_A, "r") as f:
    bestA = json.load(f)

learning_rate_A = bestA["learning_rate"]
batch_size_A    = bestA["batch_size"]
dropout_rate_A  = bestA["dropout_rate"]
weight_decay_A  = bestA["weight_decay"]
warmup_ratio_A  = bestA["warmup_ratio"]

print("Loaded Task A tuned params:")
print(json.dumps(bestA, indent=4))

params_path_B = "./optuna_results/best_t5_taskB_params.json"
with open(params_path_B, "r") as f:
    bestB = json.load(f)

learning_rate_B = bestB["learning_rate"]
batch_size_B    = bestB["batch_size"]
dropout_rate_B  = bestB["dropout_rate"]
weight_decay_B  = bestB["weight_decay"]
warmup_ratio_B  = bestB["warmup_ratio"]

print("Loaded Task B tuned params:")
print(json.dumps(bestB, indent=4))


rouge = evaluate.load("rouge")
model_name = "./flan_t5_full_sarcasm_final"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = preds[0] if isinstance(preds, tuple) else preds
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    pred_texts = tokenizer.batch_decode(preds, skip_special_tokens=True)
    label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
    score = rouge.compute(predictions=pred_texts, references=label_texts)
    return {"rougeL": round(score["rougeL"], 4)}


modelA = AutoModelForSeq2SeqLM.from_pretrained(model_name)
modelA.config.dropout_rate = dropout_rate_A
modelA.config.attention_dropout_rate = dropout_rate_A

collatorA = DataCollatorForSeq2Seq(tokenizer, model=modelA)

argsA = Seq2SeqTrainingArguments(
    learning_rate=learning_rate_A,
    per_device_train_batch_size=batch_size_A,
    per_device_eval_batch_size=max(2, batch_size_A // 2), #prevent OOM
    num_train_epochs=3,
    weight_decay=weight_decay_A,
    warmup_ratio=warmup_ratio_A,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    predict_with_generate=True,
    report_to="none"
)

trainerA = Seq2SeqTrainer(
    model=modelA,
    args=argsA,
    train_dataset=taskA_train_tok,
    eval_dataset=taskA_val_tok,
    data_collator=collatorA,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print(" Starting fine-tuning on Task A...")
trainerA.train()
trainerA.save_model("./model_final_taskA")

del modelA, trainerA, collatorA, argsA
clear_memory("(after Task A)")


modelB = AutoModelForSeq2SeqLM.from_pretrained("./model_final_taskA")
modelB.config.dropout_rate = dropout_rate_B
modelB.config.attention_dropout_rate = dropout_rate_B

collatorB = DataCollatorForSeq2Seq(tokenizer, model=modelB)

argsB = Seq2SeqTrainingArguments(
    learning_rate=learning_rate_B,
    per_device_train_batch_size=batch_size_B,
    per_device_eval_batch_size=max(2, batch_size_B // 2), #prevent OOM
    num_train_epochs=3,
    weight_decay=weight_decay_B,
    warmup_ratio=warmup_ratio_B,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    predict_with_generate=True,
    report_to="none"
)

trainerB = Seq2SeqTrainer(
    model=modelB,
    args=argsB,
    train_dataset=taskB_train_tok,
    eval_dataset=taskB_val_tok,
    data_collator=collatorB,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print("Continuing fine-tuning on Task B...")
trainerB.train()
trainerB.save_model("./model_final_dualtask")

del modelB, trainerB, collatorB, argsB
clear_memory("(after Task B)")

print("Training complete. Sequential Generalist Model saved at ./model_final_dualtask")

In [None]:
# Stage 1 â€” Joint Learner + Generator Training
from copy import deepcopy
from torch.utils.data import ConcatDataset

params_path_A = "./optuna_results/best_t5_taskA_params.json"
with open(params_path_A, "r") as f:
    bestA = json.load(f)

learning_rate_A = bestA["learning_rate"]
batch_size_A    = bestA["batch_size"]
dropout_rate_A  = bestA["dropout_rate"]
weight_decay_A  = bestA["weight_decay"]
warmup_ratio_A  = bestA["warmup_ratio"]

print("Loaded Task A tuned params:")
print(json.dumps(bestA, indent=4))

params_path_B = "./optuna_results/best_t5_taskB_params.json"
with open(params_path_B, "r") as f:
    bestB = json.load(f)

learning_rate_B = bestB["learning_rate"]
batch_size_B    = bestB["batch_size"]
dropout_rate_B  = bestB["dropout_rate"]
weight_decay_B  = bestB["weight_decay"]
warmup_ratio_B  = bestB["warmup_ratio"]

print("Loaded Task B tuned params:")
print(json.dumps(bestB, indent=4))

rouge = evaluate.load("rouge")
model_name = "./flan_t5_full_sarcasm_final"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = preds[0] if isinstance(preds, tuple) else preds
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    pred_texts = tokenizer.batch_decode(preds, skip_special_tokens=True)
    label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
    score = rouge.compute(predictions=pred_texts, references=label_texts)
    return {"rougeL": round(score["rougeL"], 4)}

# Base model (acts as learner)
modelA = AutoModelForSeq2SeqLM.from_pretrained(model_name)
modelA.config.dropout_rate = dropout_rate_A
modelA.config.attention_dropout_rate = dropout_rate_A

# Clone it as generator (separate optimizer)
generatorA = deepcopy(modelA)

collatorA = DataCollatorForSeq2Seq(tokenizer, model=modelA)

argsA = Seq2SeqTrainingArguments(
    learning_rate=learning_rate_A,
    per_device_train_batch_size=batch_size_A,
    per_device_eval_batch_size=max(2, batch_size_A // 2),
    num_train_epochs=3,
    weight_decay=weight_decay_A,
    warmup_ratio=warmup_ratio_A,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    predict_with_generate=True,
    report_to="none"
)

# Trainer for the learner
trainerA = Seq2SeqTrainer(
    model=modelA,
    args=argsA,
    train_dataset=taskA_train_tok,
    eval_dataset=taskA_val_tok,
    data_collator=collatorA,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# Trainer for generator â€” self-supervised reconstruction
trainerG = Seq2SeqTrainer(
    model=generatorA,
    args=deepcopy(argsA),
    train_dataset=taskA_train_tok,
    eval_dataset=taskA_val_tok,
    data_collator=collatorA,
    tokenizer=tokenizer,
)

print("Joint training learner + generator on Task A...")
trainerA.train()
trainerG.train()

# Save both models
trainerA.save_model("./model_final_taskA_learner")
trainerG.save_model("./model_final_taskA_generator")


learnerB = AutoModelForSeq2SeqLM.from_pretrained("./model_final_taskA_learner")
generatorB = AutoModelForSeq2SeqLM.from_pretrained("./model_final_taskA_generator")


def generate_taskA_examples(generator, tokenizer, raw_dataset, num_samples=200):
    """Generate replay samples from raw Task A validation data."""
    generator = generator.to("cuda")
    generator.eval()

    # Select subset
    sample_data = raw_dataset.select(range(min(num_samples, len(raw_dataset))))

    # Use the original Task A prompt
    prompts = [
        "In exactly 1â€“2 sentences, identify the specific words or phrases "
        "that make the text sarcastic and explain how they create the sarcastic effect. "
        "Focus only on observable linguistic elements without adding interpretation beyond what's directly evident in the text.\n\n"
        f"Text: {x}"
        for x in sample_data["text"]
    ]

    # Tokenize + generate outputs
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")
    with torch.no_grad():
        outputs = generator.generate(**inputs, max_new_tokens=64)
    gens = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Return replay samples as dictionaries
    replay_data = [{"text": p, "target_text": g} for p, g in zip(prompts, gens)]
    return replay_data

print("Generating replay samples from Task A...")
replayA = generate_taskA_examples(generatorB, tokenizer, taskA_val_ds, num_samples=200)


replayA_ds = Dataset.from_list(replayA)
replayA_tok = replayA_ds.map(
    preprocess_A,
    batched=True,
    remove_columns=replayA_ds.column_names
)


combined_train = concatenate_datasets([taskB_train_tok, replayA_tok])
combined_val   = concatenate_datasets([taskB_val_tok, replayA_tok])  # optional validation replay

print(f"Combined training size: {len(combined_train)} samples")


collatorB = DataCollatorForSeq2Seq(tokenizer, model=learnerB)

argsB = Seq2SeqTrainingArguments(
    learning_rate=learning_rate_B,
    per_device_train_batch_size=batch_size_B,
    per_device_eval_batch_size=max(2, batch_size_B // 2),
    num_train_epochs=3,
    weight_decay=weight_decay_B,
    warmup_ratio=warmup_ratio_B,
    eval_strategy="epoch",     
    save_strategy="epoch",
    logging_strategy="epoch",
    predict_with_generate=True,
    report_to="none"
)

trainerB = Seq2SeqTrainer(
    model=learnerB,
    args=argsB,
    train_dataset=combined_train,
    eval_dataset=combined_val,
    data_collator=collatorB,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)


trainerB.train()
trainerB.save_model("./model_final_lamol")


In [None]:
# Generation for Task A (cue task)

specilist_cue = "./model_final_taskA"
generalist_seq = "./model_final_dualtask"  
generalist_lamol = "./model_final_lamol"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


tokenizer = AutoTokenizer.from_pretrained(specilist_cue)

specialist_model = AutoModelForSeq2SeqLM.from_pretrained(specilist_cue).to(device)
generalist_seq_model = AutoModelForSeq2SeqLM.from_pretrained(generalist_seq).to(device)
generalist_lamol_model = AutoModelForSeq2SeqLM.from_pretrained(generalist_lamol).to(device)


def generate_response(model, sentence):
    prompt = (
        "In exactly 1-2 sentences, identify the specific words or phrases "
        "that make the text sarcastic and explain how they create the sarcastic effect. "
        "Focus only on observable linguistic elements without adding interpretation "
        "beyond what's directly evident in the text."
        f"Sentence: \"{sentence}\"\n"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=80,
        do_sample=True,
        temperature=2.3,       
        top_p=0.6,            
        top_k=60,
        num_beams=10,           
        no_repeat_ngram_size=3,
        repetition_penalty=1.4,  
        length_penalty=1.0,
    )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    text = re.sub(r"^(Explanation|Answer|Response)\s*:\s*", "", text, flags=re.IGNORECASE)
    return text


sentences = [
    "Yay my shoe broke!",
    "If the shooter shouldn't have been able to get a gun, the solution is obviously more guns, right? <URL>",
    "baby feels foolish after realizing stranger waving at toddler next seat over",
    "Wonderful, traffic is even worse than yesterday!",
    "I'm so sorry I can't read sarcasm over the internet",
    "Crying before I go into work... This is going to be a great night. #Sarcasm #WishItWasTrue",
    "In a cab on the way home from the airport. What a long day. Work tomorrow is going to be AMAZING. Should be home by 3AM.",
    "hey <user> thanks for making it easy for me to take my music with me . # ihateyourupdates",
    "It could confuse your muscles and make muscle grow in places where you didn't actually work out.",
    "Yay, 2-hour traffic for a 10-minute errand. Exactly what I needed ðŸ™ƒ",
    "This guy gets a gold star for such excellent parking in the handicap lot!",
    "How else will we feel superior if not by our amazing taste in phones?",
    "Received a compliment today that I look very relaxed. If only this person knew just how much effort it takes to look this relaxed.",
    "My phone dying at 5% is the highlight of my day.",
    "overweight man repeatedly introduced to overweight woman at party",
    "How dare you type out Obama's name and not praise him you racist",
    "No, it's perfectly safe for nurses to shove pills in your mouth without any education",
    "Great, another inspirational quote on LinkedIn. Just what I needed.",
    "even aside from the blatant misogyny, this is great because we have so much space in our prisons!",
    "DSA4213 is sooo easy even my grandma can score A+"
    "Trains are delayed on both directions. Instead of seeing people rushing to take the bus or cab, they were taking pictures. Haha..",
    "I have never felt more alive than during DSA4213 finetuning, nothing like a few CUDA OOMs to keep the adrenaline going.",
    "I love how DSA4213 keeps me humble, every single assignment and quizzes",
    "DSA4213 is not that hard, I just needed 4 GPUs and divine intervention"
]

for s in sentences:
    s_out = generate_response(specialist_model, s)
    gseq_out = generate_response(generalist_seq_model, s)
    gl_out = generate_response(generalist_lamol_model, s)
    
    print(f"\n Sentence: {s}")
    print(f"Cue task Specialist: {s_out}")
    print(f"Generalist Seq: {gseq_out}")
    print(f"Generalist LAMOL: {gl_out}")

In [None]:
# Generation for Task B (Explain task)

specilist_explain = "./model_final_taskB"
generalist_seq = "./model_final_dualtask" 
generalist_lamol = "./model_final_lamol"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(specilist_cue)

specialist_model = AutoModelForSeq2SeqLM.from_pretrained(specilist_explain).to(device)
generalist_seq_model = AutoModelForSeq2SeqLM.from_pretrained(generalist_seq).to(device)
generalist_lamol_model = AutoModelForSeq2SeqLM.from_pretrained(generalist_lamol).to(device)



def generate_response(model, sentence):
    prompt = (
        "In exactly 1-2 sentences, explain what the speaker actually means by removing the sarcasm "
        "and stating their true intended message directly. "
        "Focus on the genuine sentiment or opinion being expressed beneath the sarcastic language.\n\n"
        f"Sentence: \"{sentence}\"\n"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
            **inputs,
            max_new_tokens=80,
            do_sample=True,
            temperature=2.3,       
            top_p=0.6,          
            top_k=60,
            num_beams=10,          
            no_repeat_ngram_size=3,
            repetition_penalty=1.4, 
            length_penalty=1.0,
        )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    text = re.sub(r"^(Explanation|Answer|Response)\s*:\s*", "", text, flags=re.IGNORECASE)
    return text


sentences = [
    "Yay my shoe broke!",
    "If the shooter shouldn't have been able to get a gun, the solution is obviously more guns, right? <URL>",
    "baby feels foolish after realizing stranger waving at toddler next seat over",
    "Wonderful, traffic is even worse than yesterday!",
    "I'm so sorry I can't read sarcasm over the internet",
    "Crying before I go into work... This is going to be a great night. #Sarcasm #WishItWasTrue",
    "In a cab on the way home from the airport. What a long day. Work tomorrow is going to be AMAZING. Should be home by 3AM.",
    "hey <user> thanks for making it easy for me to take my music with me . # ihateyourupdates",
    "It could confuse your muscles and make muscle grow in places where you didn't actually work out.",
    "Yay, 2-hour traffic for a 10-minute errand. Exactly what I needed ðŸ™ƒ",
    "This guy gets a gold star for such excellent parking in the handicap lot!",
    "How else will we feel superior if not by our amazing taste in phones?",
    "Received a compliment today that I look very relaxed. If only this person knew just how much effort it takes to look this relaxed.",
    "My phone dying at 5% is the highlight of my day.",
    "overweight man repeatedly introduced to overweight woman at party",
    "How dare you type out Obama's name and not praise him you racist",
    "No, it's perfectly safe for nurses to shove pills in your mouth without any education",
    "Great, another inspirational quote on LinkedIn. Just what I needed.",
    "even aside from the blatant misogyny, this is great because we have so much space in our prisons!",
    "DSA4213 is sooo easy even my grandma can score A+"
    "Trains are delayed on both directions. Instead of seeing people rushing to take the bus or cab, they were taking pictures. Haha..",
    "I have never felt more alive than during DSA4213 finetuning, nothing like a few CUDA OOMs to keep the adrenaline going.",
    "I love how DSA4213 keeps me humble, every single assignment and quizzes",
    "DSA4213 is not that hard, I just needed 4 GPUs and divine intervention"
]

for s in sentences:
    s_out = generate_response(specialist_model, s)
    gseq_out = generate_response(generalist_seq_model, s)
    gl_out = generate_response(generalist_lamol_model, s)
    
    print(f"\n Sentence: {s}")
    print(f"Explain task Specialist: {s_out}")
    print(f"Generalist Seq: {gseq_out}")
    print(f"Generalist LAMOL: {gl_out}")