<a href="https://colab.research.google.com/github/jamesalv/HateDeRC/blob/master/HateDeRC_Full.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import shutil
# import os

# # Delete HateDeRC directory if it exists
# if os.path.exists('HateDeRC'):
#   shutil.rmtree('HateDeRC')
# !git clone https://github.com/jamesalv/HateDeRC
# %cd HateDeRC

In [2]:
from TrainingConfig import TrainingConfig
from typing import Dict, Any, Tuple, List
import numpy as np
import torch
from transformers import AutoTokenizer
import json

In [3]:
data_path = 'Data/dataset.json'

In [4]:
config = TrainingConfig()

config.class_weighting = True
config.num_epochs = 2
config.hidden_dropout_prob = 0.2

# Attention Training Configurations
config.train_attention = True
config.lambda_attn = 1
config.ranking_margin = 0.1        # Minimum margin between token pairs
config.ranking_threshold = 0.05    # Min difference to consider pairs significant

# 
config.use_multi_layer_loss = True

In [5]:
# Seed all randomness for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(config.seed)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)

# Preprocessing

In [6]:
import re
import string

def deobfuscate_text(text):
    """
    Normalize common text obfuscation patterns to reveal original words.
    Useful for hate speech detection and content analysis.

    Args:
        text (str): Input text with potential obfuscations

    Returns:
        str: Text with obfuscations normalized
    """
    if not isinstance(text, str):
        return text

    # Make a copy to work with
    result = text.lower()

    # 1. Handle asterisk/symbol replacements
    symbol_patterns = {
        # Common profanity
        r'f\*+c?k': 'fuck',
        r'f\*+': 'fuck',
        r's\*+t': 'shit',
        r'b\*+ch': 'bitch',
        r'a\*+s': 'ass',
        r'd\*+n': 'damn',
        r'h\*+l': 'hell',
        r'c\*+p': 'crap',

        # Slurs and hate speech terms (be comprehensive for detection)
        r'n\*+g+[aer]+': 'nigger',  # Various n-word obfuscations
        r'f\*+g+[ot]*': 'faggot',
        r'r\*+[dt]ard': 'retard',
        r'sp\*+c': 'spic',

        # Other symbols
        r'@ss': 'ass',
        r'b@tch': 'bitch',
        r'sh!t': 'shit',
        r'f#ck': 'fuck',
        r'd@mn': 'damn',
    }

    for pattern, replacement in symbol_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 2. Handle character spacing (f u c k -> fuck)
    spacing_patterns = {
        r'\bf\s+u\s+c\s+k\b': 'fuck',
        r'\bs\s+h\s+i\s+t\b': 'shit',
        r'\bd\s+a\s+m\s+n\b': 'damn',
        r'\bh\s+e\s+l\s+l\b': 'hell',
        r'\ba\s+s\s+s\b': 'ass',
        r'\bc\s+r\s+a\s+p\b': 'crap',
    }

    for pattern, replacement in spacing_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 3. Handle number/letter substitutions
    leet_patterns = {
        # Basic leet speak
        r'\b3\s*1\s*1\s*3\b': 'elle',  # 3113 -> elle
        r'\bf4g\b': 'fag',
        r'\bf4gg0t\b': 'faggot',
        r'\bn00b\b': 'noob',
        r'\bl33t\b': 'leet',
        r'\bh4t3\b': 'hate',
        r'\b5h1t\b': 'shit',
        r'\bf0ck\b': 'fock',
    }

    for pattern, replacement in leet_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 4. Handle repeated characters and separators
    # Remove excessive punctuation between letters
    result = re.sub(r'([a-z])[^\w\s]+([a-z])', r'\1\2', result)

    # Handle underscore separation
    result = re.sub(r'([a-z])_+([a-z])', r'\1\2', result)

    # Handle dot separation
    result = re.sub(r'([a-z])\.+([a-z])', r'\1\2', result)

    # 5. Handle common misspellings/variations used for evasion
    evasion_patterns = {
        r'\bfuk\b': 'fuck',
        r'\bfuq\b': 'fuck',
        r'\bfck\b': 'fuck',
        r'\bshyt\b': 'shit',
        r'\bshit\b': 'shit',
        r'\bbiatch\b': 'bitch',
        r'\bbeatch\b': 'bitch',
        r'\basshole\b': 'asshole',
        r'\ba55hole\b': 'asshole',
        r'\btard\b': 'retard',
        r'\bfagg\b': 'fag',
    }

    for pattern, replacement in evasion_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 6. Clean up multiple spaces
    result = re.sub(r'\s+', ' ', result).strip()

    return result

In [7]:
def aggregate_rationales(rationales, labels, post_length, drop_abnormal=False):
    """
    If all 3 annotators are normal → 3 zero spans → average (all zeros).
    If k annotators are non-normal and k spans exist → average the k spans (no added zeros).
    If k non-normal but fewer than k spans:
        If the missing annotators are non-normal → do not fill with zeros; average only existing spans and record rationale_support = #spans.
        If the missing annotators are normal (e.g., 2 hate + 1 normal + 2 spans) → append one zero span for the normal.
    """
    count_normal = labels.count(0)
    count_hate = labels.count(1)
    count_rationales = len(rationales)
    pad = np.zeros(post_length, dtype="int").tolist()

    # If there are hate labels but no rationales, something is wrong
    if count_hate > 0 and count_rationales == 0:
        if drop_abnormal:
            return None

        # Else just fill with 0
        return np.zeros(post_length).tolist()

    # If all annotators are normal, return all zeros
    if count_normal == 3:
        return np.zeros(post_length).tolist()

    # If we have hate annotators
    if count_hate > 0:
        # Case 1: Number of rationales matches number of hate annotators
        if count_rationales == count_hate:
            return np.average(rationales, axis=0).tolist()

        # Case 2: Fewer rationales than hate annotators
        elif count_rationales < count_hate:
            # Add zero padding for normal annotators only
            rationales_copy = rationales.copy()
            zeros_to_add = count_normal
            for _ in range(zeros_to_add):
                rationales_copy.append(pad)
            return np.average(rationales_copy, axis=0).tolist()

        # Case 3: More rationales than hate annotators (shouldn't happen normally)
        else:
            # Just average what we have
            return np.average(rationales, axis=0).tolist()

    # Fallback: return zeros if no clear case matches
    return np.zeros(post_length).tolist()

In [8]:
from typing import List, Tuple

def preprocess_text(raw_text):
    preprocessed_text = raw_text
    # # Remove HTML tags <>
    preprocessed_text = preprocessed_text.replace("<", "").replace(">", "")
    # # De-Obsfucate Patterns
    preprocessed_text = deobfuscate_text(preprocessed_text)

    return preprocessed_text


def create_text_segment(
    text_tokens: List[str], rationale_mask: List[int]
) -> List[Tuple[List[str], int]]:
    """
    Process a rationale mask to identify contiguous segments of highlighted text.
    Then create a segmented representation of the tokens

    Args:
        text_tokens: Original text tokens
        mask: Binary mask where 1 indicates a highlighted token (this consists of mask from 3 annotators)

    Returns:
        A list of tuples (text segment, mask value)
    """
    # Handle case where mask is empty (no rationale provided), usually this is normal classification
    mask = rationale_mask

    # for mask in all_rationale_mask:
    # Find breakpoints (transitions between highlighted/1 and non-highlighted/0)
    breakpoints = []
    mask_values = []

    # Always start with position 0
    breakpoints.append(0)
    mask_values.append(mask[0])

    # Find transitions in the mask
    for i in range(1, len(mask)):
        if mask[i] != mask[i - 1]:
            breakpoints.append(i)
            mask_values.append(mask[i])

    # Always end with the length of the text
    if breakpoints[-1] != len(mask):
        breakpoints.append(len(mask))

    # Create segments based on breakpoints
    segments = []
    for i in range(len(breakpoints) - 1):
        start = breakpoints[i]
        end = breakpoints[i + 1]
        segments.append((text_tokens[start:end], mask_values[i]))

    return segments


