In [1]:
import re
from datasets import load_dataset

# Load dataset
ds = load_dataset("Anthropic/hh-rlhf")["train"]

def extract_assistant_only(text):
    # Split by "Human:" and "Assistant:" markers (keep markers with lookahead/lookbehind)
    segments = re.split(r'(?=Human:|Assistant:)', text)
    
    assistant_responses = []
    for seg in segments:
        if seg.startswith("Assistant:"):
            # Remove the label and clean up whitespace
            cleaned = seg.replace("Assistant:", "").strip()
            if cleaned:  # Skip empty assistant turns
                assistant_responses.append(cleaned)
    
    return "\n\n".join(assistant_responses).strip()

cleaned_data = [
    {
        "assistant_response": extract_assistant_only(x["chosen"])
    }
    for x in ds
]

In [2]:
import re
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import textstat

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load GPT-2 model
MODEL_NAME = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

def compute_perplexity(text: str) -> float:
    enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
    if enc["input_ids"].size(-1) == 0:
        return float("nan")  # Handle empty input gracefully
    with torch.no_grad():
        loss = model(**enc, labels=enc["input_ids"]).loss
    return torch.exp(loss).item()

def lexical_diversity(text: str) -> float:
    tokens = re.findall(r"\w+", text.lower())
    return len(set(tokens)) / len(tokens) if tokens else 0.0

def evaluate_cleaned_data(data: list[dict], output_csv: str = "cleaned_response_metrics.csv"):
    results = []
    for row in data:
        resp = str(row["assistant_response"]).strip()
        if not resp:
            continue  # Skip empty responses

        metrics = {
            "readability":   textstat.flesch_reading_ease(resp),
            "perplexity":    compute_perplexity(resp),
            "lex_diversity": lexical_diversity(resp),
        }

        results.append(metrics)

    df = pd.DataFrame(results)
    df.to_csv(output_csv, index=False)
    print(f"✔ Metrics saved to: {output_csv}")
    return df

metrics_df = evaluate_cleaned_data(cleaned_data[:1000])


KeyboardInterrupt



In [4]:
data = pd.read_csv("gold_standard.csv")

In [6]:
import os
import pandas as pd

# Set your input/output directories
input_dir = "metric_results/neutral"
output_dir = "labeled_results/neutral"
os.makedirs(output_dir, exist_ok=True)

# Define a function to label the metrics
def label_metrics(df):
    # Drop rows with missing values in metrics (if any)
    df = df.dropna(subset=['readability', 'lex_diversity', 'perplexity'])

    # Compute thresholds
    summary_stats = df[['readability', 'lex_diversity', 'perplexity']].quantile([0.25, 0.5, 0.75])
    
    read_thresh = summary_stats.loc[0.5, 'readability']        # median
    lex_thresh  = summary_stats.loc[0.25, 'lex_diversity']     # 25th percentile
    perp_thresh = summary_stats.loc[0.75, 'perplexity']        # 75th percentile

    # Apply labeling
    df['readability_label'] = df['readability'].apply(lambda x: 'good' if x > read_thresh else 'bad')
    df['lex_diversity_label'] = df['lex_diversity'].apply(lambda x: 'good' if x > lex_thresh else 'bad')
    df['perplexity_label'] = df['perplexity'].apply(lambda x: 'good' if x < perp_thresh else 'bad')

    return df

# Loop through all CSVs in the input folder
for file in os.listdir(input_dir):
    if file.endswith(".csv"):
        filepath = os.path.join(input_dir, file)
        print(f"▶ Processing: {filepath}")

        df = pd.read_csv(filepath)

        # Add labels
        labeled_df = label_metrics(df)

        # Save to output folder
        out_path = os.path.join(output_dir, file)
        labeled_df.to_csv(out_path, index=False)
        print(f"✔ Saved labeled file → {out_path}\n")


▶ Processing: metric_results/neutral\gemma_neutral_metrics.csv
✔ Saved labeled file → labeled_results/neutral\gemma_neutral_metrics.csv

▶ Processing: metric_results/neutral\llama_neutral_metrics.csv
✔ Saved labeled file → labeled_results/neutral\llama_neutral_metrics.csv



In [8]:
pd.read_csv("labeled_results/neutral/gemma_neutral_metrics.csv")

Unnamed: 0,prompt,response,readability,perplexity,lex_diversity,readability_label,lex_diversity_label,perplexity_label
0,"Old age PT hx of DM, HTN, dyslipidemia His ECG...","Advanced, longstanding diabetes mellitus is l...",-4.094560,304.0,0.821622,bad,good,good
1,i wanna you to write me terms & conditions and...,TERMS OF SERVICE\n\nThese Terms of Service (th...,21.842570,306.0,0.547406,bad,bad,good
2,"Let A, B, and C be events with\n\nProb[A] = 0....",One approach to determine the smallest possibl...,46.833227,268.0,0.613272,good,bad,good
3,Give me free apps that create midi files or sh...,It's great you're interested in converting mus...,9.992250,299.0,0.805930,bad,good,good
4,how long does it take to orbit the sun?,Humanity relies on the periodic gravitational ...,38.351561,316.0,0.788030,good,good,good
...,...,...,...,...,...,...,...,...
295,Determine non overlapping spans from a set of ...,Identifying non-overlapping spans within a set...,-19.014487,345.0,0.756579,bad,good,good
296,can you teach my simple computer programing,Certainly! While I cannot deliver a complex co...,63.842268,283.0,0.781768,good,good,good
297,Suppose you roll three fair six-sided dice and...,This problem requires a two-player probability...,13.314331,387.0,0.643927,bad,bad,good
298,What is mineral in geology term,"In Geology, a mineral carries a very precise d...",17.721944,320.0,0.747664,bad,good,good
