In [1]:
import shutil
import os

# Delete HateDeRC directory if it exists
if os.path.exists('HateDeRC'):
  shutil.rmtree('HateDeRC')
!git clone https://github.com/jamesalv/HateDeRC
%cd HateDeRC

Cloning into 'HateDeRC'...
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 36 (delta 14), reused 29 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (36/36), 4.55 MiB | 8.98 MiB/s, done.
Resolving deltas: 100% (14/14), done.
/content/HateDeRC


In [7]:
from TrainingConfig import TrainingConfig
from typing import Dict, Any, Tuple, List
import numpy as np
import torch
from transformers import AutoTokenizer
import json

In [8]:
data_path = 'Data/dataset.json'

In [9]:
config = TrainingConfig()

In [10]:
# Seed all randomness for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(config.seed)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)

# Preprocessing

In [11]:
import re
import string

def deobfuscate_text(text):
    """
    Normalize common text obfuscation patterns to reveal original words.
    Useful for hate speech detection and content analysis.

    Args:
        text (str): Input text with potential obfuscations

    Returns:
        str: Text with obfuscations normalized
    """
    if not isinstance(text, str):
        return text

    # Make a copy to work with
    result = text.lower()

    # 1. Handle asterisk/symbol replacements
    symbol_patterns = {
        # Common profanity
        r'f\*+c?k': 'fuck',
        r'f\*+': 'fuck',
        r's\*+t': 'shit',
        r'b\*+ch': 'bitch',
        r'a\*+s': 'ass',
        r'd\*+n': 'damn',
        r'h\*+l': 'hell',
        r'c\*+p': 'crap',

        # Slurs and hate speech terms (be comprehensive for detection)
        r'n\*+g+[aer]+': 'nigger',  # Various n-word obfuscations
        r'f\*+g+[ot]*': 'faggot',
        r'r\*+[dt]ard': 'retard',
        r'sp\*+c': 'spic',

        # Other symbols
        r'@ss': 'ass',
        r'b@tch': 'bitch',
        r'sh!t': 'shit',
        r'f#ck': 'fuck',
        r'd@mn': 'damn',
    }

    for pattern, replacement in symbol_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 2. Handle character spacing (f u c k -> fuck)
    spacing_patterns = {
        r'\bf\s+u\s+c\s+k\b': 'fuck',
        r'\bs\s+h\s+i\s+t\b': 'shit',
        r'\bd\s+a\s+m\s+n\b': 'damn',
        r'\bh\s+e\s+l\s+l\b': 'hell',
        r'\ba\s+s\s+s\b': 'ass',
        r'\bc\s+r\s+a\s+p\b': 'crap',
    }

    for pattern, replacement in spacing_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 3. Handle number/letter substitutions
    leet_patterns = {
        # Basic leet speak
        r'\b3\s*1\s*1\s*3\b': 'elle',  # 3113 -> elle
        r'\bf4g\b': 'fag',
        r'\bf4gg0t\b': 'faggot',
        r'\bn00b\b': 'noob',
        r'\bl33t\b': 'leet',
        r'\bh4t3\b': 'hate',
        r'\b5h1t\b': 'shit',
        r'\bf0ck\b': 'fock',
    }

    for pattern, replacement in leet_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 4. Handle repeated characters and separators
    # Remove excessive punctuation between letters
    result = re.sub(r'([a-z])[^\w\s]+([a-z])', r'\1\2', result)

    # Handle underscore separation
    result = re.sub(r'([a-z])_+([a-z])', r'\1\2', result)

    # Handle dot separation
    result = re.sub(r'([a-z])\.+([a-z])', r'\1\2', result)

    # 5. Handle common misspellings/variations used for evasion
    evasion_patterns = {
        r'\bfuk\b': 'fuck',
        r'\bfuq\b': 'fuck',
        r'\bfck\b': 'fuck',
        r'\bshyt\b': 'shit',
        r'\bshit\b': 'shit',
        r'\bbiatch\b': 'bitch',
        r'\bbeatch\b': 'bitch',
        r'\basshole\b': 'asshole',
        r'\ba55hole\b': 'asshole',
        r'\btard\b': 'retard',
        r'\bfagg\b': 'fag',
    }

    for pattern, replacement in evasion_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 6. Clean up multiple spaces
    result = re.sub(r'\s+', ' ', result).strip()

    return result