def align_rationales(tokens, rationales, tokenizer, max_length=128):
    """
    Align rationales with tokenized text while handling different tokenizer formats.

    Args:
        tokens: Original text tokens
        rationales: Original rationale masks
        tokenizer: The tokenizer to use
        max_length: Maximum sequence length

    Returns:
        Dictionary with tokenized inputs and aligned rationale masks
    """
    segments = create_text_segment(tokens, rationales)
    all_human_rationales = []
    all_input_ids = []
    all_attention_mask = []
    all_token_type_ids = []
    all_rationales = []
    for text_segment, rationale_value in segments:
        inputs = {}
        concatenated_text = " ".join(text_segment)
        processed_segment = preprocess_text(concatenated_text)
        tokenized = tokenizer(
            processed_segment, add_special_tokens=False, return_tensors="pt"
        )

        # Extract the relevant data
        segment_input_ids = tokenized["input_ids"][0]
        segment_attention_mask = tokenized["attention_mask"][0]
        # Handle token_type_ids if present
        if "token_type_ids" in tokenized:
            segment_token_type_ids = tokenized["token_type_ids"][0]
            all_token_type_ids.extend(segment_token_type_ids)

        # Add input IDs and attention mask
        all_input_ids.extend(segment_input_ids)
        all_attention_mask.extend(segment_attention_mask)

        # Add rationales (excluding special tokens)
        segment_rationales = [rationale_value] * len(segment_input_ids)
        all_rationales.extend(segment_rationales)
    # Get special token IDs
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id

    # Add special tokens at the beginning and end
    all_input_ids = [cls_token_id] + all_input_ids + [sep_token_id]
    all_attention_mask = [1] + all_attention_mask + [1]

    # Handle token_type_ids if the model requires it
    if hasattr(tokenizer, "create_token_type_ids_from_sequences"):
        all_token_type_ids = tokenizer.create_token_type_ids_from_sequences(
            all_input_ids[1:-1]
        )
    elif all_token_type_ids:
        all_token_type_ids = [0] + all_token_type_ids + [0]
    else:
        all_token_type_ids = [0] * len(all_input_ids)

    # Check tokenized vs rationales length
    if len(all_input_ids) != len(all_attention_mask):
        print("Warning: length of tokens and rationales do not match")

    # Add zero rationale values for special tokens
    all_rationales = [0] + all_rationales + [0]

    # Truncate to max length if needed
    if len(all_input_ids) > max_length:
        print("WARNING: NEED TO TRUNCATE")
        all_input_ids = all_input_ids[:max_length]
        all_attention_mask = all_attention_mask[:max_length]
        all_token_type_ids = all_token_type_ids[:max_length]
        all_rationales = all_rationales[:max_length]

    # Pad to max_length if needed
    pad_token_id = tokenizer.pad_token_id
    padding_length = max_length - len(all_input_ids)

    if padding_length > 0:
        all_input_ids = all_input_ids + [pad_token_id] * padding_length
        all_attention_mask = all_attention_mask + [0] * padding_length
        all_token_type_ids = all_token_type_ids + [0] * padding_length
        all_rationales = all_rationales + [0] * padding_length

    # Convert lists to tensors
    inputs = {
        "input_ids": torch.tensor([all_input_ids], dtype=torch.long),
        "attention_mask": torch.tensor([all_attention_mask], dtype=torch.long),
        "token_type_ids": (
            torch.tensor([all_token_type_ids], dtype=torch.long)
            if "token_type_ids" in tokenizer.model_input_names
            else None
        ),
        "rationales": torch.tensor([all_rationales], dtype=torch.float32),
    }

    # Remove None values
    inputs = {k: v for k, v in inputs.items() if v is not None}
    return inputs

In [9]:
import re
import json
import os
import string
from collections import Counter
from tqdm import tqdm
import more_itertools as mit


def find_ranges(iterable):
    """Yield range of consecutive numbers."""
    for group in mit.consecutive_groups(iterable):
        group = list(group)
        if len(group) == 1:
            yield group[0]
        else:
            yield group[0], group[-1]


def process_and_convert_data(
    data,
    tokenizer,
    post_id_divisions,
    save_path="Data/explanations/",
    drop_abnormal=False,
):
    """
    Combined function that processes raw entries and converts to ERASER format in one pass.
    Also splits data into train/val/test sets.
    """
    print("Processing and converting data...")

    # Initialize outputs
    train_data = []
    val_data = []
    test_data = []
    dropped = 0

    # Create directories if saving splits
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        os.makedirs(os.path.join(save_path, "docs"), exist_ok=True)
        train_fp = open(os.path.join(save_path, "train.jsonl"), "w")
        val_fp = open(os.path.join(save_path, "val.jsonl"), "w")
        test_fp = open(os.path.join(save_path, "test.jsonl"), "w")

    for key, value in tqdm(data.items()):
        try:
            # Extract labels
            labels = [
                1 if annot["label"] in ["hatespeech", "offensive"] else 0
                for annot in value["annotators"]
            ]

            # Process rationales
            rationales = value.get("rationales", [])
            aggregated_rationale = aggregate_rationales(
                rationales,
                labels,
                len(value["post_tokens"]),
                drop_abnormal=drop_abnormal,
            )

            if aggregated_rationale is None:
                dropped += 1
                continue

            inputs = align_rationales(
                value["post_tokens"], aggregated_rationale, tokenizer
            )

            # Calculate labels
            hard_label = Counter(labels).most_common(1)[0][0]
            soft_label = sum(labels) / len(labels)

            # Determine target groups (mentioned at least 3 times)
            target_groups = [
                t for annot in value["annotators"] for t in annot["target"]
            ]
            filtered_targets = [k for k, v in Counter(target_groups).items() if v > 2]

            # Create processed entry
            processed_entry = {
                "post_id": key,
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"],
                "rationales": inputs["rationales"],
                "raw_text": " ".join(value["post_tokens"]),
                "hard_label": hard_label,
                "soft_label": soft_label,
                "target_groups": filtered_targets,
            }

            # Convert to ERASER format if it's hateful/offensive content
            if hard_label == 1 and save_path:
                input_ids_list = inputs["input_ids"].squeeze().tolist()
                rationales_list = inputs["rationales"].squeeze().ceil().int().tolist()

                # Build evidences
                evidences = []
                indexes = sorted(
                    [i for i, each in enumerate(rationales_list) if each == 1]
                )
                for span in find_ranges(indexes):
                    if isinstance(span, int):
                        start, end = span, span + 1
                    else:
                        start, end = span[0], span[1] + 1

                    evidences.append(
                        {
                            "docid": key,
                            "end_sentence": -1,
                            "end_token": end,
                            "start_sentence": -1,
                            "start_token": start,
                            "text": " ".join(
                                [str(x) for x in input_ids_list[start:end]]
                            ),
                        }
                    )

                eraser_entry = {
                    "annotation_id": key,
                    "classification": str(hard_label),
                    "evidences": [evidences],
                    "query": "What is the class?",
                    "query_type": None,
                }

                # Save document
                with open(os.path.join(save_path, "docs", key), "w") as fp:
                    fp.write(" ".join([str(x) for x in input_ids_list if x > 0]))

                # Write to appropriate split
                if key in post_id_divisions["train"]:
                    train_fp.write(json.dumps(eraser_entry) + "\n")
                elif key in post_id_divisions["val"]:
                    val_fp.write(json.dumps(eraser_entry) + "\n")
                elif key in post_id_divisions["test"]:
                    test_fp.write(json.dumps(eraser_entry) + "\n")

            # Add to appropriate split list
            if key in post_id_divisions["train"]:
                train_data.append(processed_entry)
            elif key in post_id_divisions["val"]:
                val_data.append(processed_entry)
            elif key in post_id_divisions["test"]:
                test_data.append(processed_entry)

        except Exception as e:
            dropped += 1
            print(f"Error processing {key}: {e}")

    if save_path:
        train_fp.close()
        val_fp.close()
        test_fp.close()

    print(
        f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}, Dropped: {dropped}"
    )

    return {"train": train_data, "val": val_data, "test": test_data}

In [10]:
with open(data_path, 'r') as file:
    data = json.load(file)

with open('Data/post_id_divisions.json') as file:
    post_id_divisions = json.load(file)

# Process everything in one pass
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
splits = process_and_convert_data(
    data=data,
    tokenizer=tokenizer,
    post_id_divisions=post_id_divisions,
    save_path='Data/explanations/',
    drop_abnormal=False
)

# Access splits directly
train_data = splits['train']
val_data = splits['val']
test_data = splits['test']

Processing and converting data...


 37%|███▋      | 7406/20148 [00:09<00:16, 774.86it/s]



 89%|████████▉ | 17990/20148 [00:24<00:03, 692.51it/s]

Error processing 24439295_gab: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.


100%|██████████| 20148/20148 [00:27<00:00, 738.93it/s]

Train: 15382, Val: 1922, Test: 1924, Dropped: 1





# Training

In [11]:
from HateDataset import HateDataset

# Create datasets with pre-tokenized data
train_dataset = HateDataset(data=train_data)
val_dataset = HateDataset(data=val_data)
test_dataset = HateDataset(data=test_data)

In [12]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # Use shuffle=False for validation
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Use shuffle=False for testing

## Bias

In [13]:
def get_bias_evaluation_samples(data, method, group):
    """
    Get positive and negative sample IDs for bias evaluation based on method and group

    Args:
        data: list of data entries
        method: Bias evaluation method ('subgroup', 'bpsn', or 'bnsp')
        group: Target group to evaluate

    Returns:
        Tuple of (positive_ids, negative_ids)
    """
    positive_ids = []
    negative_ids = []

    for idx, row in enumerate(data):
        target_groups = row['target_groups']
        if target_groups is None:
            continue

        is_in_group = group in target_groups

        # Convert various label formats to binary toxic/non-toxic
        if 'hard_label' in row:
            is_toxic = row['hard_label'] == 1
        else:
            continue

        if method == 'subgroup':
            # Only consider samples mentioning the group
            if is_in_group:
                if is_toxic:
                    positive_ids.append(idx)
                else:
                    negative_ids.append(idx)

        elif method == 'bpsn':
            # Compare non-toxic posts mentioning the group with toxic posts NOT mentioning the group
            if is_in_group and not is_toxic:
                negative_ids.append(idx)
            elif not is_in_group and is_toxic:
                positive_ids.append(idx)

        elif method == 'bnsp':
            # Compare toxic posts mentioning the group with non-toxic posts NOT mentioning the group
            if is_in_group and is_toxic:
                positive_ids.append(idx)
            elif not is_in_group and not is_toxic:
                negative_ids.append(idx)

    return positive_ids, negative_ids

In [14]:
from collections import defaultdict
from sklearn.metrics import roc_auc_score

