In [1]:
%pip install "torch>=2.4.0" tensorboard

%pip install "transformers>=4.51.3"
%pip install vllm
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.21.0" \
  "peft==0.14.0" \
  protobuf \
  sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
%pip install flash-attn


Collecting datasets==3.3.2
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting accelerate==1.4.0
  Downloading accelerate-1.4.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate==0.4.3
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting bitsandbytes==0.45.3
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting trl==0.21.0
  Downloading trl-0.21.0-py3-none-any.whl.metadata (11 kB)
Collecting peft==0.14.0
  Downloading peft-0.14.0-py3-none-any.whl.metadata (13 kB)
Collecting protobuf
  Downloading protobuf-6.33.1-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m41.3 MB/s[0m eta [36m0:00:00

Collecting flash-attn
  Downloading flash_attn-2.8.3.tar.gz (8.4 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/8.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━[0m [32m6.2/8.4 MB[0m [31m187.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m8.4/8.4 MB[0m [31m195.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m117.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.8.3-cp312-cp312-linux_x86_64.whl size=256040057 sha256=f25da18657a87fc83dc1bfb8b7751b82246e9db355510226b674fd437c34b5fb
  Stored in directory: /root/.cache/pip/wheels/3d/59/46/f282c12c73dd4bb3c2e3fe199f1

In [3]:
from google.colab import userdata
from huggingface_hub import login
import json

hf_token = userdata.get('HF_TOKEN')
login(hf_token)

In [4]:
import pandas as pd
import json

# Note: This script requires the following libraries to be installed:
# pip install pandas pyarrow fsspec huggingface_hub

# 1. Load the dataset into a pandas DataFrame (your preferred method)
print("Loading the dataset into a pandas DataFrame...")
try:
    df = pd.read_parquet("hf://datasets/ucberkeley-dlab/measuring-hate-speech/measuring-hate-speech.parquet")
    print(f"Dataset loaded successfully with {len(df)} rows and {len(df.columns)} columns.")
except Exception as e:
    print(f"An error occurred while loading the data: {e}")
    print("Please ensure you have run 'pip install pandas pyarrow fsspec huggingface_hub'")
    exit()

# 2. Define the columns we want to extract
# These are the actual facet and target columns present in the dataset
FACET_COLUMNS = [
    'sentiment', 'respect', 'insult', 'humiliate', 'status', 'dehumanize',
    'violence', 'genocide', 'attack_defend', 'hatespeech'
]

TARGET_COLUMNS = [
    "target_race_asian", "target_race_black", "target_race_latinx", "target_race_middle_eastern",
    "target_race_native_american", "target_race_pacific_islander", "target_race_white", "target_race_other",
    "target_religion_atheist", "target_religion_buddhist", "target_religion_christian", "target_religion_hindu",
    "target_religion_jewish", "target_religion_mormon", "target_religion_muslim", "target_religion_other",
    "target_origin_immigrant",  "target_origin_migrant_worker", "target_origin_specific_country",
    "target_origin_undocumented", "target_origin_other", "target_gender_men", "target_gender_non_binary",
    "target_gender_transgender_men", "target_gender_transgender_unspecified", "target_gender_transgender_women",
    "target_gender_women", "target_gender_other", "target_sexuality_bisexual", "target_sexuality_gay",
    "target_sexuality_lesbian", "target_sexuality_straight", "target_sexuality_other",
    "target_age_children", "target_age_teenagers", "target_age_young_adults", "target_age_middle_aged", "target_age_seniors",
    "target_age_other",  "target_disability_physical", "target_disability_cognitive",
    "target_disability_neurological", "target_disability_visually_impaired",
    "target_disability_hearing_impaired", "target_disability_unspecific", "target_disability_other"
]

# Helper function to classify the score based on the paper's logic
def classify_score(score):
    if score > 0.5:
        return "hateful"
    if score < -1.0:
        return "supportive"
    return "neutral"

# 3. Create the function to transform each row of the DataFrame
def create_gold_standard_record(row):
    # Create the 'overall' object
    overall = {
        "label": classify_score(row['hate_speech_score']),
        "hate_speech_score": row['hate_speech_score']
    }

    # Create the 'facets' object
    facets = {col: row[col] for col in FACET_COLUMNS}

    # Create the 'targets' object (ensuring columns exist)
    targets = {col: bool(row[col]) for col in TARGET_COLUMNS if col in row}

    # Assemble the final record as a dictionary
    return {
        "comment_id": row['comment_id'],
        "text": row['text'],
        "overall": overall,
        "facets": facets,
        "targets": targets
    }

# 4. Apply the function to each row of the DataFrame
print("Applying the transformation to each row of the DataFrame...")
gold_records = df.apply(create_gold_standard_record, axis=1).tolist()

print("\nExample of a processed record:")
print(json.dumps(gold_records[0], indent=2))

output_file = "gold_benchmark_dataset.jsonl"
print(f"\nSaving the {len(gold_records)} records to '{output_file}'...")

with open(output_file, 'w') as f:
    for record in gold_records:
        f.write(json.dumps(record) + '\n')
import json
import pandas as pd
from sklearn.model_selection import train_test_split

INPUT_FILE = "gold_benchmark_dataset.jsonl"
TRAIN_FILE = "train_aggregated.jsonl"
VAL_FILE = "val_aggregated.jsonl"
TEST_FILE = "test_aggregated.jsonl"

def load_data(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

def save_jsonl(data, path):
    with open(path, "w", encoding="utf-8") as f:
        for item in data:
            f.write(json.dumps(item) + "\n")

def aggregate_annotations(records):
    """Aggregate multiple annotations for the same comment into one record."""
    if len(records) == 1:
        return records[0]

    aggregated = {
        "comment_id": records[0]["comment_id"],
        "text": records[0]["text"],
        "overall": {},
        "facets": {},
        "targets": {}
    }

    # Aggregate hate_speech_score (mean)
    hate_scores = [r["overall"]["hate_speech_score"] for r in records]
    avg_hate_score = sum(hate_scores) / len(hate_scores)

    # Classify based on averaged score
    if avg_hate_score > 0.5:
        label = "hateful"
    elif avg_hate_score < -1.0:
        label = "supportive"
    else:
        label = "neutral"

    aggregated["overall"] = {
        "label": label,
        "hate_speech_score": avg_hate_score
    }

    # Aggregate facets (mean then round)
    facet_keys = records[0]["facets"].keys()
    for key in facet_keys:
        values = [r["facets"][key] for r in records]
        aggregated["facets"][key] = round(sum(values) / len(values))

    # Aggregate targets (OR logic)
    target_keys = records[0]["targets"].keys()
    for key in target_keys:
        aggregated["targets"][key] = any(r["targets"][key] for r in records)

    return aggregated

def main():
    # Load all data
    data = load_data(INPUT_FILE)
    print(f"Loaded {len(data)} total records")

    # Group by comment_id
    comment_groups = {}
    for record in data:
        comment_id = record["comment_id"]
        if comment_id not in comment_groups:
            comment_groups[comment_id] = []
        comment_groups[comment_id].append(record)

    unique_comment_ids = list(comment_groups.keys())
    print(f"Found {len(unique_comment_ids)} unique comments")

    # Split comment_ids: 80% train, 10% val, 10% test
    train_ids, temp_ids = train_test_split(unique_comment_ids, test_size=0.2, random_state=42)
    val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

    # ✅ CHANGE: Aggregate ALL splits (not just val/test)
    train_data = [aggregate_annotations(comment_groups[comment_id]) for comment_id in train_ids]
    val_data = [aggregate_annotations(comment_groups[comment_id]) for comment_id in val_ids]
    test_data = [aggregate_annotations(comment_groups[comment_id]) for comment_id in test_ids]

    # Save datasets
    save_jsonl(train_data, TRAIN_FILE)
    save_jsonl(val_data, VAL_FILE)
    save_jsonl(test_data, TEST_FILE)

    print(f"\n✅ Data split complete:")
    print(f"Train: {len(train_data)} aggregated records ({len(train_ids)} unique comments)")
    print(f"Val: {len(val_data)} aggregated records ({len(val_ids)} unique comments)")
    print(f"Test: {len(test_data)} aggregated records ({len(test_ids)} unique comments)")
    print(f"Total: {len(train_data) + len(val_data) + len(test_data)} records")
    print(f"Reduction: {len(data)} → {len(train_data) + len(val_data) + len(test_data)} (~{100*(1 - (len(train_data) + len(val_data) + len(test_data))/len(data)):.1f}% reduction)")

    # Verify no overlap
    assert len(set(train_ids) & set(val_ids)) == 0, "Train/Val overlap detected!"
    assert len(set(train_ids) & set(test_ids)) == 0, "Train/Test overlap detected!"
    assert len(set(val_ids) & set(test_ids)) == 0, "Val/Test overlap detected!"
    print("\n✅ Verified: No comment_id overlap between splits")

if __name__ == "__main__":
    main()

Loading the dataset into a pandas DataFrame...
Dataset loaded successfully with 135556 rows and 131 columns.
Applying the transformation to each row of the DataFrame...

Example of a processed record:
{
  "comment_id": 47777,
  "text": "Yes indeed. She sort of reminds me of the elder lady that played the part in the movie \"Titanic\" who was telling her story!!! And I wouldn't have wanted to cover who I really am!! I would be proud!!!! WE should be proud of our race no matter what it is!!",
  "overall": {
    "label": "supportive",
    "hate_speech_score": -3.9
  },
  "facets": {
    "sentiment": 0.0,
    "respect": 0.0,
    "insult": 0.0,
    "humiliate": 0.0,
    "status": 2.0,
    "dehumanize": 0.0,
    "violence": 0.0,
    "genocide": 0.0,
    "attack_defend": 0.0,
    "hatespeech": 0.0
  },
  "targets": {
    "target_race_asian": true,
    "target_race_black": true,
    "target_race_latinx": true,
    "target_race_middle_eastern": true,
    "target_race_native_american": true,
   

In [5]:

# Your instruction template
INSTRUCTION = """
You are an expert hate speech analyst. Your task is to analyze the provided text and return ONLY a valid JSON object that strictly follows the schema below.
Do not include any explanations, markdown formatting, or text outside of the JSON object.

=========================
IMPORTANT INSTRUCTIONS
=========================
1. Output must be **valid JSON only** - no markdown, no commentary.
2. Use exact field names and types from the schema.
3. Type requirements:
   - `overall.label` → string ("hateful", "supportive", or "neutral")
   - `overall.hate_speech_score` → float (e.g., `1.52`, `-1.35`, `0.12`)
   - All `facets` values → integers 0-4 only (e.g., `3`, not `3.0`)
   - All `targets` values → booleans (`true` or `false`)

=========================
OVERALL ASSESSMENT
=========================
Provide two values in the `"overall"` object:

1. **`"hate_speech_score"`** - A float representing hate speech intensity:
   - **Hateful content** → positive float > 0.5 (e.g., `0.8`, `1.52`, `2.3`)
   - **Neutral content** → float between -1.0 and 0.5 inclusive (e.g., `-0.8`, `0.12`, `0.4`)
   - **Supportive content** → negative float < -1.0 (e.g., `-1.35`, `-2.1`)

2. **`"label"`** - A categorical string derived from the hate_speech_score:
   - `"hateful"` if hate_speech_score > 0.5
   - `"neutral"` if hate_speech_score is between -1.0 and 0.5 (inclusive)
   - `"supportive"` if hate_speech_score < -1.0

=========================
FACETS (0-4 SCALE)
=========================
Each facet is an **integer** from 0 to 4:
- 0 = Absent
- 1 = Mild
- 2 = Clear
- 3 = Severe
- 4 = Extreme

Example: `"insult": 2` (not `2.0` or `"2"`)

=========================
TARGETS (BOOLEAN FLAGS)
=========================
Set to `true` only if that group is explicitly targeted in the text.

=========================
JSON SCHEMA (MUST MATCH EXACTLY)
=========================
{{
  "overall": {{
    "label": "neutral",
    "hate_speech_score": 0.00
  }},
  "facets": {{
    "sentiment": 0,
    "respect": 0,
    "insult": 0,
    "humiliate": 0,
    "status": 0,
    "dehumanize": 0,
    "violence": 0,
    "genocide": 0,
    "attack_defend": 0,
    "hatespeech": 0
  }},
  "targets": {{
    "target_race_asian": false,
    "target_race_black": false,
    "target_race_latinx": false,
    "target_race_middle_eastern": false,
    "target_race_native_american": false,
    "target_race_pacific_islander": false,
    "target_race_white": false,
    "target_race_other": false,
    "target_religion_atheist": false,
    "target_religion_buddhist": false,
    "target_religion_christian": false,
    "target_religion_hindu": false,
    "target_religion_jewish": false,
    "target_religion_mormon": false,
    "target_religion_muslim": false,
    "target_religion_other": false,
    "target_origin_immigrant": false,
    "target_origin_migrant_worker": false,
    "target_origin_specific_country": false,
    "target_origin_undocumented": false,
    "target_origin_other": false,
    "target_gender_men": false,
    "target_gender_non_binary": false,
    "target_gender_transgender_men": false,
    "target_gender_transgender_unspecified": false,
    "target_gender_transgender_women": false,
    "target_gender_women": false,
    "target_gender_other": false,
    "target_sexuality_bisexual": false,
    "target_sexuality_gay": false,
    "target_sexuality_lesbian": false,
    "target_sexuality_straight": false,
    "target_sexuality_other": false,
    "target_age_children": false,
    "target_age_teenagers": false,
    "target_age_young_adults": false,
    "target_age_middle_aged": false,
    "target_age_seniors": false,
    "target_age_other": false,
    "target_disability_physical": false,
    "target_disability_cognitive": false,
    "target_disability_neurological": false,
    "target_disability_visually_impaired": false,
    "target_disability_hearing_impaired": false,
    "target_disability_unspecific": false,
    "target_disability_other": false
  }}
}}

=========================
TEXT TO ANALYZE
=========================
{{text}}
"""


In [6]:
from datasets import load_dataset
test_data = load_dataset("json", data_files="test_aggregated.jsonl", split="train")

def create_conversation(sample):
    model_response = json.dumps({
        "overall": sample["overall"],
        "facets": sample["facets"],
        "targets": sample["targets"]
    }, ensure_ascii=False)

    user_content = f"{INSTRUCTION}\n\n<text>\n{sample['text']}\n</text>"

    return {
        "messages": [
            {"role": "user", "content": user_content},
            {"role": "model", "content": model_response}
        ],
        "comment_id": sample["comment_id"],  # Keep this for evaluation!
        "text": sample["text"]
    }

test_data = test_data.map(create_conversation, remove_columns=test_data.column_names)
print(f"Test dataset size: {len(test_data)}")


Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/3957 [00:00<?, ? examples/s]

Test dataset size: 3957


# Models

In [9]:
import json
import torch
from datasets import load_dataset
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

# Your instruction template
INSTRUCTION = """
You are an expert hate speech analyst. Your task is to analyze the provided text and return ONLY a valid JSON object that strictly follows the schema below.
Do not include any explanations, markdown formatting, or text outside of the JSON object.

=========================
IMPORTANT INSTRUCTIONS
=========================
1. Output must be **valid JSON only** - no markdown, no commentary.
2. Use exact field names and types from the schema.
3. Type requirements:
   - `overall.label` → string ("hateful", "supportive", or "neutral")
   - `overall.hate_speech_score` → float (e.g., `1.52`, `-1.35`, `0.12`)
   - All `facets` values → integers 0-4 only (e.g., `3`, not `3.0`)
   - All `targets` values → booleans (`true` or `false`)

=========================
OVERALL ASSESSMENT
=========================
Provide two values in the `"overall"` object:

1. **`"hate_speech_score"`** - A float representing hate speech intensity:
   - **Hateful content** → positive float > 0.5 (e.g., `0.8`, `1.52`, `2.3`)
   - **Neutral content** → float between -1.0 and 0.5 inclusive (e.g., `-0.8`, `0.12`, `0.4`)
   - **Supportive content** → negative float < -1.0 (e.g., `-1.35`, `-2.1`)

2. **`"label"`** - A categorical string derived from the hate_speech_score:
   - `"hateful"` if hate_speech_score > 0.5
   - `"neutral"` if hate_speech_score is between -1.0 and 0.5 (inclusive)
   - `"supportive"` if hate_speech_score < -1.0

=========================
FACETS (0-4 SCALE)
=========================
Each facet is an **integer** from 0 to 4:
- 0 = Absent
- 1 = Mild
- 2 = Clear
- 3 = Severe
- 4 = Extreme

Example: `"insult": 2` (not `2.0` or `"2"`)

=========================
TARGETS (BOOLEAN FLAGS)
=========================
Set to `true` only if that group is explicitly targeted in the text.

=========================
JSON SCHEMA (MUST MATCH EXACTLY)
=========================
{{
  "overall": {{
    "label": "neutral",
    "hate_speech_score": 0.00
  }},
  "facets": {{
    "sentiment": 0,
    "respect": 0,
    "insult": 0,
    "humiliate": 0,
    "status": 0,
    "dehumanize": 0,
    "violence": 0,
    "genocide": 0,
    "attack_defend": 0,
    "hatespeech": 0
  }},
  "targets": {{
    "target_race_asian": false,
    "target_race_black": false,
    "target_race_latinx": false,
    "target_race_middle_eastern": false,
    "target_race_native_american": false,
    "target_race_pacific_islander": false,
    "target_race_white": false,
    "target_race_other": false,
    "target_religion_atheist": false,
    "target_religion_buddhist": false,
    "target_religion_christian": false,
    "target_religion_hindu": false,
    "target_religion_jewish": false,
    "target_religion_mormon": false,
    "target_religion_muslim": false,
    "target_religion_other": false,
    "target_origin_immigrant": false,
    "target_origin_migrant_worker": false,
    "target_origin_specific_country": false,
    "target_origin_undocumented": false,
    "target_origin_other": false,
    "target_gender_men": false,
    "target_gender_non_binary": false,
    "target_gender_transgender_men": false,
    "target_gender_transgender_unspecified": false,
    "target_gender_transgender_women": false,
    "target_gender_women": false,
    "target_gender_other": false,
    "target_sexuality_bisexual": false,
    "target_sexuality_gay": false,
    "target_sexuality_lesbian": false,
    "target_sexuality_straight": false,
    "target_sexuality_other": false,
    "target_age_children": false,
    "target_age_teenagers": false,
    "target_age_young_adults": false,
    "target_age_middle_aged": false,
    "target_age_seniors": false,
    "target_age_other": false,
    "target_disability_physical": false,
    "target_disability_cognitive": false,
    "target_disability_neurological": false,
    "target_disability_visually_impaired": false,
    "target_disability_hearing_impaired": false,
    "target_disability_unspecific": false,
    "target_disability_other": false
  }}
}}

=========================
EXAMPLE 1 (SHOT 1)
=========================

<example_text>
death to israel and british and usa
</example_text>

<example_output>
{{
  "overall": {{
    "label": "hateful",
    "hate_speech_score": 1.57
  }},
  "facets": {{
    "sentiment": 4,
    "respect": 4,
    "insult": 3,
    "humiliate": 2,
    "status": 3,
    "dehumanize": 2,
    "violence": 4,
    "genocide": 4,
    "attack_defend": 4,
    "hatespeech": 1
  }},
  "targets": {{
    "target_race_asian": false,
    "target_race_black": true,
    "target_race_latinx": false,
    "target_race_middle_eastern": true,
    "target_race_native_american": false,
    "target_race_pacific_islander": false,
    "target_race_white": true,
    "target_race_other": true,
    "target_religion_atheist": false,
    "target_religion_buddhist": false,
    "target_religion_christian": false,
    "target_religion_hindu": false,
    "target_religion_jewish": false,
    "target_religion_mormon": false,
    "target_religion_muslim": false,
    "target_religion_other": false,
    "target_origin_immigrant": false,
    "target_origin_migrant_worker": false,
    "target_origin_specific_country": true,
    "target_origin_undocumented": false,
    "target_origin_other": true,
    "target_gender_men": false,
    "target_gender_non_binary": false,
    "target_gender_transgender_men": false,
    "target_gender_transgender_unspecified": false,
    "target_gender_transgender_women": false,
    "target_gender_women": false,
    "target_gender_other": false,
    "target_sexuality_bisexual": false,
    "target_sexuality_gay": false,
    "target_sexuality_lesbian": false,
    "target_sexuality_straight": false,
    "target_sexuality_other": false,
    "target_age_children": false,
    "target_age_teenagers": false,
    "target_age_young_adults": false,
    "target_age_middle_aged": false,
    "target_age_seniors": false,
    "target_age_other": false,
    "target_disability_physical": false,
    "target_disability_cognitive": false,
    "target_disability_neurological": false,
    "target_disability_visually_impaired": false,
    "target_disability_hearing_impaired": false,
    "target_disability_unspecific": false,
    "target_disability_other": false
  }}
}}
</example_output>

=========================
EXAMPLE 2 (SHOT 2)
=========================

<example_text>
Im bi and I'm dating all of u
</example_text>

<example_output>
{{
  "overall": {{
    "label": "neutral",
    "hate_speech_score": -0.47
  }},
  "facets": {{
    "sentiment": 3,
    "respect": 3,
    "insult": 3,
    "humiliate": 2,
    "status": 2,
    "dehumanize": 2,
    "violence": 2,
    "genocide": 1,
    "attack_defend": 3,
    "hatespeech": 1
  }},
  "targets": {{
    "target_race_asian": false,
    "target_race_black": false,
    "target_race_latinx": false,
    "target_race_middle_eastern": false,
    "target_race_native_american": false,
    "target_race_pacific_islander": false,
    "target_race_white": false,
    "target_race_other": false,
    "target_religion_atheist": false,
    "target_religion_buddhist": false,
    "target_religion_christian": false,
    "target_religion_hindu": false,
    "target_religion_jewish": false,
    "target_religion_mormon": false,
    "target_religion_muslim": false,
    "target_religion_other": false,
    "target_origin_immigrant": false,
    "target_origin_migrant_worker": false,
    "target_origin_specific_country": false,
    "target_origin_undocumented": false,
    "target_origin_other": false,
    "target_gender_men": true,
    "target_gender_non_binary": true,
    "target_gender_transgender_men": false,
    "target_gender_transgender_unspecified": false,
    "target_gender_transgender_women": false,
    "target_gender_women": true,
    "target_gender_other": false,
    "target_sexuality_bisexual": true,
    "target_sexuality_gay": true,
    "target_sexuality_lesbian": true,
    "target_sexuality_straight": true,
    "target_sexuality_other": false,
    "target_age_children": false,
    "target_age_teenagers": false,
    "target_age_young_adults": false,
    "target_age_middle_aged": false,
    "target_age_seniors": false,
    "target_age_other": false,
    "target_disability_physical": false,
    "target_disability_cognitive": false,
    "target_disability_neurological": false,
    "target_disability_visually_impaired": false,
    "target_disability_hearing_impaired": false,
    "target_disability_unspecific": false,
    "target_disability_other": false
  }}
}}
</example_output>

=========================
EXAMPLE 3 (SHOT 3)
=========================

<example_text>
From a white woman thank you for speaking truth.
</example_text>

<example_output>
{{
  "overall": {{
    "label": "supportive",
    "hate_speech_score": -3.63
  }},
  "facets": {{
    "sentiment": 2,
    "respect": 0,
    "insult": 0,
    "humiliate": 0,
    "status": 2,
    "dehumanize": 0,
    "violence": 0,
    "genocide": 0,
    "attack_defend": 2,
    "hatespeech": 0
  }},
  "targets": {{
    "target_race_asian": false,
    "target_race_black": false,
    "target_race_latinx": false,
    "target_race_middle_eastern": false,
    "target_race_native_american": false,
    "target_race_pacific_islander": false,
    "target_race_white": true,
    "target_race_other": false,
    "target_religion_atheist": false,
    "target_religion_buddhist": false,
    "target_religion_christian": false,
    "target_religion_hindu": false,
    "target_religion_jewish": false,
    "target_religion_mormon": false,
    "target_religion_muslim": false,
    "target_religion_other": false,
    "target_origin_immigrant": false,
    "target_origin_migrant_worker": false,
    "target_origin_specific_country": false,
    "target_origin_undocumented": false,
    "target_origin_other": false,
    "target_gender_men": false,
    "target_gender_non_binary": false,
    "target_gender_transgender_men": false,
    "target_gender_transgender_unspecified": false,
    "target_gender_transgender_women": false,
    "target_gender_women": false,
    "target_gender_other": false,
    "target_sexuality_bisexual": false,
    "target_sexuality_gay": false,
    "target_sexuality_lesbian": false,
    "target_sexuality_straight": false,
    "target_sexuality_other": false,
    "target_age_children": false,
    "target_age_teenagers": false,
    "target_age_young_adults": false,
    "target_age_middle_aged": false,
    "target_age_seniors": false,
    "target_age_other": false,
    "target_disability_physical": false,
    "target_disability_cognitive": false,
    "target_disability_neurological": false,
    "target_disability_visually_impaired": false,
    "target_disability_hearing_impaired": false,
    "target_disability_unspecific": false,
    "target_disability_other": false
  }}
}}
</example_output>

=========================
EXAMPLE 4 (SHOT 4)
=========================

<example_text>
Google 'Catholic answers'
</example_text>

<example_output>
{{
  "overall": {{
    "label": "supportive",
    "hate_speech_score": -2.11
  }},
  "facets": {{
    "sentiment": 2,
    "respect": 2,
    "insult": 2,
    "humiliate": 2,
    "status": 2,
    "dehumanize": 2,
    "violence": 1,
    "genocide": 1,
    "attack_defend": 2,
    "hatespeech": 0
  }},
  "targets": {{
    "target_race_asian": false,
    "target_race_black": false,
    "target_race_latinx": false,
    "target_race_middle_eastern": false,
    "target_race_native_american": false,
    "target_race_pacific_islander": false,
    "target_race_white": false,
    "target_race_other": false,
    "target_religion_atheist": false,
    "target_religion_buddhist": false,
    "target_religion_christian": true,
    "target_religion_hindu": false,
    "target_religion_jewish": false,
    "target_religion_mormon": false,
    "target_religion_muslim": false,
    "target_religion_other": false,
    "target_origin_immigrant": false,
    "target_origin_migrant_worker": false,
    "target_origin_specific_country": false,
    "target_origin_undocumented": false,
    "target_origin_other": false,
    "target_gender_men": false,
    "target_gender_non_binary": false,
    "target_gender_transgender_men": false,
    "target_gender_transgender_unspecified": false,
    "target_gender_transgender_women": false,
    "target_gender_women": false,
    "target_gender_other": false,
    "target_sexuality_bisexual": false,
    "target_sexuality_gay": false,
    "target_sexuality_lesbian": false,
    "target_sexuality_straight": false,
    "target_sexuality_other": false,
    "target_disability_physical": false,
    "target_disability_cognitive": false,
    "target_disability_neurological": false,
    "target_disability_visually_impaired": false,
    "target_disability_hearing_impaired": false,
    "target_disability_unspecific": false,
    "target_disability_other": false
  }}
}}
</example_output>

=========================
EXAMPLE 5 (SHOT 5)
=========================

<example_text>
Stupid white priv racist crackas
</example_text>

<example_output>
{{
  "overall": {{
    "label": "hateful",
    "hate_speech_score": 0.82
  }},
  "facets": {{
    "sentiment": 4,
    "respect": 4,
    "insult": 3,
    "humiliate": 2,
    "status": 3,
    "dehumanize": 2,
    "violence": 0,
    "genocide": 0,
    "attack_defend": 3,
    "hatespeech": 2
  }},
  "targets": {{
    "target_race_asian": false,
    "target_race_black": false,
    "target_race_latinx": false,
    "target_race_middle_eastern": false,
    "target_race_native_american": false,
    "target_race_pacific_islander": false,
    "target_race_white": true,
    "target_race_other": false,
    "target_religion_atheist": false,
    "target_religion_buddhist": false,
    "target_religion_christian": false,
    "target_religion_hindu": false,
    "target_religion_jewish": false,
    "target_religion_mormon": false,
    "target_religion_muslim": false,
    "target_religion_other": false,
    "target_origin_immigrant": false,
    "target_origin_migrant_worker": false,
    "target_origin_specific_country": false,
    "target_origin_undocumented": false,
    "target_origin_other": false,
    "target_gender_men": false,
    "target_gender_non_binary": false,
    "target_gender_transgender_men": false,
    "target_gender_transgender_unspecified": false,
    "target_gender_transgender_women": false,
    "target_gender_women": false,
    "target_gender_other": false,
    "target_sexuality_bisexual": false,
    "target_sexuality_gay": false,
    "target_sexuality_lesbian": false,
    "target_sexuality_straight": false,
    "target_sexuality_other": false,
    "target_disability_physical": false,
    "target_disability_cognitive": false,
    "target_disability_neurological": false,
    "target_disability_visually_impaired": false,
    "target_disability_hearing_impaired": false,
    "target_disability_unspecific": false,
    "target_disability_other": false
  }}
}}
</example_output>
"""

def extract_outer_json(text: str) -> str:
    """Extract JSON object from model output, stripping markdown and handling malformed JSON."""
    # Remove markdown code fences
    text = text.replace("```json", "").replace("```", "").strip()

    s = text.find("{")
    if s == -1:
        raise ValueError("No JSON object found in output")

    # Find matching closing brace
    brace_count = 0
    for i in range(s, len(text)):
        if text[i] == '{':
            brace_count += 1
        elif text[i] == '}':
            brace_count -= 1
            if brace_count == 0:
                return text[s:i+1]

    # Fallback to rfind if no matching brace found
    e = text.rfind("}")
    if e == -1 or e <= s:
        raise ValueError("No closing brace found for JSON object")
    return text[s:e+1]

def to_bool(x):
    """Robust boolean conversion."""
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, float)):
        return bool(int(x))
    if isinstance(x, str):
        return x.strip().lower() in {"1", "true", "yes", "y", "t"}
    return False

def derive_label_from_score(score: float) -> str:
    """Derive label from hate_speech_score."""
    if score > 0.5:
        return "hateful"
    elif score < -1.0:
        return "supportive"
    else:
        return "neutral"

def normalize_schema(data):
    """Normalize data to match expected schema, handling dot notation and variations."""
    if isinstance(data, str):
        data = json.loads(data)

    # Handle flat dot notation (e.g., "overall.label" -> nested structure)
    if "overall.label" in data or "overall.hate_speech_score" in data:
        flat_data = data.copy()
        data = {"overall": {}, "facets": {}, "targets": {}}

        for key, value in flat_data.items():
            if key.startswith("overall."):
                field = key.replace("overall.", "")
                data["overall"][field] = value
            elif key == "facets":
                data["facets"] = value
            elif key == "targets":
                data["targets"] = value

    overall = data.get("overall", {})

    # Handle score field variations
    if "hate_speech_score" not in overall and "score" in overall:
        overall["hate_speech_score"] = float(overall.pop("score"))
    if "hate_speech_score" in overall:
        overall["hate_speech_score"] = float(overall["hate_speech_score"])

    # Fix wrong label values
    if "label" in overall:
        label = overall["label"]
        if label == "hate_speech":  # Base model sometimes uses this
            overall["label"] = "hateful"
        elif label not in ["hateful", "neutral", "supportive"]:
            if "hate_speech_score" in overall:
                overall["label"] = derive_label_from_score(overall["hate_speech_score"])

    # Derive label if missing
    if "label" not in overall and "hate_speech_score" in overall:
        overall["label"] = derive_label_from_score(overall["hate_speech_score"])

    data["overall"] = overall

    # Normalize facets to integers 0-4
    facets = data.get("facets", {})
    fixed_facets = {}
    for key, value in facets.items():
        try:
            fixed_facets[key] = max(0, min(4, int(float(value))))
        except Exception:
            fixed_facets[key] = 0
    data["facets"] = fixed_facets

    # Normalize targets to booleans
    targets = data.get("targets", {})
    data["targets"] = {key: to_bool(value) for key, value in targets.items()}

    return data


def run_base_llama_benchmark():
    """Benchmark base Llama-3.2-1B-Instruct model using vLLM."""

    print("="*60)
    print("BENCHMARKING BASE LLAMA-3.2-1B-INSTRUCT MODEL")
    print("="*60)

    # Load base model tokenizer
    print("\n📥 Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

    # Initialize vLLM with base model
    print("🚀 Loading base model with vLLM...")
    llm = LLM(
        model="meta-llama/Llama-3.2-1B-Instruct",
        dtype="bfloat16",
        gpu_memory_utilization=0.9,
        max_model_len=8192,
        trust_remote_code=True,
    )

    # Sampling parameters
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=2048,  # Sufficient for full JSON output
        stop=["<|eot_id|>", "</s>", "\n}\n", "}\n\n"],  # Llama stop tokens
        stop_token_ids=[tokenizer.eos_token_id],
    )

    # Load test data
    print("📂 Loading test data...")
    test_data = load_dataset("json", data_files="test_aggregated.jsonl", split="train")
    print(f"   Test dataset size: {len(test_data)}")

    # Prepare prompts with Llama chat template
    print("📝 Preparing prompts...")
    prompts = []
    for sample in test_data:
        # Use system + user message format for Llama
        messages = [
            {"role": "system", "content": INSTRUCTION.strip()},
            {"role": "user", "content": sample['text']}
        ]

        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        prompts.append(formatted_prompt)

    # Run inference
    print(f"⚡ Running inference on {len(prompts)} samples...")
    outputs = llm.generate(prompts, sampling_params)

    # Process results
    print("🔄 Processing outputs...")
    predictions = []
    failed_samples = []

    for i, output in enumerate(tqdm(outputs, desc="Processing")):
        sample = test_data[i]
        generated_text = output.outputs[0].text.strip()

        try:
            predicted_json_str = extract_outer_json(generated_text)
            normalized_prediction = normalize_schema(predicted_json_str)
            normalized_expected = normalize_schema(sample)

            predictions.append({
                "comment_id": sample.get("comment_id"),
                "text": sample.get("text"),
                "expected": normalized_expected,
                "predicted": normalized_prediction,
                "raw_output": generated_text,
                "success": True
            })
        except Exception as e:
            failed_samples.append({
                "comment_id": sample.get("comment_id"),
                "text": sample.get("text"),
                "expected": normalize_schema(sample),
                "raw_output": generated_text,
                "error": str(e)
            })
            predictions.append({
                "comment_id": sample.get("comment_id"),
                "text": sample.get("text"),
                "success": False
            })

    # Save results
    print("💾 Saving results...")
    output_file = "base_llama_predictions.jsonl"
    with open(output_file, "w") as f:
        for pred in predictions:
            f.write(json.dumps(pred) + "\n")

    if failed_samples:
        failed_file = "base_llama_failed.jsonl"
        with open(failed_file, "w") as f:
            for failed in failed_samples:
                f.write(json.dumps(failed) + "\n")
        print(f"   Failed samples saved to: {failed_file}")

    # Print summary
    success_rate = (len(predictions) - len(failed_samples)) / len(predictions)
    print("\n" + "="*60)
    print("✅ BENCHMARK COMPLETE!")
    print("="*60)
    print(f"   Total samples:    {len(predictions)}")
    print(f"   Successful:       {len(predictions) - len(failed_samples)} ({success_rate:.1%})")
    print(f"   Failed:           {len(failed_samples)}")
    print(f"   Results saved to: {output_file}")

    return predictions


if __name__ == "__main__":
    run_base_llama_benchmark()

BENCHMARKING BASE LLAMA-3.2-1B-INSTRUCT MODEL

📥 Loading tokenizer...
🚀 Loading base model with vLLM...
INFO 11-18 04:56:44 [utils.py:233] non-default args: {'trust_remote_code': True, 'dtype': 'bfloat16', 'max_model_len': 8192, 'disable_log_stats': True, 'model': 'meta-llama/Llama-3.2-1B-Instruct'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 11-18 04:57:03 [model.py:547] Resolved architecture: LlamaForCausalLM


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 11-18 04:57:03 [model.py:1510] Using max model len 8192
INFO 11-18 04:57:07 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=8192.


generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

INFO 11-18 04:58:03 [llm.py:306] Supported_tasks: ['generate']
📂 Loading test data...
   Test dataset size: 3957
📝 Preparing prompts...
⚡ Running inference on 3957 samples...


Adding requests:   0%|          | 0/3957 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/3957 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

🔄 Processing outputs...


Processing: 100%|██████████| 3957/3957 [00:01<00:00, 2309.07it/s]


💾 Saving results...
   Failed samples saved to: base_llama_failed.jsonl

✅ BENCHMARK COMPLETE!
   Total samples:    3957
   Successful:       2646 (66.9%)
   Failed:           1311
   Results saved to: base_llama_predictions.jsonl


In [10]:
import json
import re
from tqdm import tqdm

def aggressive_json_fix(raw_output):
    """Try to salvage JSON from base model output with aggressive fixes."""
    try:
        # Remove markdown
        text = raw_output.replace("```json", "").replace("```", "").strip()

        # Remove duplicate fields (keep first occurrence)
        lines = text.split('\n')
        seen_keys = set()
        cleaned_lines = []

        for line in lines:
            match = re.match(r'\s*"([^"]+)":', line)
            if match:
                key = match.group(1)
                if key not in seen_keys:
                    seen_keys.add(key)
                    cleaned_lines.append(line)
            else:
                cleaned_lines.append(line)

        text = '\n'.join(cleaned_lines)

        # Parse JSON
        data = json.loads(text)

        # Convert dot notation to nested structure
        if "overall.label" in data or "overall.hate_speech_score" in data:
            flat_data = data.copy()
            data = {"overall": {}, "facets": flat_data.get("facets", {}), "targets": flat_data.get("targets", {})}

            for key, value in flat_data.items():
                if key.startswith("overall."):
                    field = key.replace("overall.", "")
                    data["overall"][field] = value

        # Fix label
        if "label" in data.get("overall", {}):
            label = data["overall"]["label"]
            if label == "hate_speech":
                data["overall"]["label"] = "hateful"

        # Derive label from score if missing
        overall = data.get("overall", {})
        if "label" not in overall and "hate_speech_score" in overall:
            score = float(overall["hate_speech_score"])
            if score > 0.5:
                overall["label"] = "hateful"
            elif score < -1.0:
                overall["label"] = "supportive"
            else:
                overall["label"] = "neutral"
            data["overall"] = overall

        return data, None

    except Exception as e:
        return None, str(e)


def main():
    print("="*60)
    print("RECOVERY: Parsing Failed Base Llama Outputs")
    print("="*60)

    # Check if failed file exists
    try:
        with open("base_llama_failed.jsonl", "r") as f:
            failed_samples = [json.loads(line) for line in f]
    except FileNotFoundError:
        print("\n❌ base_llama_failed.jsonl not found!")
        print("   Run benchmark_base_llama.py first")
        return

    print(f"\n📁 Found {len(failed_samples)} failed samples")
    print("🔧 Attempting aggressive recovery...\n")

    recovered = []
    still_failed = []

    for sample in tqdm(failed_samples, desc="Recovering"):
        raw_output = sample.get("raw_output", "")
        fixed_data, error = aggressive_json_fix(raw_output)

        if fixed_data:
            recovered.append({
                "comment_id": sample.get("comment_id"),
                "text": sample.get("text"),
                "expected": sample.get("expected"),
                "predicted": fixed_data,
                "raw_output": raw_output,
                "success": True
            })
        else:
            still_failed.append({
                **sample,
                "recovery_error": error
            })

    # Save recovered samples
    if recovered:
        with open("base_llama_recovered.jsonl", "w") as f:
            for item in recovered:
                f.write(json.dumps(item) + "\n")
        print(f"\n✅ Recovered {len(recovered)} samples → base_llama_recovered.jsonl")

    if still_failed:
        with open("base_llama_still_failed.jsonl", "w") as f:
            for item in still_failed:
                f.write(json.dumps(item) + "\n")
        print(f"⚠️  Still failed: {len(still_failed)} samples → base_llama_still_failed.jsonl")

    # Calculate recovery rate
    recovery_rate = len(recovered) / len(failed_samples) if failed_samples else 0

    print("\n" + "="*60)
    print("RECOVERY SUMMARY")
    print("="*60)
    print(f"Original failures:      {len(failed_samples)}")
    print(f"Successfully recovered: {len(recovered)} ({recovery_rate:.1%})")
    print(f"Still failed:           {len(still_failed)} ({(1-recovery_rate):.1%})")

    if recovery_rate < 0.5:
        print("\n⚠️  Recovery rate < 50%")
        print("   The base Llama model struggles with this JSON format.")
        print("   Consider skipping base model benchmark.")
    elif recovery_rate < 0.9:
        print("\n✓ Partial recovery successful")
        print("  You can combine base_llama_predictions.jsonl + base_llama_recovered.jsonl")
    else:
        print("\n✅ High recovery rate!")
        print("   Merge recovered samples with predictions for evaluation")

    # If recovery rate is decent, create merged file
    if recovery_rate >= 0.5:
        print("\n🔗 Creating merged predictions file...")
        try:
            # Load original predictions
            with open("base_llama_predictions.jsonl", "r") as f:
                original = [json.loads(line) for line in f]

            # Merge: keep successful from original, add recovered
            merged = [p for p in original if p.get("success")] + recovered

            with open("base_llama_predictions_merged.jsonl", "w") as f:
                for item in merged:
                    f.write(json.dumps(item) + "\n")

            print(f"✅ Created base_llama_predictions_merged.jsonl ({len(merged)} samples)")
            print("   Use this file for evaluation")
        except Exception as e:
            print(f"⚠️  Could not create merged file: {e}")


if __name__ == "__main__":
    main()

RECOVERY: Parsing Failed Base Llama Outputs

📁 Found 1311 failed samples
🔧 Attempting aggressive recovery...



Recovering: 100%|██████████| 1311/1311 [00:00<00:00, 24203.87it/s]

⚠️  Still failed: 1311 samples → base_llama_still_failed.jsonl

RECOVERY SUMMARY
Original failures:      1311
Successfully recovered: 0 (0.0%)
Still failed:           1311 (100.0%)

⚠️  Recovery rate < 50%
   The base Llama model struggles with this JSON format.
   Consider skipping base model benchmark.





In [11]:
import json
import numpy as np
from sklearn.metrics import (
    f1_score, precision_score, recall_score,
    mean_absolute_error, mean_squared_error,
    hamming_loss, accuracy_score, classification_report
)
from scipy.stats import spearmanr
import os

# ============================================
# UTILITY FUNCTIONS
# ============================================

def to_bool(x):
    """Robust boolean coercion."""
    if isinstance(x, bool): return x
    if isinstance(x, (int, float)): return bool(int(x))
    if isinstance(x, str): return x.strip().lower() in {"1", "true", "yes", "y", "t"}
    return False

def derive_label_from_score(score: float) -> str:
    """Derive label using the same logic as training data."""
    if score > 0.5: return "hateful"
    if score < -1.0: return "supportive"
    return "neutral"

def normalize_schema(data):
    """Ensure data conforms to required types, handling dot notation."""
    if isinstance(data, str):
        data = json.loads(data)

    # Handle flat dot notation
    if "overall.label" in data or "overall.hate_speech_score" in data:
        flat_data = data.copy()
        data = {"overall": {}, "facets": {}, "targets": {}}

        for key, value in flat_data.items():
            if key.startswith("overall."):
                field = key.replace("overall.", "")
                data["overall"][field] = value
            elif key == "facets":
                data["facets"] = value
            elif key == "targets":
                data["targets"] = value

    overall = data.get("overall", {})

    # Normalize score field
    if "score" in overall and "hate_speech_score" not in overall:
        overall["hate_speech_score"] = float(overall.pop("score"))
    if "hate_speech_score" in overall:
        overall["hate_speech_score"] = float(overall["hate_speech_score"])
    else:
        overall["hate_speech_score"] = 0.0

    # Fix wrong label values
    if "label" in overall:
        label = overall["label"]
        if label == "hate_speech":
            overall["label"] = "hateful"
        elif label not in ["hateful", "neutral", "supportive"]:
            overall["label"] = derive_label_from_score(overall["hate_speech_score"])

    # Guarantee label exists
    if "label" not in overall:
        overall["label"] = derive_label_from_score(overall["hate_speech_score"])

    data["overall"] = overall

    facets = data.get("facets", {})
    fixed_facets = {}
    for key, value in facets.items():
        try:
            fixed_facets[key] = max(0, min(4, int(float(value))))
        except (ValueError, TypeError):
            fixed_facets[key] = 0
    data["facets"] = fixed_facets

    targets = data.get("targets", {})
    data["targets"] = {key: to_bool(value) for key, value in targets.items()}

    return data

def safe_get_score(obj):
    """Safely extract score from overall."""
    overall = obj.get("overall", {})
    if "hate_speech_score" in overall: return float(overall["hate_speech_score"])
    if "score" in overall: return float(overall["score"])
    raise KeyError("Neither 'hate_speech_score' nor 'score' found in overall")


# ============================================
# LOAD DATA
# ============================================

# Auto-detect which predictions file to use
if os.path.exists("base_llama_predictions_merged.jsonl"):
    PREDICTIONS_FILE = "base_llama_predictions_merged.jsonl"
    print("🔗 Using merged predictions (includes recovered samples)")
elif os.path.exists("base_llama_predictions.jsonl"):
    PREDICTIONS_FILE = "base_llama_predictions.jsonl"
else:
    print("❌ No predictions file found!")
    print("   Run benchmark_base_llama.py first")
    exit(1)

print("="*60)
print("BASE LLAMA-3.2-1B-INSTRUCT EVALUATION")
print("="*60)
print(f"\n📁 Loading predictions from: {PREDICTIONS_FILE}")

predictions = []
with open(PREDICTIONS_FILE, "r") as f:
    for line in f:
        predictions.append(json.loads(line))

# Separate successful and failed predictions
valid_preds = [p for p in predictions if p.get("success")]
failed_samples = [p for p in predictions if not p.get("success")]

# Re-normalize both expected and predicted data
for p in valid_preds:
    p["expected"] = normalize_schema(p["expected"])
    p["predicted"] = normalize_schema(p["predicted"])

print(f"\n📊 Data Summary:")
print(f"   Total samples:        {len(predictions)}")
print(f"   Valid predictions:    {len(valid_preds)} ({100*len(valid_preds)/len(predictions):.1f}%)")
print(f"   Failed predictions:   {len(failed_samples)} ({100*len(failed_samples)/len(predictions):.1f}%)")

if not valid_preds:
    print("\n❌ No valid predictions found. Cannot evaluate.")
    exit(1)

print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)


# ============================================
# 1. OVERALL LABEL CLASSIFICATION
# ============================================

print("\n" + "="*60)
print("OVERALL: Label Classification")
print("="*60)

y_true_labels = [p["expected"]["overall"]["label"] for p in valid_preds]
y_pred_labels = [p["predicted"]["overall"]["label"] for p in valid_preds]

overall_accuracy = accuracy_score(y_true_labels, y_pred_labels)
overall_micro_f1 = f1_score(y_true_labels, y_pred_labels, average="micro", zero_division=0)
overall_macro_f1 = f1_score(y_true_labels, y_pred_labels, average="macro", zero_division=0)
overall_precision = precision_score(y_true_labels, y_pred_labels, average="macro", zero_division=0)
overall_recall = recall_score(y_true_labels, y_pred_labels, average="macro", zero_division=0)

print(f"Accuracy:        {overall_accuracy:.4f}")
print(f"Macro F1:        {overall_macro_f1:.4f}")
print(f"Micro F1:        {overall_micro_f1:.4f}")
print(f"Macro Precision: {overall_precision:.4f}")
print(f"Macro Recall:    {overall_recall:.4f}")
print("\nPer-class breakdown:")
print(classification_report(y_true_labels, y_pred_labels, zero_division=0))

# Score correlation
y_true_scores = [safe_get_score(p["expected"]) for p in valid_preds]
y_pred_scores = [safe_get_score(p["predicted"]) for p in valid_preds]
score_corr = spearmanr(y_true_scores, y_pred_scores).correlation
print(f"Score Spearman correlation: {score_corr:.4f}")


# ============================================
# 2. FACETS EVALUATION
# ============================================

print("\n" + "="*60)
print("FACETS: Ordinal Ratings (0-4 scale)")
print("="*60)

facet_names = list(valid_preds[0]["expected"]["facets"].keys())
facet_results = {}

for facet in facet_names:
    y_true = np.array([p["expected"]["facets"].get(facet, 0) for p in valid_preds])
    y_pred = np.array([p["predicted"]["facets"].get(facet, 0) for p in valid_preds])

    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    spearman = spearmanr(y_true, y_pred).correlation
    exact_match = accuracy_score(y_true, y_pred)
    within_1 = np.mean(np.abs(y_true - y_pred) <= 1)

    facet_results[facet] = {
        "mae": mae, "mse": mse, "spearman": spearman,
        "exact_match": exact_match, "within_1_accuracy": within_1
    }

mean_mae = np.mean([r["mae"] for r in facet_results.values()])
mean_mse = np.mean([r["mse"] for r in facet_results.values()])
mean_spearman = np.mean([r["spearman"] for r in facet_results.values()])
mean_exact = np.mean([r["exact_match"] for r in facet_results.values()])
mean_within_1 = np.mean([r["within_1_accuracy"] for r in facet_results.values()])

print(f"Mean MAE:               {mean_mae:.4f}")
print(f"Mean MSE:               {mean_mse:.4f}")
print(f"Mean Spearman:          {mean_spearman:.4f}")
print(f"Mean Exact Match:       {mean_exact:.4f}")
print(f"Mean Within-1 Accuracy: {mean_within_1:.4f}")

print("\nPer-facet breakdown:")
print(f"{'Facet':<20} {'MAE':<8} {'Exact':<8} {'Within-1':<10} {'Spearman':<10}")
print("-" * 60)
for facet in facet_names:
    r = facet_results[facet]
    print(f"{facet:<20} {r['mae']:<8.3f} {r['exact_match']:<8.3f} {r['within_1_accuracy']:<10.3f} {r['spearman']:<10.3f}")


# ============================================
# 3. TARGETS EVALUATION
# ============================================

print("\n" + "="*60)
print("TARGETS: Multi-label Classification")
print("="*60)

target_names = list(valid_preds[0]["expected"]["targets"].keys())

y_true_targets = np.array([[int(p["expected"]["targets"].get(t, False)) for t in target_names] for p in valid_preds])
y_pred_targets = np.array([[int(p["predicted"]["targets"].get(t, False)) for t in target_names] for p in valid_preds])

targets_micro_f1 = f1_score(y_true_targets, y_pred_targets, average="micro", zero_division=0)
targets_macro_f1 = f1_score(y_true_targets, y_pred_targets, average="macro", zero_division=0)
targets_micro_precision = precision_score(y_true_targets, y_pred_targets, average="micro", zero_division=0)
targets_micro_recall = recall_score(y_true_targets, y_pred_targets, average="micro", zero_division=0)
targets_hamming = hamming_loss(y_true_targets, y_pred_targets)
exact_match_ratio = np.mean(np.all(y_true_targets == y_pred_targets, axis=1))

print(f"Micro F1:          {targets_micro_f1:.4f}")
print(f"Macro F1:          {targets_macro_f1:.4f}")
print(f"Micro Precision:   {targets_micro_precision:.4f}")
print(f"Micro Recall:      {targets_micro_recall:.4f}")
print(f"Hamming Loss:      {targets_hamming:.4f}")
print(f"Exact Match Ratio: {exact_match_ratio:.4f} ({int(exact_match_ratio*len(valid_preds))}/{len(valid_preds)})")

print("\nPer-target F1 scores (bottom 10):")
per_target_f1 = {target: f1_score(y_true_targets[:, i], y_pred_targets[:, i], zero_division=0) for i, target in enumerate(target_names)}
sorted_targets = sorted(per_target_f1.items(), key=lambda x: x[1])
for target, f1 in sorted_targets[:10]:
    print(f"  {target:<40} {f1:.3f}")


# ============================================
# SAVE EVALUATION SUMMARY
# ============================================

print("\n" + "="*60)
print("SAVING EVALUATION SUMMARY")
print("="*60)

eval_summary = {
    "metadata": {
        "model": "meta-llama/Llama-3.2-1B-Instruct (base)",
        "prediction_file": PREDICTIONS_FILE,
        "total_samples": len(predictions),
        "valid_predictions": len(valid_preds),
        "failed_predictions": len(failed_samples),
        "success_rate": len(valid_preds) / len(predictions) if predictions else 0
    },
    "overall": {
        "accuracy": overall_accuracy,
        "macro_f1": overall_macro_f1,
        "micro_f1": overall_micro_f1,
        "precision": overall_precision,
        "recall": overall_recall,
        "score_spearman": score_corr
    },
    "facets": {
        "mean_mae": mean_mae,
        "mean_mse": mean_mse,
        "mean_spearman": mean_spearman,
        "mean_exact_match": mean_exact,
        "mean_within_1_accuracy": mean_within_1,
        "per_facet": facet_results
    },
    "targets": {
        "micro_f1": targets_micro_f1,
        "macro_f1": targets_macro_f1,
        "precision": targets_micro_precision,
        "recall": targets_micro_recall,
        "hamming_loss": targets_hamming,
        "exact_match_ratio": exact_match_ratio,
        "per_target_f1": per_target_f1
    }
}

output_file = "base_llama_evaluation.json"
with open(output_file, "w") as f:
    json.dump(eval_summary, f, indent=2)

print(f"✅ Summary saved to: {output_file}")
if failed_samples:
    print(f"⚠️  {len(failed_samples)} samples failed - check base_llama_failed.jsonl")

print("\n" + "="*60)
print("EVALUATION COMPLETE!")
print("="*60)

# Print key metrics summary
print("\n📈 KEY METRICS SUMMARY:")
print(f"   Overall Accuracy:  {overall_accuracy:.2%}")
print(f"   Overall Macro F1:  {overall_macro_f1:.4f}")
print(f"   Facets Mean MAE:   {mean_mae:.4f}")
print(f"   Targets Micro F1:  {targets_micro_f1:.4f}")
print(f"   Success Rate:      {len(valid_preds)/len(predictions):.2%}")

BASE LLAMA-3.2-1B-INSTRUCT EVALUATION

📁 Loading predictions from: base_llama_predictions.jsonl

📊 Data Summary:
   Total samples:        3957
   Valid predictions:    2646 (66.9%)
   Failed predictions:   1311 (33.1%)

EVALUATION RESULTS

OVERALL: Label Classification
Accuracy:        0.3042
Macro F1:        0.2137
Micro F1:        0.3042
Macro Precision: 0.4193
Macro Recall:    0.3715

Per-class breakdown:
              precision    recall  f1-score   support

     hateful       0.27      1.00      0.43       660
     neutral       0.01      0.00      0.00       737
  supportive       0.97      0.12      0.21      1249

    accuracy                           0.30      2646
   macro avg       0.42      0.37      0.21      2646
weighted avg       0.53      0.30      0.21      2646

Score Spearman correlation: 0.3542

FACETS: Ordinal Ratings (0-4 scale)
Mean MAE:               0.9549
Mean MSE:               1.8870
Mean Spearman:          0.2154
Mean Exact Match:       0.3887
Mean Within