In [12]:
def aggregate_rationales(rationales, labels, post_length, drop_abnormal=False):
    """
    If all 3 annotators are normal → 3 zero spans → average (all zeros).
    If k annotators are non-normal and k spans exist → average the k spans (no added zeros).
    If k non-normal but fewer than k spans:
        If the missing annotators are non-normal → do not fill with zeros; average only existing spans and record rationale_support = #spans.
        If the missing annotators are normal (e.g., 2 hate + 1 normal + 2 spans) → append one zero span for the normal.
    """
    count_normal = labels.count(0)
    count_hate = labels.count(1)
    count_rationales = len(rationales)
    pad = np.zeros(post_length, dtype="int").tolist()

    # If there are hate labels but no rationales, something is wrong
    if count_hate > 0 and count_rationales == 0:
        if drop_abnormal:
            return None

        # Else just fill with 0
        return np.zeros(post_length).tolist()

    # If all annotators are normal, return all zeros
    if count_normal == 3:
        return np.zeros(post_length).tolist()

    # If we have hate annotators
    if count_hate > 0:
        # Case 1: Number of rationales matches number of hate annotators
        if count_rationales == count_hate:
            return np.average(rationales, axis=0).tolist()

        # Case 2: Fewer rationales than hate annotators
        elif count_rationales < count_hate:
            # Add zero padding for normal annotators only
            rationales_copy = rationales.copy()
            zeros_to_add = count_normal
            for _ in range(zeros_to_add):
                rationales_copy.append(pad)
            return np.average(rationales_copy, axis=0).tolist()

        # Case 3: More rationales than hate annotators (shouldn't happen normally)
        else:
            # Just average what we have
            return np.average(rationales, axis=0).tolist()

    # Fallback: return zeros if no clear case matches
    return np.zeros(post_length).tolist()

In [13]:
from typing import List, Tuple

def preprocess_text(raw_text):
    preprocessed_text = raw_text
    # # Remove HTML tags <>
    preprocessed_text = preprocessed_text.replace("<", "").replace(">", "")
    # # De-Obsfucate Patterns
    preprocessed_text = deobfuscate_text(preprocessed_text)

    return preprocessed_text


def create_text_segment(
    text_tokens: List[str], rationale_mask: List[int]
) -> List[Tuple[List[str], int]]:
    """
    Process a rationale mask to identify contiguous segments of highlighted text.
    Then create a segmented representation of the tokens

    Args:
        text_tokens: Original text tokens
        mask: Binary mask where 1 indicates a highlighted token (this consists of mask from 3 annotators)

    Returns:
        A list of tuples (text segment, mask value)
    """
    # Handle case where mask is empty (no rationale provided), usually this is normal classification
    mask = rationale_mask

    # for mask in all_rationale_mask:
    # Find breakpoints (transitions between highlighted/1 and non-highlighted/0)
    breakpoints = []
    mask_values = []

    # Always start with position 0
    breakpoints.append(0)
    mask_values.append(mask[0])

    # Find transitions in the mask
    for i in range(1, len(mask)):
        if mask[i] != mask[i - 1]:
            breakpoints.append(i)
            mask_values.append(mask[i])

    # Always end with the length of the text
    if breakpoints[-1] != len(mask):
        breakpoints.append(len(mask))

    # Create segments based on breakpoints
    segments = []
    for i in range(len(breakpoints) - 1):
        start = breakpoints[i]
        end = breakpoints[i + 1]
        segments.append((text_tokens[start:end], mask_values[i]))

    return segments