def calculate_gmb_metrics(
    test_data: List[Dict[str, Any]],
    probabilities: np.ndarray,
    target_groups: List[str]
):
    """
    Calculate GMB (Generalized Mean of Bias) AUC metrics from model predictions

    Args:
        probabilities: Model's probability outputs
        test_data: List of test data entries
        target_groups: List of target groups to evaluate

    Returns:
        Dictionary with GMB metrics
    """
    # Create mappings from post_id to predictions and ground truth
    prediction_scores = defaultdict(lambda: defaultdict(dict))
    ground_truth = {}

    for idx, row in enumerate(test_data):
        prediction_scores[idx] = probabilities[idx, 1]
        ground_truth[idx] = row['hard_label']

    # Calculate metrics for each target group and method
    bias_metrics = {}
    methods = ['subgroup', 'bpsn', 'bnsp']

    for method in methods:
        bias_metrics[method] = {}  # Initialize nested dictionary for each method
        for group in target_groups:
            # Get positive and negative samples based on the method
            positive_ids, negative_ids = get_bias_evaluation_samples(test_data, method, group)

            if len(positive_ids) == 0 or len(negative_ids) == 0:
                print(f"Skipping {method} for group {group}: no samples found")
                continue  # Skip if no samples for this group/method

            # Collect ground truth and predictions
            y_true = []
            y_score = []

            for post_id in positive_ids:
                if post_id in ground_truth and post_id in prediction_scores:
                    y_true.append(ground_truth[post_id])
                    y_score.append(prediction_scores[post_id])

            for post_id in negative_ids:
                if post_id in ground_truth and post_id in prediction_scores:
                    y_true.append(ground_truth[post_id])
                    y_score.append(prediction_scores[post_id])

            # Calculate AUC if we have enough samples with both classes
            if len(y_true) > 10 and len(set(y_true)) > 1:
                try:
                    auc = roc_auc_score(y_true, y_score)
                    bias_metrics[method][group] = auc
                except ValueError:
                    print(f"Could not compute AUC for {method} and group {group} due to ValueError")
                    pass

    # Calculate GMB for each method
    gmb_metrics = {}
    power = -5  # Power parameter for generalized mean

    for method in methods:
        if not bias_metrics[method]:
            continue

        scores = list(bias_metrics[method].values())
        if not scores:
            continue

        # Calculate generalized mean with p=-5
        power_mean = np.mean([score ** power for score in scores]) ** (1/power)
        gmb_metrics[f'GMB-{method.upper()}-AUC'] = power_mean

    # Calculate a combined GMB score that includes all methods
    all_scores = []
    for method in methods:
        all_scores.extend(list(bias_metrics[method].values()))

    if all_scores:
        gmb_metrics['GMB-COMBINED-AUC'] = np.mean([score ** power for score in all_scores]) ** (1/power)

    return gmb_metrics, bias_metrics

## XAI

In [15]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from typing import List, Dict, Tuple
import json
from sklearn.metrics import (
    precision_recall_curve,
    auc,
    f1_score,
    precision_score,
    recall_score,
)


