In [15]:
!pip install wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33myuchenzoe-xu[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [16]:
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import transformers
import re
import wandb


In [17]:
REWARD_MODEL = "tmrcnl/SarcasmRewardModel"
DATASET_PATH = "marcbishara/sarcasm-on-reddit"
SFT_MODEL = "Zoe3324/gpt2-sft-full-v2"
GPT2_MODEL = "gpt2"
MAX_LENGTH = 128
BATCH_SIZE = 32
SAMPLE_SIZE = 1000
device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
wandb.init(
    entity="zoe_123",
    project="gst_sarcasm_rm_eval",
    name="sft_vs_gpt2_avg_reward_1",
    config={
        "batch_size": BATCH_SIZE,
        "sample_size": SAMPLE_SIZE,
        "reward_model": REWARD_MODEL,
        "sft_model": SFT_MODEL
    },
    resume=False
)

0,1
batch_avg_reward/GPT2,▁▄▃█
batch_avg_reward/SFT,▄▆▁█
global_step,▁▃▆█▁▃▆█
overall_avg_reward/GPT2,▁
overall_avg_reward/SFT,▁

0,1
batch_avg_reward/GPT2,0.32945
batch_avg_reward/SFT,0.80405
global_step,3.0
overall_avg_reward/GPT2,0.25003
overall_avg_reward/SFT,0.73558


In [19]:
# Load dataset
dataset = load_dataset("marcbishara/sarcasm-on-reddit", split="holdout")
data = (dataset.shuffle(seed=42).select(range(SAMPLE_SIZE)))
parent_comments = data["parent_comment"]
print(f"Loaded {SAMPLE_SIZE} test samples")

Loaded 1000 test samples


In [20]:
# Load RM tokenizer/model
rm_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL)
if rm_tokenizer.pad_token is None:
    rm_tokenizer.pad_token = rm_tokenizer.eos_token
rm_model = AutoModelForSequenceClassification.from_pretrained(REWARD_MODEL).to(device)

# Load GPT2&SFT tokenizers/models
gpt2_tokenizer = AutoTokenizer.from_pretrained(GPT2_MODEL)
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
gpt2_model = AutoModelForCausalLM.from_pretrained(GPT2_MODEL).to(device)
gpt2_model.eval()

sft_tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL)
sft_tokenizer.pad_token = sft_tokenizer.eos_token
sft_model = AutoModelForCausalLM.from_pretrained(SFT_MODEL).to(device)
sft_model.eval();

In [21]:
 # Add tags to prompt
def build_prompt(parent_text: str) -> str:
    return f"<PARENT>{parent_text.strip()}</PARENT>\n<RESPONSE>"

# Remove output tags
def extract_clean_response(full_output: str, prompt: str) -> str:
    # Remove parent comment and parent tag
    full_output = re.sub(r"<PARENT>.*?</PARENT>", "", full_output, flags=re.DOTALL)
    # Fetch text in between response tag
    m = re.search(r"<RESPONSE>(.*?)</RESPONSE>", full_output, flags=re.DOTALL)
    if m:
        return m.group(1).strip()
    # fallback for output without </RESPONSE>
    if full_output.startswith(prompt):
        return full_output[len(prompt):].strip()

    # fallback for plain text
    return full_output.strip()

In [None]:
# Generate a model response given a parent comment, return cleaned reponses
def generate_responses(model, tokenizer, parent_comments):
    responses = []
    # Loop through each parent comment
    for text in tqdm(parent_comments, desc="Generating", unit="sample"):
        prompt = build_prompt(text)
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=80,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
            )

        full_output = tokenizer.decode(output[0], skip_special_tokens=True)
        clean_output = extract_clean_response(full_output, prompt)
        responses.append(clean_output)
    return responses

# Compute average rewards for responses using reward model
def calculate_avg_reward(prompts, responses, rm_tokenizer, rm_model, device, model_label):
    all_scores = []         # all individual reward scores
    batch_avg_rewards = []  # per-batch average reward scores

    for local_step, i in enumerate(range(0, len(prompts), BATCH_SIZE)):
        batch_prompts = prompts[i:i + BATCH_SIZE]
        batch_responses = responses[i:i + BATCH_SIZE]
        
        # Tokenize (prompt, response) pairs for RM
        rm_inputs = rm_tokenizer(
            batch_prompts,
            batch_responses,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        ).to(device)

        with torch.no_grad():
            rm_outputs = rm_model(**rm_inputs)

        # Get sarcasm score(probability) for label = 1(sarcasm)
        sarcasm_scores = torch.softmax(rm_outputs.logits, dim=-1)[:, 1].cpu().tolist()
        # Compute batch average
        batch_avg = sum(sarcasm_scores) / len(sarcasm_scores)
        batch_avg_rewards.append(batch_avg)
        all_scores.extend(sarcasm_scores)
        wandb.log({
            f"batch_avg_reward/{model_label}": batch_avg,
            "global_step": local_step
        })
    # Compute overall average score
    overall_avg = sum(all_scores) / len(all_scores)
    wandb.log({f"overall_avg_reward/{model_label}": overall_avg})
    return overall_avg, batch_avg_rewards

In [23]:
print("\nEvaluating SFT model")
sft_outputs = generate_responses(sft_model, sft_tokenizer, parent_comments)
sft_avg, sft_batch_rewards = calculate_avg_reward(
    parent_comments, sft_outputs,
    rm_tokenizer, rm_model, device,
    model_label="SFT"
)

print("\nEvaluating GPT-2 model")
gpt2_outputs = generate_responses(gpt2_model, gpt2_tokenizer, parent_comments)
gpt2_avg, gpt2_batch_rewards = calculate_avg_reward(
    parent_comments, gpt2_outputs,
    rm_tokenizer, rm_model, device,
    model_label="GPT2"
)

print(f"SFT model avg reward:   {sft_avg:.4f}")
print(f"GPT-2 model avg reward: {gpt2_avg:.4f}")


Evaluating SFT model


Generating: 100%|██████████| 1000/1000 [04:30<00:00,  3.70sample/s]



Evaluating GPT-2 model


Generating: 100%|██████████| 1000/1000 [14:11<00:00,  1.17sample/s]


SFT model avg reward:   0.7290
GPT-2 model avg reward: 0.2214