def align_rationales(tokens, rationales, tokenizer, max_length=128):
    """
    Align rationales with tokenized text while handling different tokenizer formats.

    Args:
        tokens: Original text tokens
        rationales: Original rationale masks
        tokenizer: The tokenizer to use
        max_length: Maximum sequence length

    Returns:
        Dictionary with tokenized inputs and aligned rationale masks
    """
    segments = create_text_segment(tokens, rationales)
    all_human_rationales = []
    all_input_ids = []
    all_attention_mask = []
    all_token_type_ids = []
    all_rationales = []
    for text_segment, rationale_value in segments:
        inputs = {}
        concatenated_text = " ".join(text_segment)
        processed_segment = preprocess_text(concatenated_text)
        tokenized = tokenizer(
            processed_segment, add_special_tokens=False, return_tensors="pt"
        )

        # Extract the relevant data
        segment_input_ids = tokenized["input_ids"][0]
        segment_attention_mask = tokenized["attention_mask"][0]
        # Handle token_type_ids if present
        if "token_type_ids" in tokenized:
            segment_token_type_ids = tokenized["token_type_ids"][0]
            all_token_type_ids.extend(segment_token_type_ids)

        # Add input IDs and attention mask
        all_input_ids.extend(segment_input_ids)
        all_attention_mask.extend(segment_attention_mask)

        # Add rationales (excluding special tokens)
        segment_rationales = [rationale_value] * len(segment_input_ids)
        all_rationales.extend(segment_rationales)
    # Get special token IDs
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id

    # Add special tokens at the beginning and end
    all_input_ids = [cls_token_id] + all_input_ids + [sep_token_id]
    all_attention_mask = [1] + all_attention_mask + [1]

    # Handle token_type_ids if the model requires it
    if hasattr(tokenizer, "create_token_type_ids_from_sequences"):
        all_token_type_ids = tokenizer.create_token_type_ids_from_sequences(
            all_input_ids[1:-1]
        )
    elif all_token_type_ids:
        all_token_type_ids = [0] + all_token_type_ids + [0]
    else:
        all_token_type_ids = [0] * len(all_input_ids)

    # Check tokenized vs rationales length
    if len(all_input_ids) != len(all_attention_mask):
        print("Warning: length of tokens and rationales do not match")

    # Add zero rationale values for special tokens
    all_rationales = [0] + all_rationales + [0]

    # Truncate to max length if needed
    if len(all_input_ids) > max_length:
        print("WARNING: NEED TO TRUNCATE")
        all_input_ids = all_input_ids[:max_length]
        all_attention_mask = all_attention_mask[:max_length]
        all_token_type_ids = all_token_type_ids[:max_length]
        all_rationales = all_rationales[:max_length]

    # Pad to max_length if needed
    pad_token_id = tokenizer.pad_token_id
    padding_length = max_length - len(all_input_ids)

    if padding_length > 0:
        all_input_ids = all_input_ids + [pad_token_id] * padding_length
        all_attention_mask = all_attention_mask + [0] * padding_length
        all_token_type_ids = all_token_type_ids + [0] * padding_length
        all_rationales = all_rationales + [0] * padding_length

    # Convert lists to tensors
    inputs = {
        "input_ids": torch.tensor([all_input_ids], dtype=torch.long),
        "attention_mask": torch.tensor([all_attention_mask], dtype=torch.long),
        "token_type_ids": (
            torch.tensor([all_token_type_ids], dtype=torch.long)
            if "token_type_ids" in tokenizer.model_input_names
            else None
        ),
        "rationales": torch.tensor([all_rationales], dtype=torch.float32),
    }

    # Remove None values
    inputs = {k: v for k, v in inputs.items() if v is not None}
    return inputs

In [14]:
import re
import string
from collections import Counter
from tqdm import tqdm
invalid_rationales_key = []
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
def process_raw_entries(data, drop_abnormal = False):
    """
    Process raw data entries
    """
    print("Processing raw entries...")
    processed_entries = {}
    dropped = 0

    for key, value in tqdm(data.items()):
        try:
            # Extract labels (1 = hate/offensive, 0 = normal)
            labels = []
            target_groups = []
            for annot in value["annotators"]:
                label = annot["label"]
                target = annot["target"]

                labels.append(1 if label in ['hatespeech', 'offensive'] else 0)
                target_groups.append(target)

            # Process rationales
            rationales = value.get("rationales", [])
            aggregated_rationale = aggregate_rationales(rationales, labels, len(value["post_tokens"]),  drop_abnormal=drop_abnormal)


            if aggregated_rationale is None:
                dropped += 1
                continue
            inputs = align_rationales(value['post_tokens'], aggregated_rationale, tokenizer)

            # Determine final label
            # Majority vote for hard label
            hard_label = Counter(labels).most_common(1)[0][0]
            # Average for soft label
            soft_label = sum(labels) / len(labels)

            # Determine Targets --> Considered as targets when there are at least 2 mention
            target_groups = [
              t
              for annot in value['annotators']
              for t in annot['target']
            ]
            filtered = [k for k, v in Counter(target_groups).items() if v > 2]

            # Store inputs and labels
            processed_entries[key] = {
                'input_ids': inputs['input_ids'],
                'attention_mask': inputs['attention_mask'],
                'rationales': inputs['rationales'],
                'raw_text': " ".join(value['post_tokens']), # Keep raw text for potential debugging/analysis
                'hard_label': hard_label,
                'soft_label': soft_label,
                'target_groups': filtered
            }

        except Exception as e:
            dropped += 1
            print(f"Error processing {key}: {e}")

    print(f"Processed: {len(processed_entries)}, Dropped: {dropped}")
    return processed_entries

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [15]:
with open(data_path, 'r') as file:
  data = json.load(file)