class FaithfulnessMetrics:
    """
    Compute faithfulness metrics using the model's existing predict() method.
    Creates modified datasets and uses DataLoader for efficient batched processing.
    """

    def __init__(self, model, tokenizer, dataset_class, batch_size=32):
        self.model = model
        self.tokenizer = tokenizer
        self.dataset_class = dataset_class
        self.batch_size = batch_size

        # Get special token IDs
        self.special_token_ids = {
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
        }
        self.special_token_ids = {x for x in self.special_token_ids if x is not None}

    def compute_all_metrics(
        self,
        test_data: List[Dict],  # Your original test data
        test_results: Dict,  # Results from prediction
        k: int = 5,  # Number of top tokens to consider
        eraser_save_path: str = "Data/eraser_formatted_results.jsonl",
    ) -> Dict[str, float]:
        """
        Compute all ERASER metrics efficiently using DataLoader approach

        Args:
            test_data: List of test instances (each with input_ids, attention_mask, rationales, labels)
            test_results: List of dictionaries containing attention scores for each instance

        Returns:
            Dictionary with all metrics
        """
        print("Computing ERASER metrics using DataLoader approach...")

        # Extract lists for easier processing
        input_ids_list = [item["input_ids"] for item in test_data]
        attention_masks_list = [item["attention_mask"] for item in test_data]
        human_rationales = [item["rationales"] for item in test_data]
        attention_scores = [item for item in test_results["attentions"]]

        # 1. Extract top-k as hard predictions
        hard_predictions = self._extract_top_k_tokens(
            attention_scores, attention_masks_list, input_ids_list, k
        )

        hard_rationale_predictions, soft_rationale_predictions = self._convert_attention_to_evidence_format(input_ids_list, attention_scores, hard_predictions)

        # 2. PLAUSIBILITY METRICS
        print("\n[1/3] Computing plausibility metrics...")
        auprc = self._compute_auprc(
            attention_scores, human_rationales, attention_masks_list, input_ids_list
        )
        token_f1, token_prec, token_rec = self._compute_token_f1(
            hard_predictions, human_rationales, attention_masks_list
        )

        # 3. FAITHFULNESS METRICS
        print("[2/3] Computing comprehensiveness scores...")
        raw_comprehensiveness, comprehensiveness_scores = (
            self._compute_comprehensiveness(test_data, test_results, hard_predictions)
        )

        print("[3/3] Computing sufficiency scores...")
        raw_sufficiency, sufficiency_scores = self._compute_sufficiency(
            test_data, test_results, hard_predictions
        )

        # 4. Convert to eraser format
        results_eraser = self._convert_result_to_eraser_format(test_results, hard_rationale_predictions, soft_rationale_predictions, raw_sufficiency, raw_comprehensiveness)
        # Convert to JSONL format
        jsonl_output = '\n'.join([json.dumps(entry) for entry in results_eraser])
        with open(eraser_save_path, 'w') as f:
            f.write(jsonl_output)

        return {
            # Plausibility
            "auprc": auprc,
            "token_f1": token_f1,
            "token_precision": token_prec,
            "token_recall": token_rec,
            # Faithfulness
            "comprehensiveness": float(np.mean(comprehensiveness_scores)),
            "sufficiency": float(np.mean(sufficiency_scores)),
            # Additional
            "avg_rationale_length": k,
        }

    def _convert_attention_to_evidence_format(self, input_ids_list, attention_scores, hard_predictions):
        # 2. Collect evidence
        hard_rationale_predictions = []
        for idx, hp in enumerate(hard_predictions):
            evidences = []
            indexes = sorted([i for i, each in enumerate(hp.tolist()) if each == 1])
            for span in find_ranges(indexes):
                if isinstance(span, int):
                    start, end = span, span + 1
                else:
                    start, end = span[0], span[1] + 1

                evidences.append({
                    "start_token": start,
                    "end_token": end,
                })
            hard_rationale_predictions.append(evidences)

        soft_rationale_predictions = []
        for att in attention_scores:
            pred = [x for x in att if x > 0]
            soft_rationale_predictions.append(pred)

        return hard_rationale_predictions, soft_rationale_predictions

    def _convert_result_to_eraser_format(
        self,
        test_result: Dict,
        hard_rationale_predictions,
        soft_rationale_predictions,
        sufficiency_scores: np.ndarray,
        comprehensiveness_scores: np.ndarray,
    ):
        all_entries = []
        for idx, data in enumerate(test_result["post_id"]):
            entry = {
            'annotation_id': data,
            'classification': str(int(test_result["predictions"][idx])),
            'classification_scores': {
                0: float(test_result["probabilities"][idx][0]),
                1: float(test_result["probabilities"][idx][1]),
            },
            'rationales': [
                {
                    "docid": data,
                    "hard_rationale_predictions": hard_rationale_predictions[idx],
                    "soft_rationale_predictions": [float(x) for x in soft_rationale_predictions[idx]],
                }
            ],
            'sufficiency_classification_scores': {
                0: float(sufficiency_scores[idx][0]),
                1: float(sufficiency_scores[idx][1])
            },
            'comprehensiveness_classification_scores': {
                0: float(comprehensiveness_scores[idx][0]),
                1: float(comprehensiveness_scores[idx][1])
            }
            }
            all_entries.append(entry)

        return all_entries

    def _calculate_average_rationale_length(
        self,
        human_rationales: List[torch.Tensor],
        attention_masks_list: List[torch.Tensor],
        input_ids_list: List[torch.Tensor],
    ) -> int:
        """Calculate average number of content rationale tokens"""
        lengths = []
        for idx, (rat, mask) in enumerate(zip(human_rationales, attention_masks_list)):
            valid_positions = mask.bool().cpu().numpy().flatten()

            # Exclude special tokens
            input_ids = input_ids_list[idx].cpu().numpy().flatten()
            is_special = np.isin(input_ids, list(self.special_token_ids))
            content_positions = valid_positions & ~is_special

            rat_count = (rat.cpu().numpy().flatten()[content_positions] == 1).sum()
            lengths.append(rat_count)

        return max(1, int(np.mean(lengths)))

    def _extract_top_k_tokens(
        self,
        attention_scores: List[np.ndarray],
        attention_masks_list: List[torch.Tensor],
        input_ids_list: List[torch.Tensor],
        k: int,
    ) -> List[np.ndarray]:
        """Extract top-k content tokens as hard predictions"""
        hard_predictions = []

        for idx, (attn, mask) in enumerate(zip(attention_scores, attention_masks_list)):
            pred_mask = np.zeros_like(attn, dtype=int)
            valid_positions = mask.bool().cpu().numpy().flatten()

            # Exclude special tokens
            input_ids = input_ids_list[idx].cpu().numpy().flatten()
            is_special = np.isin(input_ids, list(self.special_token_ids))
            content_positions = valid_positions & ~is_special

            content_attn = attn[content_positions]

            if k > 0 and len(content_attn) > 0:
                k_actual = min(k, len(content_attn))
                top_k_within_content = np.argsort(content_attn)[-k_actual:]
                content_indices = np.where(content_positions)[0]
                top_k_indices = content_indices[top_k_within_content]
                pred_mask[top_k_indices] = 1

            hard_predictions.append(pred_mask)

        return hard_predictions

    def _compute_auprc(
        self,
        attention_scores: List[np.ndarray],
        human_rationales: List[torch.Tensor],
        attention_masks_list: List[torch.Tensor],
        input_ids_list: List[torch.Tensor],
    ) -> float:
        """Compute AUPRC for soft attention scores"""
        all_scores = []
        all_labels = []

        for idx, (attn, rat, mask) in enumerate(
            zip(attention_scores, human_rationales, attention_masks_list)
        ):
            valid_positions = mask.bool().cpu().numpy().flatten()

            # Exclude special tokens
            input_ids = input_ids_list[idx].cpu().numpy().flatten()
            is_special = np.isin(input_ids, list(self.special_token_ids))
            content_positions = valid_positions & ~is_special

            all_scores.extend(attn[content_positions].tolist())
            all_labels.extend(
                rat.cpu().numpy().flatten()[content_positions].astype(int).tolist()
            )

        all_scores = np.array(all_scores, dtype=float)
        all_labels = np.array(all_labels, dtype=int)

        if len(np.unique(all_labels)) < 2:
            print(f"Warning: Only one class in labels: {np.unique(all_labels)}")
            return 0.0

        precision, recall, _ = precision_recall_curve(all_labels, all_scores)
        return auc(recall, precision)

    def _compute_token_f1(
        self,
        hard_predictions: List[np.ndarray],
        human_rationales: List[torch.Tensor],
        attention_masks_list: List[torch.Tensor],
    ) -> Tuple[float, float, float]:
        """Compute token-level F1, Precision, Recall"""
        all_preds = []
        all_labels = []

        for pred, rat, mask in zip(
            hard_predictions, human_rationales, attention_masks_list
        ):
            valid_positions = mask.bool().cpu().numpy().flatten()
            all_preds.extend(pred[valid_positions].astype(int).tolist())
            all_labels.extend(
                rat.cpu().numpy().flatten()[valid_positions].astype(int).tolist()
            )

        all_preds = np.array(all_preds, dtype=int)
        all_labels = np.array(all_labels, dtype=int)

        f1 = f1_score(all_labels, all_preds, zero_division=0)
        precision = precision_score(all_labels, all_preds, zero_division=0)
        recall = recall_score(all_labels, all_preds, zero_division=0)

        return f1, precision, recall

    def _compute_comprehensiveness(
        self,
        test_data: List[Dict],
        test_results: Dict,
        hard_predictions: List[np.ndarray],
    ) -> Tuple[float, List[float]]:
        """
        Compute comprehensiveness: how much does REMOVING rationales hurt?
        Uses DataLoader approach for efficiency
        """
        # Create modified dataset (remove rationales from attention mask)
        modified_data = []
        for item, rationale_mask in zip(test_data, hard_predictions):
            modified_item = self._create_comprehensiveness_instance(
                item, rationale_mask
            )
            modified_data.append(modified_item)

        # Create DataLoader
        modified_dataset = self.dataset_class(modified_data)
        modified_loader = DataLoader(
            modified_dataset, batch_size=self.batch_size, shuffle=False
        )

        # Get predictions using model's predict method
        results = self.model.predict(modified_loader, return_attentions=False)
        modified_probs = results["probabilities"]

        # Calculate comprehensiveness scores
        comprehensiveness_scores = []
        for idx, (prob, label) in enumerate(zip(test_results["probabilities"], test_results['labels'])):
            original_prob = prob[
                label
            ]  # Probability from normal prediction process for the label
            modified_prob = modified_probs[idx][label]

            # Comprehensiveness = original - modified (higher is better)
            comp_score = original_prob - modified_prob
            comprehensiveness_scores.append(comp_score)

        return modified_probs, comprehensiveness_scores

    def _compute_sufficiency(
        self,
        test_data: List[Dict],
        test_results: Dict,
        hard_predictions: List[np.ndarray],
    ) -> Tuple[List[float], List[float]]:
        """
        Compute sufficiency: how well do ONLY rationales predict?
        Uses DataLoader approach for efficiency
        """
        # Create modified dataset (keep only rationales in attention mask)
        modified_data = []
        for item, rationale_mask in zip(test_data, hard_predictions):
            modified_item = self._create_sufficiency_instance(item, rationale_mask)
            modified_data.append(modified_item)

        # Create DataLoader
        modified_dataset = self.dataset_class(modified_data)
        modified_loader = DataLoader(
            modified_dataset, batch_size=self.batch_size, shuffle=False
        )

        # Get predictions using model's predict method
        results = self.model.predict(modified_loader, return_attentions=False)
        modified_probs = results["probabilities"]

        # Calculate sufficiency scores
        sufficiency_scores = []
        for idx, (prob, label) in enumerate(zip(test_results["probabilities"], test_results['labels'])):
            original_prob = prob[
                label
            ]  # Probability from normal prediction process for the label
            modified_prob = modified_probs[idx][label]

            # Sufficiency = original - modified (lower/negative is better)
            suff_score = original_prob - modified_prob
            sufficiency_scores.append(suff_score)

        return modified_probs, sufficiency_scores

    def _create_comprehensiveness_instance(
        self, item: Dict, rationale_mask: np.ndarray
    ) -> Dict:
        """
        Create instance for comprehensiveness: REMOVE rationales from attention
        Keep: CLS + non-rationale content tokens + SEP
        """
        input_ids = item["input_ids"].cpu().numpy().flatten()
        orig_mask = item["attention_mask"].cpu().numpy().flatten()

        # Start with original mask
        new_mask = orig_mask.copy()

        # Zero out rationale positions (except CLS and SEP)
        for i in range(len(new_mask)):
            if rationale_mask[i] == 1:  # This is a rationale
                # Don't mask if it's CLS or SEP
                if input_ids[i] not in self.special_token_ids:
                    new_mask[i] = 0

        return {
            "post_id": item["post_id"],
            "input_ids": torch.tensor(input_ids).unsqueeze(0),
            "attention_mask": torch.tensor(new_mask).unsqueeze(0),
            "rationales": item["rationales"],
            "hard_label": item["hard_label"],
        }

    def _create_sufficiency_instance(
        self, item: Dict, rationale_mask: np.ndarray
    ) -> Dict:
        """
        Create instance for sufficiency: Keep ONLY rationales in attention
        Keep: CLS + rationale tokens + SEP
        """
        input_ids = item["input_ids"].cpu().numpy().flatten()
        orig_mask = item["attention_mask"].cpu().numpy().flatten()

        # Start with zeros
        new_mask = np.zeros_like(orig_mask)

        # Always keep CLS and SEP
        for i in range(len(new_mask)):
            if input_ids[i] in self.special_token_ids:
                new_mask[i] = 1

        # Keep rationale positions
        for i in range(len(new_mask)):
            if rationale_mask[i] == 1 and orig_mask[i] == 1:
                new_mask[i] = 1

        return {
            "post_id": item["post_id"],
            "input_ids": torch.tensor(input_ids).unsqueeze(0),
            "attention_mask": torch.tensor(new_mask).unsqueeze(0),
            "rationales": item["rationales"],
            "hard_label": item["hard_label"],
        }

# Experiment Management System

This notebook now uses a systematic experiment tracking system that organizes all outputs by experiment.

In [16]:
# ============================================================================
# FULL EXPERIMENT PIPELINE WITH TRACKING
# ============================================================================
from ExperimentManager import ExperimentManager
from HateClassifier import HateClassifier
from HateDataset import HateDataset
# 1. CREATE EXPERIMENT
experiment_manager = ExperimentManager(base_dir="./experiments")
experiment_dir = experiment_manager.create_experiment(
    config=config,
    custom_name="baseline_distilbert",  # Change this for each experiment
    description="Baseline model with distilbert-base-uncased, standard hyperparameters"
)

Created new experiment: 20251218_212137_baseline_distilbert_455539da
Directory: experiments\20251218_212137_baseline_distilbert_455539da
Description: Baseline model with distilbert-base-uncased, standard hyperparameters


In [21]:
from sklearn.utils.class_weight import compute_class_weight
y = [int(td['hard_label']) for td in train_data]

class_weights = torch.tensor(compute_class_weight(
    class_weight='balanced',
    classes=np.array([0, 1]),  # Ensure consistent order
    y=y
), dtype=torch.float32)

In [22]:
# 2. TRAIN MODEL (config.save_dir is automatically updated)
model = HateClassifier(config, class_weight=class_weights)
history = model.train(train_dataloader=train_loader, val_dataloader=val_loader)

# Save training history
experiment_manager.save_training_history(history)

Using class weighting for loss function.
Training on device: cuda
Model: distilbert-base-uncased
Epochs: 2
Batch size: 32
Gradient accumulation steps: 1
Effective batch size: 32
Learning rate: 1e-05
Mixed precision (AMP): True
Gradient clipping: 1.0

Loss Configuration:
  Multi-layer loss: True
    - Auxiliary (layer 3): α=0.5
    - Main (final layer): β=0.5
  Attention supervision: True
    - Ranking loss: λ=1
    - Margin: 0.1, Threshold: 0.05

Epoch 1/2


Training:   4%|▍         | 21/481 [00:05<02:02,  3.76batch/s, total=0.826, main=0.694, aux=0.713, attn=0.122]


KeyboardInterrupt: 

In [None]:
# 3. EVALUATE MODEL
# Load best model
model.load_model('best_model')

# Evaluate model
result = model.predict(test_dataloader=test_loader, return_attentions=True)

# Save predictions
experiment_manager.save_predictions(result, filename="test_predictions.pkl")

In [None]:
from collections import Counter
from itertools import chain

all_target_groups = chain.from_iterable([group['target_groups'] for group in train_data])
all_target_groups = [group for group in all_target_groups if group != 'None' and group != 'Other']
counter = Counter(all_target_groups)

n_common = 10
bias_target_groups = [tg[0] for tg in counter.most_common(n_common)]

# 4. BIAS EVALUATION
gmb_metrics, bias_details = calculate_gmb_metrics(
    test_data=test_data,
    probabilities=result['probabilities'],
    target_groups=bias_target_groups
)

# Save bias metrics
experiment_manager.save_bias_metrics(gmb_metrics, bias_details)


In [None]:

# 5. XAI EVALUATION (only on hate samples)
test_data_hate_only = []
test_results_hate_only = {'attentions': [], 'probabilities': [], 'predictions': [], 'post_id': [], 'labels': []}
for idx, td in enumerate(test_data):
    if td['hard_label'] == 1:
        test_data_hate_only.append(td)
        test_results_hate_only['attentions'].append(result['attentions'][idx])
        test_results_hate_only['probabilities'].append(result['probabilities'][idx])
        test_results_hate_only['predictions'].append(result['predictions'][idx])
        test_results_hate_only['post_id'].append(result['post_ids'][idx])
        test_results_hate_only['labels'].append(result['labels'][idx])

calculator = FaithfulnessMetrics(
    model=model,
    tokenizer=tokenizer,
    dataset_class=HateDataset,
    batch_size=32
)

k = 5
eraser_save_path = f"{experiment_dir}/results/test_explain_output.jsonl"
xai_results = calculator.compute_all_metrics(test_data_hate_only, test_results_hate_only, k, eraser_save_path)

# Save XAI metrics
experiment_manager.save_xai_metrics(xai_results)


In [None]:
!git clone https://github.com/jayded/eraserbenchmark.git
!sed -i "285s/.*/    labels=['0', '1']/" eraserbenchmark/rationale_benchmark/metrics.py
!sed -i "286s/.*/    label_to_int = {'0':0, '1': 1}/" eraserbenchmark/rationale_benchmark/metrics.py

In [None]:
score_file = f"{experiment_dir}/results/eraser_result.json"
!PYTHONPATH=./eraserbenchmark:%PYTHONPATH% && python eraserbenchmark/rationale_benchmark/metrics.py --split test --strict --data_dir Data/explanations --results {eraser_save_path} --score_file {score_file}

In [None]:
# 6. CREATE FINAL SUMMARY
final_summary = {
    "test_accuracy": float(result['accuracy']),
    "test_f1": float(result['f1']),
    "test_loss": float(result['loss']),
    "gmb_metrics": gmb_metrics,
    "xai_metrics": xai_results,
    "total_params": sum(p.numel() for p in model.base_model.parameters()),
}

experiment_manager.save_final_metrics(final_summary)


In [None]:
for key, value in final_summary.items():
    if isinstance(value, dict):
        for k, v in value.items():
            print(f"{key}.{k}: {v}")
        print()
    else:
        print(f"{key}: {value}")
    print()

In [None]:
# 7. MARK EXPERIMENT AS COMPLETE
experiment_manager.mark_complete(
    status="completed",
    notes="Baseline experiment completed successfully"
)

print("\n" + "="*80)
print("EXPERIMENT COMPLETED!")
print(f"All results saved to: {experiment_dir}")
print("="*80)

## Experiment Management Utilities

Useful commands for managing and comparing experiments:

In [None]:
# View all experiments
experiment_manager = ExperimentManager()
experiment_manager.print_experiment_summary()

In [None]:
# List only completed experiments
completed_experiments = experiment_manager.list_experiments(status="completed")
print(f"Found {len(completed_experiments)} completed experiments")
for exp in completed_experiments:
    print(f"  - {exp['experiment_id']}: {exp.get('description', 'No description')}")

In [None]:
experiment_dir

In [None]:
# Compare multiple experiments
experiment_ids = [exp['experiment_id'] for exp in completed_experiments]
comparison = experiment_manager.compare_experiments(experiment_ids)

# Display comparison
for exp in comparison["experiments"]:
    print(f"\nExperiment: {exp['experiment_id']}")
    print(f"  Model: {exp['config'].get('model_name', 'N/A')}")
    print(f"  Learning Rate: {exp['config'].get('learning_rate', 'N/A')}")
    print(f"  Test F1: {exp['metrics'].get('test_f1', 'N/A')}")
    print(f"  Test Accuracy: {exp['metrics'].get('test_accuracy', 'N/A')}")

## Visualization Tools

Visualize and compare experiment results:

In [None]:
from experiment_visualization import (
    plot_training_curves,
    plot_metrics_comparison,
    plot_bias_metrics,
    plot_xai_metrics,
    create_experiment_report
)

In [None]:
# Uncomment to use:
plot_training_curves(experiment_ids, save_path="training_curves.png")

In [None]:
# Compare final metrics across experiments
plot_metrics_comparison(experiment_ids, save_path="metrics_comparison.png")

In [None]:
# Visualize bias metrics
plot_bias_metrics(experiment_ids, save_path="bias_comparison.png")

In [None]:
# Visualize XAI metrics
plot_xai_metrics(experiment_ids, save_path="xai_comparison.png")

# Task
Enable class weighting for the `hard_label` by calculating inverse class frequency weights from `train_data`, updating `config.class_weighting` to `True`, modifying the `HateClassifier` class definition in `HateClassifier.py` to accept and utilize these weights in `nn.CrossEntropyLoss`, and finally, passing the calculated weights to the `HateClassifier` during its instantiation.

## Calculate Class Weights

### Subtask:
Calculate the inverse class frequency weights for the `hard_label` in the `train_data` using `sklearn.utils.class_weight.compute_class_weight` and convert them to a PyTorch tensor. This code will be added to the selected cell.


**Reasoning**:
To calculate the inverse class frequency weights, I need to import the `compute_class_weight` function, extract the `hard_label`s from the training data, define the classes, compute the weights, and then convert them to a PyTorch tensor.



In [None]:
from sklearn.utils.class_weight import compute_class_weight

# Extract hard_label values from train_data
labels = [item['hard_label'] for item in train_data]

# Define unique classes
classes = [0, 1]

# Calculate inverse class frequency weights
class_weights_np = compute_class_weight(class_weight='balanced', classes=classes, y=labels)

# Convert to PyTorch tensor and move to device
class_weights = torch.tensor(class_weights_np, dtype=torch.float32).to(device)

print("Calculated class weights:", class_weights)
print("Class weights moved to device:", device)

**Reasoning**:
The error message indicates that the `classes` parameter of `compute_class_weight` expects a NumPy array, not a Python list. I will convert the `classes` list to a NumPy array to fix this.



In [None]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np # Import numpy

# Extract hard_label values from train_data
labels = [item['hard_label'] for item in train_data]

# Define unique classes as a NumPy array
classes = np.array([0, 1])

# Calculate inverse class frequency weights
class_weights_np = compute_class_weight(class_weight='balanced', classes=classes, y=labels)

# Convert to PyTorch tensor and move to device
class_weights = torch.tensor(class_weights_np, dtype=torch.float32).to(device)

print("Calculated class weights:", class_weights)
print("Class weights moved to device:", device)

**Reasoning**:
The subtask requires setting `config.class_weighting` to `True`. This step updates the configuration object accordingly.



In [None]:
config = TrainingConfig()
config.train_attention = True
config.num_epochs = 5
config.class_weighting = True
config.lambda_attn = 1
config.use_multi_layer_loss = True
config.hidden_dropout_prob = 0.2


## Update HateClassifier Class

### Subtask:
Modify the `HateClassifier` class definition (`HateClassifier.py`) to accept a `class_weights` parameter in its `__init__` method and use these weights when initializing `nn.CrossEntropyLoss` within the `_setup_loss_functions` method.


**Reasoning**:
To guide the user in modifying the `HateClassifier.py` file, I need to first display its current content so they know what to edit. This helps confirm the current state of the file before providing the changes.



In [None]:
import inspect
import os

# Get the path to the HateClassifier.py file
file_path = os.path.join('HateDeRC', 'HateClassifier.py')

# Read and print the content of the file
with open(file_path, 'r') as f:
    print(f.read())

**Reasoning**:
The previous attempt to read `HateClassifier.py` failed because the path was constructed incorrectly relative to the current working directory. The working directory is already inside `HateDeRC`, so the file path should just be `HateClassifier.py`.



In [None]:
import inspect
import os

# Get the path to the HateClassifier.py file
file_path = 'HateClassifier.py'

# Read and print the content of the file
with open(file_path, 'r') as f:
    print(f.read())

**Reasoning**:
The previous output showed the content of the `HateClassifier.py` file. Now, I will generate the code to modify this file by adding `class_weights` to the `__init__` method and ensuring it's used in the `CrossEntropyLoss` initialization. I will use a shell command to overwrite the file with the modified content.



In [None]:
%%writefile HateClassifier.py
import torch
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp import autocast, GradScaler  # type: ignore
from transformers import (
    AutoModel,  # pyright: ignore[reportPrivateImportUsage]
    AutoConfig,  # pyright: ignore[reportPrivateImportUsage]
)
from tqdm import tqdm
import numpy as np
from TrainingConfig import TrainingConfig
from sklearn.metrics import f1_score, accuracy_score
import json
import os
from pathlib import Path
from typing import Optional