processed_data = process_raw_entries(data)

Processing raw entries...


 37%|███▋      | 7397/20148 [00:06<00:10, 1270.12it/s]



 90%|████████▉ | 18101/20148 [00:17<00:02, 1012.24it/s]

Error processing 24439295_gab: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.


100%|██████████| 20148/20148 [00:19<00:00, 1017.18it/s]

Processed: 20147, Dropped: 1





In [36]:
# Dataset conversion to ERASER format for later XAI analysis
processed_data['7067204_gab']['rationales']

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]])

In [42]:
import more_itertools as mit
# https://stackoverflow.com/questions/2154249/identify-groups-of-continuous-numbers-in-a-list
def find_ranges(iterable):
    """Yield range of consecutive numbers."""
    for group in mit.consecutive_groups(iterable):
        group = list(group)
        if len(group) == 1:
            yield group[0]
        else:
            yield group[0], group[-1]

# Convert dataset into ERASER format: https://github.com/jayded/eraserbenchmark/blob/master/rationale_benchmark/utils.py
def get_evidence(post_id, anno_text, explanations):
    output = []

    indexes = sorted([i for i, each in enumerate(explanations) if each==1])
    span_list = list(find_ranges(indexes))

    for each in span_list:
        if type(each)== int:
            start = each
            end = each+1
        elif len(each) == 2:
            start = each[0]
            end = each[1]+1
        else:
            print('error')

        output.append({"docid":post_id,
              "end_sentence": -1,
              "end_token": end,
              "start_sentence": -1,
              "start_token": start,
              "text": ' '.join([str(x) for x in anno_text[start:end]])})
    return output

# To use the metrices defined in ERASER, we will have to convert the dataset
def convert_to_eraser_format(dataset, save_split, save_path, division_file):
    final_output = []

    if save_split:
        # Create the save_path directory if it does not exist
        if not os.path.exists(save_path):
            os.makedirs(save_path, exist_ok=True)

        train_fp = open(save_path+'train.jsonl', 'w')
        val_fp = open(save_path+'val.jsonl', 'w')
        test_fp = open(save_path+'test.jsonl', 'w')

    for key, value in dataset.items():

        temp = {}
        post_id = key
        post_class = value['hard_label']
        input_ids = value['input_ids'].squeeze().tolist()
        # Rounding up the values in the rationales, assuming we do union method
        rationales = value['rationales'].squeeze().ceil().int().tolist()


        temp['annotation_id'] = post_id
        temp['classification'] = post_class
        temp['evidences'] = [get_evidence(post_id, input_ids, rationales)]
        temp['query'] = "What is the class?"
        temp['query_type'] = None
        final_output.append(temp)

        if save_split:
          with open(division_file) as fp:
            id_division = json.load(fp)

          if not os.path.exists(save_path+'docs'):
              os.makedirs(save_path+'docs')

          with open(save_path+'docs/'+post_id, 'w') as fp:
              fp.write(' '.join([str(x) for x in input_ids]))

          if post_id in id_division['train']:
              train_fp.write(json.dumps(temp)+'\n')

          elif post_id in id_division['val']:
              val_fp.write(json.dumps(temp)+'\n')

          elif post_id in id_division['test']:
              test_fp.write(json.dumps(temp)+'\n')
          else:
              print(post_id)

    if save_split:
        train_fp.close()
        val_fp.close()
        test_fp.close()

    return final_output

In [43]:
final_output = convert_to_eraser_format(dataset=processed_data, save_split=True, save_path='Data/explanations', division_file='Data/post_id_divisions.json')