class HateClassifier:
    """
    HateDeRC: Hate Speech Detection with Debiasing Residual Connections.

    This classifier implements a novel architecture for binary hate speech detection
    that reduces dependency on target-sensitive words (e.g., race, religion, gender).

    Key Components:
    ---------------
    1. **DeRC Mechanism (Debiasing Residual Connection)**:
       - Creates a residual connection from lower layer (layer 3) to final layer
       - Lower layers capture shallow, bias-prone patterns (target words)
       - Final layer learns bias-independent features by incorporating debiased residuals
       - Prevents over-reliance on lexical shortcuts

    2. **Multi-Layer Loss**:
       - Auxiliary loss from debias layer (layer 3): guides early layers to learn useful representations
       - Main loss from final layer: ensures correct final predictions
       - Configurable weighting allows balancing between auxiliary and main objectives

    3. **Ranking-Based Attention Supervision**:
       - Uses human token-level annotations as supervision signal
       - Employs pairwise ranking loss instead of cross-entropy (respects independent annotations)
       - Enforces: tokens with higher human importance should receive higher attention
       - Helps model focus on contextually relevant hate indicators

    Loss Function:
    --------------
    Total Loss = α × lower_layer_loss + β × upper_layer_loss + λ × attention_ranking_loss

    Where:
    - α (lower_loss_weight): Weight for auxiliary classification loss
    - β (upper_loss_weight): Weight for main classification loss
    - λ (lambda_attn): Weight for attention supervision loss

    Architecture:
    -------------
    Input → BERT-like Encoder (all hidden states) → Multi-layer Classifiers
                                ↓
                          Layer 3 (debias) -----(residual)----→ Final Layer
                                ↓                                    ↓
                         Auxiliary Loss                        Main Loss
    """

    def __init__(self, config: TrainingConfig, class_weights: Optional[torch.Tensor] = None, **kwargs):
        self.config = config
        self.class_weights = class_weights # Store class weights

        # Initialize device:
        if torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        # Configure & initialize the base model
        model_config = AutoConfig.from_pretrained(
            config.model_name, output_attentions=True, output_hidden_states=True
        )
        self.base_model = AutoModel.from_pretrained(
            config.model_name, config=model_config
        )

        # Get model dimensions
        hidden_size = self.base_model.config.hidden_size
        self.num_layers = self.base_model.config.num_hidden_layers

        # Multi-layer classifier heads (one for each transformer layer)
        self.classifier_list = nn.ModuleList(
            [nn.Linear(hidden_size, config.num_labels) for _ in range(self.num_layers)]
        )

        # Dropout layer
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Layer configuration for debiasing (similar to BertDistill)
        self.debias_layer = 3  # Layer index for auxiliary loss
        self.use_multi_layer_loss = getattr(config, "use_multi_layer_loss", False)

        # Loss Weighting Configuration
        self.lower_loss_weight = getattr(
            config, "lower_loss_weight", 0.5
        )  # α: auxiliary loss weight
        self.upper_loss_weight = getattr(
            config, "upper_loss_weight", 0.5
        )  # β: main loss weight
        self.lambda_attn = getattr(
            config, "lambda_attn", 0.1
        )  # λ: attention loss weight

        # Attention Training Configuration
        self.train_attention = getattr(config, "train_attention", False)
        self.ranking_margin = getattr(
            config, "ranking_margin", 0.1
        )  # Margin for pairwise ranking
        self.ranking_threshold = getattr(
            config, "ranking_threshold", 0.05
        )  # Threshold for significant pairs

        # Move models to device
        self.base_model.to(self.device)
        self.classifier_list.to(self.device)

        # Configure loss function
        if config.class_weighting:
            if self.class_weights is not None:
                class_weight_on_device = self.class_weights.to(self.device)
                self.cls_criterion = CrossEntropyLoss(weight=class_weight_on_device)
            else:
                raise ValueError("config.class_weighting is True but no class_weights were provided.")
        else:
            self.cls_criterion = CrossEntropyLoss()

        # Configure optimizer (for base model and all classifiers)
        params = list(self.base_model.parameters()) + list(
            self.classifier_list.parameters()
        )
        self.optimizer = AdamW(params, lr=config.learning_rate)

        # Learning rate scheduler
        self.scheduler = None

        # Mixed precision training
        self.use_amp = config.use_amp and torch.cuda.is_available()
        self.scaler = GradScaler() if self.use_amp else None

        # Gradient accumulation
        self.gradient_accumulation_steps = config.gradient_accumulation_steps
        self.max_grad_norm = config.max_grad_norm

        # Torch compile (PyTorch 2.0+)
        if hasattr(config, "use_compile") and config.use_compile:
            try:
                self.base_model = torch.compile(self.base_model)
                print("✓ Model compiled with torch.compile")
            except Exception as e:
                print(f"Warning: torch.compile failed: {e}")

        # Training history
        self.history = {
            "train_loss": [],
            "val_loss": [],
            "val_accuracy": [],
            "val_f1": [],
        }

    def train_epoch(self, train_dataloader):
        """
        Train for one epoch using the HateDeRC architecture.

        Returns:
            float: Average total loss for the epoch
        """
        self.base_model.train()
        for classifier in self.classifier_list:
            classifier.train()

        # Track individual loss components for monitoring
        total_loss = 0
        total_cls_loss = 0
        total_lower_loss = 0
        total_attn_loss = 0
        num_batches = 0

        progress_bar = tqdm(train_dataloader, desc="Training", unit="batch")
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch["input_ids"].to(self.device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
            labels = batch["labels"].to(self.device, non_blocking=True)

            # Mixed precision context
            with autocast(device_type="cuda", enabled=self.use_amp):
                # Forward Pass through base model (get all hidden states)
                outputs = self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=self.train_attention,
                )

                # Get hidden states from all layers
                hidden_states = outputs.hidden_states  # Tuple of (num_layers+1) tensors

                # Apply pooling to each layer's hidden state (extract CLS token)
                # Note: hidden_states[0] is embeddings, hidden_states[1:] are transformer layers
                pooled_outputs = []
                for i in range(1, len(hidden_states)):  # Skip embeddings layer
                    cls_token = hidden_states[i][:, 0, :]  # Get CLS token
                    pooled_outputs.append(cls_token)

                # Get logits from all classifier heads
                logits_list = []
                for i, pooled_output in enumerate(pooled_outputs):
                    if i == len(pooled_outputs) - 1 and self.use_multi_layer_loss:
                        # Last layer: add residual connection from debias layer
                        combined = (
                            self.dropout(pooled_output)
                            + self.dropout(pooled_outputs[self.debias_layer]).detach()
                        )
                        logits = self.classifier_list[i](combined)
                    else:
                        logits = self.classifier_list[i](self.dropout(pooled_output))
                    logits_list.append(logits)

                # Calculate unified loss with all components
                loss_dict = self._calculate_loss(
                    logits_list=logits_list,
                    labels=labels,
                    attention_mask=attention_mask,
                    attentions=outputs.attentions if self.train_attention else None,
                    human_rationales=(
                        batch.get("rationales") if self.train_attention else None
                    ),
                )

                loss = loss_dict["total_loss"]

                # Track individual loss components
                if "cls_loss" in loss_dict:
                    total_cls_loss += loss_dict["cls_loss"]
                if "lower_loss" in loss_dict:
                    total_lower_loss += loss_dict["lower_loss"]
                if "attn_loss" in loss_dict:
                    total_attn_loss += loss_dict["attn_loss"]

                # Scale loss for gradient accumulation
                loss = loss / self.gradient_accumulation_steps

            total_loss += loss.item() * self.gradient_accumulation_steps
            num_batches += 1

            # Backward pass with gradient scaling
            if self.use_amp and self.scaler is not None:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()

            # Gradient accumulation: only step optimizer every N batches
            if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
                # Gradient clipping
                if self.use_amp and self.scaler is not None:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.base_model.parameters(), self.max_grad_norm
                    )
                    torch.nn.utils.clip_grad_norm_(
                        self.classifier_list.parameters(), self.max_grad_norm
                    )
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(
                        self.base_model.parameters(), self.max_grad_norm
                    )
                    torch.nn.utils.clip_grad_norm_(
                        self.classifier_list.parameters(), self.max_grad_norm
                    )
                    self.optimizer.step()

                self.optimizer.zero_grad()

                if self.scheduler:
                    self.scheduler.step()

            # Update progress bar with relevant loss components
            postfix = {"total": total_loss / num_batches}

            if self.use_multi_layer_loss:
                postfix["main"] = total_cls_loss / num_batches
                postfix["aux"] = total_lower_loss / num_batches
            else:
                postfix["cls"] = total_cls_loss / num_batches

            if self.train_attention and total_attn_loss > 0:
                postfix["attn"] = total_attn_loss / num_batches

            progress_bar.set_postfix(postfix)

        return total_loss / num_batches

    def evaluate(self, val_dataloader):
        self.base_model.eval()
        for classifier in self.classifier_list:
            classifier.eval()

        total_loss = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Evaluating", unit="batch"):
                input_ids = batch["input_ids"].to(self.device, non_blocking=True)
                attention_mask = batch["attention_mask"].to(
                    self.device, non_blocking=True
                )
                labels = batch["labels"].to(self.device, non_blocking=True)

                # Mixed precision context for evaluation
                with autocast(device_type="cuda", enabled=self.use_amp):
                    # Forward pass through base model (no hidden states needed - more efficient)
                    outputs = self.base_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )

                    # Get final layer's CLS token (for evaluation, only use final classifier)
                    cls_token = outputs.last_hidden_state[:, 0, :]

                    # Get logits from final classifier (no dropout in eval mode)
                    logits = self.classifier_list[-1](cls_token)

                    # Calculate loss
                    loss = self.cls_criterion(logits, labels)
                    total_loss += loss.item()

                # Get predictions
                preds = torch.argmax(logits, dim=1).cpu().numpy()

                # Store predictions and labels
                all_preds.extend(preds)
                all_labels.extend(labels.cpu().numpy())

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average="macro")

        # Calculate loss
        avg_loss = total_loss / len(val_dataloader)

        return avg_loss, accuracy, f1

    def train(self, train_dataloader, val_dataloader):
        """
        Train the HateDeRC model with multi-component loss.

        Training Configuration:
        - DeRC: Residual connection from layer 3 to final layer
        - Multi-layer loss: Auxiliary (layer 3) + Main (final layer)
        - Attention supervision: Ranking-based loss from human annotations

        Loss Weights:
        - Lower loss: {:.2f}
        - Upper loss: {:.2f}
        - Attention: {:.2f}

        Args:
            train_dataloader: DataLoader for training data
            val_dataloader: DataLoader for validation data

        Returns:
            dict: Training history with loss and metrics per epoch
        """.format(
            self.lower_loss_weight, self.upper_loss_weight, self.lambda_attn
        )
        print(f"Training on device: {self.device}")
        print(f"Model: {self.config.model_name}")
        print(f"Epochs: {self.config.num_epochs}")
        print(f"Batch size: {self.config.batch_size}")
        print(f"Gradient accumulation steps: {self.gradient_accumulation_steps}")
        print(
            f"Effective batch size: {self.config.batch_size * self.gradient_accumulation_steps}"
        )
        print(f"Learning rate: {self.config.learning_rate}")
        print(f"Mixed precision (AMP): {self.use_amp}")
        print(f"Gradient clipping: {self.max_grad_norm}")
        print(f"\nLoss Configuration:")
        print(f"  Multi-layer loss: {self.use_multi_layer_loss}")
        if self.use_multi_layer_loss:
            print(
                f"    - Auxiliary (layer {self.debias_layer}): α={self.lower_loss_weight}"
            )
            print(f"    - Main (final layer): β={self.upper_loss_weight}")
        print(f"  Attention supervision: {self.train_attention}")
        if self.train_attention:
            print(f"    - Ranking loss: λ={self.lambda_attn}")
            print(
                f"    - Margin: {self.ranking_margin}, Threshold: {self.ranking_threshold}"
            )
        print("=" * 60)

        best_f1 = 0.0

        for epoch in range(self.config.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

            # Train for one epoch
            train_loss = self.train_epoch(train_dataloader)

            # Evaluate on validation set
            val_loss, val_accuracy, val_f1 = self.evaluate(val_dataloader)

            # Store metrics in history
            self.history["train_loss"].append(train_loss)
            self.history["val_loss"].append(val_loss)
            self.history["val_accuracy"].append(val_accuracy)
            self.history["val_f1"].append(val_f1)

            # Print epoch summary
            print(f"\nEpoch {epoch + 1} Summary:")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss:   {val_loss:.4f}")
            print(f"  Val Acc:    {val_accuracy:.4f}")
            print(f"  Val F1:     {val_f1:.4f}")

            # Save best model
            if val_f1 > best_f1:
                best_f1 = val_f1
                self.save_model("best_model")
                print(f"  ✓ New best model saved! (F1: {best_f1:.4f})")

            # Save checkpoint every epoch
            self.save_model(f"checkpoint_epoch_{epoch + 1}")

        # Save final model and training history
        self.save_model("final_model")
        self.save_history()

        print("\n" + "=" * 60)
        print(f"Training completed!")
        print(f"Best F1 Score: {best_f1:.4f}")
        print(
            f"Training history saved to: {self.config.save_dir}/training_history.json"
        )

        return self.history

    def save_model(self, name: str):
        """
        Save model checkpoint.
        """
        save_path = Path(self.config.save_dir) / name
        save_path.mkdir(parents=True, exist_ok=True)

        # Save base model
        self.base_model.save_pretrained(save_path)

        # Save all classifier heads and optimizer state
        torch.save(
            {
                "classifier_list_state_dict": self.classifier_list.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "config": self.config,
            },
            save_path / "training_state.pt",
        )

    def load_model(self, name: str):
        """
        Load model checkpoint.
        """
        load_path = Path(self.config.save_dir) / name

        # Load base model
        self.base_model = AutoModel.from_pretrained(load_path)
        self.base_model.to(self.device)

        # Load all classifier heads and optimizer state
        # Note: weights_only=False is safe for your own checkpoints
        checkpoint = torch.load(load_path / "training_state.pt", weights_only=False)
        self.classifier_list.load_state_dict(checkpoint["classifier_list_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        print(f"Model loaded from: {load_path}")

    def save_history(self):
        """
        Save training history to JSON file.
        """
        save_path = Path(self.config.save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        history_path = save_path / "training_history.json"
        with open(history_path, "w") as f:
            json.dump(self.history, f, indent=2)

    def load_history(self):
        """
        Load training history from JSON file.
        """
        history_path = Path(self.config.save_dir) / "training_history.json"

        if history_path.exists():
            with open(history_path, "r") as f:
                self.history = json.load(f)
            print(f"Training history loaded from: {history_path}")
        else:
            print(f"No training history found at: {history_path}")

    def predict(
        self, test_dataloader, return_layer_outputs=False, return_attentions=False
    ):
        """
        Run inference on test data and return predictions with metrics.

        Args:
            test_dataloader: DataLoader for test data
            return_layer_outputs: If True, returns predictions from all layers

        Returns:
            dict: Contains predictions, true labels, probabilities, loss, accuracy, and F1 score
        """
        self.base_model.eval()
        for classifier in self.classifier_list:
            classifier.eval()

        total_loss = 0
        all_preds = []
        all_labels = []
        all_probs = []
        all_attention_weights = []
        all_post_ids = []
        all_layer_preds = (
            [[] for _ in range(self.num_layers)] if return_layer_outputs else None
        )

        print(f"Running inference on {len(test_dataloader)} batches...")

        with torch.no_grad():
            for batch in tqdm(test_dataloader, desc="Testing", unit="batch"):
                input_ids = batch["input_ids"].to(self.device, non_blocking=True)
                attention_mask = batch["attention_mask"].to(
                    self.device, non_blocking=True
                )
                labels = batch["labels"].to(self.device, non_blocking=True)
                post_id = batch["post_id"]

                # Mixed precision context for inference
                with autocast(device_type="cuda", enabled=self.use_amp):
                    # Forward pass through base model (only request hidden states if needed)
                    outputs = self.base_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=return_layer_outputs,
                        output_attentions=return_attentions,
                    )

                    cls_token = outputs.last_hidden_state[:, 0, :]
                    logits = self.classifier_list[-1](cls_token)

                    # Calculate loss
                    loss = self.cls_criterion(logits, labels)
                    total_loss += loss.item()

                    # Get predictions and probabilities
                    probs = F.softmax(logits, dim=1)
                    preds = torch.argmax(logits, dim=1)

                # Store predictions, labels, and probabilities
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
                all_post_ids.extend(post_id)
                # Store attentions if requested
                if return_attentions:
                    attention_result = self.extract_attention(outputs.attentions)
                    if attention_result is not None:
                        all_attention_weights.extend(attention_result)
                    else:
                        print("No attention weights extracted.")

                # Optionally get predictions from all layers
                if return_layer_outputs and all_layer_preds is not None:
                    hidden_states = outputs.hidden_states[1:]  # Skip embeddings
                    for i, hidden_state in enumerate(hidden_states):
                        layer_cls = hidden_state[:, 0, :]
                        layer_logits = self.classifier_list[i](layer_cls)
                        layer_preds = torch.argmax(layer_logits, dim=1)
                        all_layer_preds[i].extend(layer_preds.cpu().numpy())

        # Convert to numpy arrays
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        all_probs = np.array(all_probs)

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average="macro")
        avg_loss = total_loss / len(test_dataloader)

        # Print summary
        print("\n" + "=" * 60)
        print("Test Results:")
        print(f"  Test Loss:     {avg_loss:.4f}")
        print(f"  Test Accuracy: {accuracy:.4f}")
        print(f"  Test F1:       {f1:.4f}")
        print("=" * 60)

        results = {
            "post_ids": all_post_ids,
            "predictions": all_preds,
            "labels": all_labels,
            "probabilities": all_probs,
            "loss": avg_loss,
            "accuracy": accuracy,
            "f1": f1,
        }

        if return_layer_outputs and all_layer_preds is not None:
            results["layer_predictions"] = [
                np.array(preds) for preds in all_layer_preds
            ]

        if return_attentions:
            results["attentions"] = all_attention_weights

        return results

    def save_predictions(self, results: dict, filename: str = "test_results.json"):
        """
        Save prediction results to file.

        Args:
            results: Dictionary returned from predict() method
            filename: Name of the file to save results
        """
        save_path = Path(self.config.save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        # Convert numpy arrays to lists for JSON serialization
        results_serializable = {
            "predictions": results["predictions"].tolist(),
            "labels": results["labels"].tolist(),
            "probabilities": results["probabilities"].tolist(),
            "loss": float(results["loss"]),
            "accuracy": float(results["accuracy"]),
            "f1": float(results["f1"]),
        }

        results_path = save_path / filename
        with open(results_path, "w") as f:
            json.dump(results_serializable, f, indent=2)

        print(f"Test results saved to: {results_path}")

    def extract_attention(self, attentions, return_tensor=False):
        """
        Extract CLS token attention from last layer, averaged across all heads.

        Args:
            attentions: Tuple of attention tensors from model
            return_tensor: If True, returns tensor (for training). If False, returns numpy (for inference)

        Returns:
            (batch_size, seq_len) attention weights as tensor or numpy array
        """
        if attentions is None or len(attentions) == 0:
            print("WARNING: No attention data available.")
            return None

        # Take CLS representation from the last layer's attentions
        last_layer_attentions = attentions[
            -1
        ]  # Shape: (batch_size, num_heads, seq_len, seq_len)
        cls_attentions = last_layer_attentions[
            :, :, 0, :
        ]  # Shape: (batch_size, num_heads, seq_len)
        # Average over all heads
        avg_cls_attention = cls_attentions.mean(dim=1)  # Shape: (batch_size, seq_len)

        if return_tensor:
            return (
                avg_cls_attention  # Keep as tensor for training (preserves gradients)
            )
        else:
            return avg_cls_attention.cpu().numpy()  # Convert to numpy for inference

    def _calculate_loss(
        self,
        logits_list,
        labels,
        attention_mask,
        attentions=None,
        human_rationales=None,
    ):
        """
        Calculate unified loss with configurable component weights.

        This method computes the total loss as a weighted combination of:
        1. Classification loss (auxiliary from debias layer if multi-layer enabled)
        2. Main classification loss (from final layer)
        3. Attention ranking loss (if attention supervision enabled)

        Args:
            logits_list: List of logits from all classifier heads
            labels: Ground truth labels (batch_size,)
            attention_mask: Attention mask for padding (batch_size, seq_len)
            attentions: Tuple of attention tensors from model (optional)
            human_rationales: Human token annotations (batch_size, seq_len) (optional)

        Returns:
            dict: Dictionary containing:
                - 'total_loss': Weighted sum of all loss components
                - 'cls_loss': Main classification loss (if applicable)
                - 'lower_loss': Auxiliary classification loss (if multi-layer enabled)
                - 'attn_loss': Attention ranking loss (if attention training enabled)
        """
        loss_dict = {}

        # Main logits from final layer
        final_logits = logits_list[-1]

        # Component 1 & 2: Classification Losses
        if self.use_multi_layer_loss:
            # Auxiliary loss from debias layer (helps guide lower layer representations)
            lower_loss = self.cls_criterion(logits_list[self.debias_layer], labels)

            # Main loss from final layer (with residual connection)
            upper_loss = self.cls_criterion(final_logits, labels)

            # Weighted combination: Total = α × lower + β × upper
            cls_loss = (
                self.lower_loss_weight * lower_loss
                + self.upper_loss_weight * upper_loss
            )

            loss_dict["lower_loss"] = lower_loss.item()
            loss_dict["cls_loss"] = upper_loss.item()
        else:
            # Single loss from final layer only (baseline)
            cls_loss = self.cls_criterion(final_logits, labels)
            loss_dict["cls_loss"] = cls_loss.item()

        total_loss = cls_loss

        # Component 3: Attention Ranking Loss
        if (
            self.train_attention
            and attentions is not None
            and human_rationales is not None
        ):
            if len(attentions) > 0:
                # Extract model attention from last layer
                model_attention = self.extract_attention(attentions, return_tensor=True)

                # Move rationales to device
                human_rationales = human_rationales.to(self.device, non_blocking=True)

                # Calculate ranking loss (already weighted by lambda_attn internally)
                attn_loss = self.calculate_attention_loss(
                    human_rationales, model_attention, attention_mask
                )

                total_loss = total_loss + attn_loss
                loss_dict["attn_loss"] = attn_loss.item()

        loss_dict["total_loss"] = total_loss
        return loss_dict

    def calculate_attention_loss(
        self, human_rationales, models_attentions, attention_mask
    ):
        """
        Calculate pairwise margin ranking loss for attention supervision.

        For each pair of tokens (i, j) where human_score[i] > human_score[j],
        we enforce: attention[i] - attention[j] >= margin

        This respects the independent nature of human annotations and focuses on
        relative importance rather than absolute values.

        Args:
            human_rationales: (batch_size, seq_len) - Independent token importance scores [0-1]
            models_attentions: (batch_size, seq_len) - Model attention weights (softmax tensor)
            attention_mask: (batch_size, seq_len) - Mask for padding tokens

        Returns:
            Scalar ranking loss
        """
        # Ensure models_attentions is a tensor (should be from extract_attention with return_tensor=True)
        if isinstance(models_attentions, np.ndarray):
            models_attentions = torch.from_numpy(models_attentions).to(self.device)

        batch_size, seq_len = human_rationales.shape

        # Mask out padding positions
        human_rationales = human_rationales * attention_mask
        models_attentions = models_attentions * attention_mask

        total_loss = 0.0
        total_pairs = 0

        for b in range(batch_size):
            # Get valid (non-padding) positions for this sample
            valid_mask = attention_mask[b].bool()
            valid_indices = torch.where(valid_mask)[0]

            if len(valid_indices) < 2:
                continue  # Skip if less than 2 valid tokens

            # Get human scores and model attentions for valid tokens
            human_scores = human_rationales[b, valid_indices]  # (num_valid,)
            model_attn = models_attentions[b, valid_indices]  # (num_valid,)

            # Create all pairs: (num_valid, num_valid)
            # human_i: (num_valid, 1), human_j: (1, num_valid)
            human_i = human_scores.unsqueeze(1)  # (num_valid, 1)
            human_j = human_scores.unsqueeze(0)  # (1, num_valid)

            model_i = model_attn.unsqueeze(1)  # (num_valid, 1)
            model_j = model_attn.unsqueeze(0)  # (1, num_valid)

            # Find pairs where human_i > human_j (should have model_i > model_j)
            human_diff = human_i - human_j  # (num_valid, num_valid)
            model_diff = model_i - model_j  # (num_valid, num_valid)

            # Only consider pairs where there's a clear difference in human scores
            # (avoid pairs with very similar scores)
            significant_pairs = (
                human_diff > self.ranking_threshold
            ).float()  # Threshold to avoid noise

            # Margin ranking loss: max(0, margin - (model_i - model_j)) when human_i > human_j
            # We want: model_i - model_j >= margin when human_i > human_j
            ranking_loss = torch.relu(self.ranking_margin - model_diff)

            # Apply mask to only consider significant pairs
            ranking_loss = ranking_loss * significant_pairs

            # Accumulate
            num_pairs = significant_pairs.sum()
            if num_pairs > 0:
                total_loss += ranking_loss.sum() / num_pairs
                total_pairs += 1

        if total_pairs == 0:
            return torch.tensor(0.0, device=human_rationales.device)

        avg_loss = total_loss / total_pairs
        return self.lambda_attn * avg_loss


**Reasoning**:
The `HateClassifier.py` file has been modified in the previous step. To ensure that the Python runtime uses the updated version of the class, I need to restart the Python interpreter. This action will clear the current kernel state and reload all modules, allowing the modified `HateClassifier` class to be imported correctly in subsequent steps.



In [None]:
exit()

### Important: Restart and Re-run Cells

The Python runtime has been restarted. To proceed with the task, you must re-run all the necessary setup cells from the beginning of the notebook until this point.

Specifically, please re-execute the cells that perform the following actions:

1.  **Clone the `HateDeRC` repository and change directory** (Cell `24658442` and `8bccce96`)
2.  **Define `data_path`** (Cell `1f654143`)
3.  **Load data, tokenizer, and process data splits** (Cells `f81645fa`, `f070b12b`, `8d446ddb`, `pEC9l6Mr6xo0`, `43NMFuwv62Q4`)
4.  **Set `config` parameters** (Cell `b3618d6e` - ensure `config.class_weighting = True` is set again)
5.  **Set random seeds for reproducibility** (Cell `487ff027`)
6.  **Create `HateDataset` instances** (Cell `87c63ecf`)
7.  **Create `DataLoader` instances** (Cell `f4887f43`)
8.  **Calculate class weights** (Cell `3b61c466`)

After re-running these cells, you can continue with the next step, which will involve instantiating the `HateClassifier` with the `class_weights`.

**Reasoning**:
Now that the `HateClassifier.py` file has been modified to accept class weights, and assuming the runtime has been restarted and all previous setup cells (including `config` and `class_weights` calculation) have been re-executed, the next step is to instantiate the `HateClassifier` class, passing the `config` object and the calculated `class_weights`.



In [None]:
model = HateClassifier(config, class_weights=class_weights)
history = model.train(train_dataloader=train_loader, val_dataloader=val_loader)

# Save training history
experiment_manager.save_training_history(history)

**Reasoning**:
The `NameError` indicates that `HateClassifier` and `ExperimentManager` were not re-imported after the kernel restart. I need to explicitly import them before they can be used.



In [None]:
from ExperimentManager import ExperimentManager
from HateClassifier import HateClassifier
from HateDataset import HateDataset

# 1. CREATE EXPERIMENT
experiment_manager = ExperimentManager(base_dir="./experiments")
experiment_dir = experiment_manager.create_experiment(
    config=config,
    custom_name="baseline_distilbert",  # Change this for each experiment
    description="Baseline model with distilbert-base-uncased, standard hyperparameters"
)