14971751_gab
1179080731642990592_twitter
25605196_gab
1179083428924231680_twitter
2098760_gab
1179077088025989120_twitter
1001155_gab
1179095481563123713_twitter
24748005_gab
1179096375826321410_twitter
1178172311565938688_twitter
24414450_gab
1179076835562463233_twitter
1179043060598083585_twitter
1179093707687157761_twitter
1179030004539252742_twitter
1179040820026118144_twitter
1179004510993289216_twitter
1178825627413221376_twitter
1178842439253020673_twitter
1179102616846032897_twitter
1178764150811455488_twitter
24348147_gab
24356318_gab
1178985525040037889_twitter
1178994611190210560_twitter
1178559091482021893_twitter
1179077511407427586_twitter
1177662316835475457_twitter
1178434057270697985_twitter
1178771184701362177_twitter
1179004811716382720_twitter
1178429287424253970_twitter
1179100190822866944_twitter
10007406_gab
10010765_gab
10039905_gab
10110793_gab
10156663_gab
10348812_gab
10360775_gab
10375689_gab
10548673_gab
10579761_gab
10652908_gab
10653951_gab
10666566_gab
1

In [11]:
with open('Data/post_id_divisions.json') as file:
  post_id_divisions = json.load(file)

# Train
train_data = []
train_missing = 0
for train_key in post_id_divisions['train']:
  try:
    train_data.append(processed_data[train_key])
  except Exception as e:
    train_missing += 1
print(f"Train missing: {train_missing}")

# Val
val_data = []
val_missing = 0
for val_key in post_id_divisions['val']:
  try:
    val_data.append(processed_data[val_key])
  except Exception as e:
    val_missing += 1
print(f"Val missing: {val_missing}")

# Test
test_data = []
test_missing = 0
for test_key in post_id_divisions['test']:
  try:
    test_data.append(processed_data[test_key])
  except Exception as e:
    test_missing += 1
print(f"Test missing: {test_missing}")

Train missing: 1
Val missing: 0
Test missing: 0


In [12]:
from HateDataset import HateDataset

# Create datasets with pre-tokenized data
train_dataset = HateDataset(data=train_data)
val_dataset = HateDataset(data=val_data)
test_dataset = HateDataset(data=test_data)

In [13]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # Use shuffle=False for validation
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Use shuffle=False for testing

In [14]:
from HateClassifier import HateClassifier
model = HateClassifier(config)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [15]:
history = model.train(train_dataloader=train_loader, val_dataloader=val_loader)

Training on device: cuda
Model: distilbert-base-uncased
Epochs: 2
Batch size: 32
Gradient accumulation steps: 1
Effective batch size: 32
Learning rate: 1e-05
Mixed precision (AMP): True
Gradient clipping: 1.0

Epoch 1/2


Training: 100%|██████████| 481/481 [00:52<00:00,  9.08batch/s, loss=0.517]
Evaluating: 100%|██████████| 61/61 [00:01<00:00, 32.58batch/s]



Epoch 1 Summary:
  Train Loss: 0.5167
  Val Loss:   0.4763
  Val Acc:    0.7653
  Val F1:     0.7632
  ✓ New best model saved! (F1: 0.7632)

Epoch 2/2


Training: 100%|██████████| 481/481 [00:53<00:00,  9.02batch/s, loss=0.419]
Evaluating: 100%|██████████| 61/61 [00:01<00:00, 32.91batch/s]



Epoch 2 Summary:
  Train Loss: 0.4193
  Val Loss:   0.4474
  Val Acc:    0.7846
  Val F1:     0.7769
  ✓ New best model saved! (F1: 0.7769)

Training completed!
Best F1 Score: 0.7769
Training history saved to: ./checkpoints/training_history.json


# Evaluation

In [16]:
import pickle

result = model.predict(test_dataloader=test_loader, return_attentions=True)

# Save the result to a file
with open('prediction_results.pkl', 'wb') as f:
  pickle.dump(result, f)

print("Results saved to prediction_results.pkl")

Running inference on 61 batches...


Testing: 100%|██████████| 61/61 [00:01<00:00, 32.37batch/s]


Test Results:
  Test Loss:     0.4320
  Test Accuracy: 0.7994
  Test F1:       0.7913
Results saved to prediction_results.pkl





## Bias

In [17]:
def get_bias_evaluation_samples(data, method, group):
    """
    Get positive and negative sample IDs for bias evaluation based on method and group

    Args:
        data: list of data entries
        method: Bias evaluation method ('subgroup', 'bpsn', or 'bnsp')
        group: Target group to evaluate

    Returns:
        Tuple of (positive_ids, negative_ids)
    """
    positive_ids = []
    negative_ids = []

    for idx, row in enumerate(data):
        target_groups = row['target_groups']
        if target_groups is None:
            continue

        is_in_group = group in target_groups

        # Convert various label formats to binary toxic/non-toxic
        if 'hard_label' in row:
            is_toxic = row['hard_label'] == 1
        else:
            continue

        if method == 'subgroup':
            # Only consider samples mentioning the group
            if is_in_group:
                if is_toxic:
                    positive_ids.append(idx)
                else:
                    negative_ids.append(idx)

        elif method == 'bpsn':
            # Compare non-toxic posts mentioning the group with toxic posts NOT mentioning the group
            if is_in_group and not is_toxic:
                negative_ids.append(idx)
            elif not is_in_group and is_toxic:
                positive_ids.append(idx)

        elif method == 'bnsp':
            # Compare toxic posts mentioning the group with non-toxic posts NOT mentioning the group
            if is_in_group and is_toxic:
                positive_ids.append(idx)
            elif not is_in_group and not is_toxic:
                negative_ids.append(idx)

    return positive_ids, negative_ids

In [18]:
from collections import defaultdict
from sklearn.metrics import roc_auc_score

def calculate_gmb_metrics(
    test_data: List[Dict[str, Any]],
    probabilities: np.ndarray,
    target_groups: List[str]
):
    """
    Calculate GMB (Generalized Mean of Bias) AUC metrics from model predictions

    Args:
        probabilities: Model's probability outputs
        test_data: List of test data entries
        target_groups: List of target groups to evaluate

    Returns:
        Dictionary with GMB metrics
    """
    # Create mappings from post_id to predictions and ground truth
    prediction_scores = defaultdict(lambda: defaultdict(dict))
    ground_truth = {}

    for idx, row in enumerate(test_data):
        prediction_scores[idx] = probabilities[idx, 1]
        ground_truth[idx] = row['hard_label']

    # Calculate metrics for each target group and method
    bias_metrics = {}
    methods = ['subgroup', 'bpsn', 'bnsp']

    for method in methods:
        bias_metrics[method] = {}  # Initialize nested dictionary for each method
        for group in target_groups:
            # Get positive and negative samples based on the method
            positive_ids, negative_ids = get_bias_evaluation_samples(test_data, method, group)

            if len(positive_ids) == 0 or len(negative_ids) == 0:
                print(f"Skipping {method} for group {group}: no samples found")
                continue  # Skip if no samples for this group/method

            # Collect ground truth and predictions
            y_true = []
            y_score = []

            for post_id in positive_ids:
                if post_id in ground_truth and post_id in prediction_scores:
                    y_true.append(ground_truth[post_id])
                    y_score.append(prediction_scores[post_id])

            for post_id in negative_ids:
                if post_id in ground_truth and post_id in prediction_scores:
                    y_true.append(ground_truth[post_id])
                    y_score.append(prediction_scores[post_id])

            # Calculate AUC if we have enough samples with both classes
            if len(y_true) > 10 and len(set(y_true)) > 1:
                try:
                    auc = roc_auc_score(y_true, y_score)
                    bias_metrics[method][group] = auc
                except ValueError:
                    print(f"Could not compute AUC for {method} and group {group} due to ValueError")
                    pass

    # Calculate GMB for each method
    gmb_metrics = {}
    power = -5  # Power parameter for generalized mean

    for method in methods:
        if not bias_metrics[method]:
            continue

        scores = list(bias_metrics[method].values())
        if not scores:
            continue

        # Calculate generalized mean with p=-5
        power_mean = np.mean([score ** power for score in scores]) ** (1/power)
        gmb_metrics[f'GMB-{method.upper()}-AUC'] = power_mean

    # Calculate a combined GMB score that includes all methods
    all_scores = []
    for method in methods:
        all_scores.extend(list(bias_metrics[method].values()))

    if all_scores:
        gmb_metrics['GMB-COMBINED-AUC'] = np.mean([score ** power for score in all_scores]) ** (1/power)

    return gmb_metrics, bias_metrics

In [19]:
from collections import Counter
# Get top 10 most common target groups in the full dataset
all_target_groups = []
for _, value in processed_data.items():
  all_target_groups.extend(value['target_groups'])

# Remove None
all_target_groups = [group for group in all_target_groups if group != 'None' and group != 'Other']
counter = Counter(all_target_groups)

n_common = 10
bias_target_groups = [tg[0] for tg in counter.most_common(n_common)]

In [20]:
gmb_metrics, bias_details = calculate_gmb_metrics(
  test_data=test_data,
  probabilities=result['probabilities'],
  target_groups=bias_target_groups
)

In [21]:
print('GMB-Metrics')
for key, value in gmb_metrics.items():
  print(f'{key}: {value}')

GMB-Metrics
GMB-SUBGROUP-AUC: 0.8495980902841762
GMB-BPSN-AUC: 0.7543771924301432
GMB-BNSP-AUC: 0.8178618794907845
GMB-COMBINED-AUC: 0.8012341275858437


In [22]:
print('Bias Details')
print()
for key, entry in bias_details.items():
  print(f"Metrics: {key}")
  for subgroup, value in entry.items():
    print(f'{subgroup}: {value}')
  print()

Bias Details

Metrics: subgroup
African: 0.908675799086758
Jewish: 0.7838345864661653
Islam: 0.9672897196261682
Homosexual: 0.875
Women: 0.680327868852459
Refugee: 0.8202020202020202
Arab: 0.896551724137931
Hispanic: 1.0
Asian: 1.0
Caucasian: 0.9666666666666667

Metrics: bpsn
African: 0.7659804983748646
Jewish: 0.5401387512388504
Islam: 0.942512077294686
Homosexual: 0.8692556634304207
Women: 0.8473635522664199
Refugee: 0.9116020438833785
Arab: 0.7556154537286612
Hispanic: 0.868515205724508
Asian: 0.9964539007092199
Caucasian: 0.9902395740905058

Metrics: bnsp
African: 0.9313661701865756
Jewish: 0.9479579411253067
Islam: 0.9026980275328544
Homosexual: 0.8792293233082706
Women: 0.7785512083826263
Refugee: 0.7793647030935167
Arab: 0.9378338999514327
Hispanic: 0.9807405036278276
Asian: 0.8448417779403695
Caucasian: 0.640940170940171



# XAI

In [23]:
from HateInterpreter import HateInterpreter
interpreter = HateInterpreter(model, tokenizer, config)

In [102]:
# Filter to only include hateful post
valid_ids = []
for idx, data in enumerate(test_data):
    if data['hard_label'] == 1:
        valid_ids.append(idx)
test_data = [test_data[idx] for idx in valid_ids]

In [106]:
input_ids_list = [d['input_ids'].squeeze() for d in test_data]
attention_masks_list = [d['attention_mask'].squeeze() for d in test_data]
human_rationales = [d['rationales'].squeeze() for d in test_data]
attention_scores = [result['attentions'][idx] for idx in valid_ids]
predicted_classes = [result['labels'][idx] for idx in valid_ids]
original_probs = [result['probabilities'][idx] for idx in valid_ids]

In [107]:
def _extract_top_k_tokens(
    tokenizer,
    attention_scores: List[np.ndarray],
    attention_masks_list: List[torch.Tensor],
    input_ids_list: List[torch.Tensor] = None,
    k: int = 5,
) -> List[np.ndarray]:
    """
    Extract top-k tokens with highest attention as hard predictions

    Excludes special tokens ([CLS], [SEP], [PAD]) from selection

    Returns binary masks (1 = rationale, 0 = not rationale)
    """
    hard_predictions = []

    # Get special token IDs
    special_token_ids = {
        tokenizer.cls_token_id,
        tokenizer.sep_token_id,
        tokenizer.pad_token_id,
    }
    # Remove None if any tokenizer doesn't have these
    special_token_ids = {x for x in special_token_ids if x is not None}

    for idx, (attn, mask) in enumerate(zip(attention_scores, attention_masks_list)):
        # Create binary mask
        pred_mask = np.zeros_like(attn, dtype=int)

        # Only consider non-padding tokens
        valid_positions = mask.bool().cpu().numpy().flatten()

        # Also exclude special tokens if input_ids provided
        if input_ids_list is not None and idx < len(input_ids_list):
            input_ids = input_ids_list[idx].cpu().numpy().flatten()
            # Mark positions with special tokens as invalid
            is_special = np.isin(input_ids, list(special_token_ids))
            content_positions = valid_positions & ~is_special
        else:
            content_positions = valid_positions

        # Get attention scores for content tokens only
        content_attn = attn[content_positions]

        if k > 0 and len(content_attn) > 0:
            k_actual = min(k, len(content_attn))
            # Get top-k indices within content positions
            top_k_within_content = np.argsort(content_attn)[-k_actual:]

            # Map back to original positions
            content_indices = np.where(content_positions)[0]
            top_k_indices = content_indices[top_k_within_content]

            pred_mask[top_k_indices] = 1

        hard_predictions.append(pred_mask)

    return hard_predictions

In [108]:
hard_predictions = _extract_top_k_tokens(tokenizer, attention_scores, attention_masks_list, input_ids_list, k=5)

In [109]:
def _compute_auprc(
    attention_scores: List[np.ndarray],
    human_rationales: List[torch.Tensor],
    attention_masks_list: List[torch.Tensor],
    input_ids_list: List[torch.Tensor] = None
) -> float:
    """
    Compute Area Under Precision-Recall Curve for soft scores

    This measures: "If I rank tokens by attention, do I recover human rationales?"
    Excludes special tokens ([CLS], [SEP], [PAD]) from evaluation.
    """
    all_scores = []
    all_labels = []

    # Get special token IDs
    special_token_ids = {
        tokenizer.cls_token_id,
        tokenizer.sep_token_id,
        tokenizer.pad_token_id,
    }
    special_token_ids = {x for x in special_token_ids if x is not None}

    for idx, (attn, rat, mask) in enumerate(zip(attention_scores, human_rationales, attention_masks_list)):
        # Only consider non-padding tokens
        valid_positions = mask.bool().cpu().numpy().flatten()

        # Also exclude special tokens if input_ids provided
        if input_ids_list is not None and idx < len(input_ids_list):
            input_ids = input_ids_list[idx].cpu().numpy().flatten()
            is_special = np.isin(input_ids, list(special_token_ids))
            content_positions = valid_positions & ~is_special
        else:
            content_positions = valid_positions

        all_scores.extend(attn[content_positions].tolist())
        # Convert rationales to binary integers (0 or 1)
        rationales_hard = rat.cpu().numpy().flatten()[content_positions].astype(int).tolist()
        all_labels.extend(rationales_hard)

    # Convert to numpy arrays
    all_scores = np.array(all_scores, dtype=float)
    all_labels = np.array(all_labels, dtype=int)  # Ensure integer type

    # Check if we have both classes (need at least one positive and one negative)
    if len(np.unique(all_labels)) < 2:
        print(f"Warning: Only one class present in labels. Unique values: {np.unique(all_labels)}")
        return 0.0

    # Compute precision-recall curve
    precision, recall, _ = precision_recall_curve(all_labels, all_scores)

    # Compute area under curve
    auprc_score = auc(recall, precision)

    return auprc_score

In [110]:
auprc = _compute_auprc(attention_scores, human_rationales, attention_masks_list, input_ids_list)
print(f"Token-level AUPRC: {auprc}")

Token-level AUPRC: 0.15848659513098923


In [111]:
from sklearn.metrics import precision_recall_curve, auc, f1_score, precision_score, recall_score
def _compute_token_f1(
    hard_predictions: List[np.ndarray],
    human_rationales: List[torch.Tensor],
    attention_masks_list: List[torch.Tensor]
) -> Tuple[float, float, float]:
    """
    Compute token-level F1, Precision, Recall

    Treats each token as binary classification problem
    """
    all_preds = []
    all_labels = []

    for pred, rat, mask in zip(hard_predictions, human_rationales, attention_masks_list):
        # Only consider non-padding tokens
        valid_positions = mask.bool().cpu().numpy()

        all_preds.extend(pred[valid_positions].astype(int).tolist())
        all_labels.extend(rat[valid_positions].cpu().numpy().astype(int).tolist())

    # Convert to arrays and ensure integer type
    all_preds = np.array(all_preds, dtype=int)
    all_labels = np.array(all_labels, dtype=int)

    # Compute metrics
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)

    return f1, precision, recall

In [112]:
token_f1 = _compute_token_f1(hard_predictions, human_rationales, attention_masks_list)
print(f"Token-level F1: {token_f1[0]:.4f}")
print(f"Token-level Precision: {token_f1[1]:.4f}")
print(f"Token-level Recall: {token_f1[2]:.4f}")

Token-level F1: 0.1890
Token-level Precision: 0.1756
Token-level Recall: 0.2047
