In [1]:
import shutil
import os

# Delete HateDeRC directory if it exists
if os.path.exists('HateDeRC'):
  shutil.rmtree('HateDeRC')
!git clone https://github.com/jamesalv/HateDeRC
%cd HateDeRC

Cloning into 'HateDeRC'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (33/33), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 33 (delta 12), reused 31 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (33/33), 4.52 MiB | 3.88 MiB/s, done.
Resolving deltas: 100% (12/12), done.
/content/HateDeRC


In [None]:
from TrainingConfig import TrainingConfig
from typing import Dict, Any, Tuple, List
import numpy as np
import torch
from transformers import AutoTokenizer
import json

In [None]:
data_path = 'Data/dataset.json'

In [None]:
config = TrainingConfig()

In [None]:
# Seed all randomness for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(config.seed)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)

# Preprocessing

In [None]:
import re
import string

def deobfuscate_text(text):
    """
    Normalize common text obfuscation patterns to reveal original words.
    Useful for hate speech detection and content analysis.

    Args:
        text (str): Input text with potential obfuscations

    Returns:
        str: Text with obfuscations normalized
    """
    if not isinstance(text, str):
        return text

    # Make a copy to work with
    result = text.lower()

    # 1. Handle asterisk/symbol replacements
    symbol_patterns = {
        # Common profanity
        r'f\*+c?k': 'fuck',
        r'f\*+': 'fuck',
        r's\*+t': 'shit',
        r'b\*+ch': 'bitch',
        r'a\*+s': 'ass',
        r'd\*+n': 'damn',
        r'h\*+l': 'hell',
        r'c\*+p': 'crap',

        # Slurs and hate speech terms (be comprehensive for detection)
        r'n\*+g+[aer]+': 'nigger',  # Various n-word obfuscations
        r'f\*+g+[ot]*': 'faggot',
        r'r\*+[dt]ard': 'retard',
        r'sp\*+c': 'spic',

        # Other symbols
        r'@ss': 'ass',
        r'b@tch': 'bitch',
        r'sh!t': 'shit',
        r'f#ck': 'fuck',
        r'd@mn': 'damn',
    }

    for pattern, replacement in symbol_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 2. Handle character spacing (f u c k -> fuck)
    spacing_patterns = {
        r'\bf\s+u\s+c\s+k\b': 'fuck',
        r'\bs\s+h\s+i\s+t\b': 'shit',
        r'\bd\s+a\s+m\s+n\b': 'damn',
        r'\bh\s+e\s+l\s+l\b': 'hell',
        r'\ba\s+s\s+s\b': 'ass',
        r'\bc\s+r\s+a\s+p\b': 'crap',
    }

    for pattern, replacement in spacing_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 3. Handle number/letter substitutions
    leet_patterns = {
        # Basic leet speak
        r'\b3\s*1\s*1\s*3\b': 'elle',  # 3113 -> elle
        r'\bf4g\b': 'fag',
        r'\bf4gg0t\b': 'faggot',
        r'\bn00b\b': 'noob',
        r'\bl33t\b': 'leet',
        r'\bh4t3\b': 'hate',
        r'\b5h1t\b': 'shit',
        r'\bf0ck\b': 'fock',
    }

    for pattern, replacement in leet_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 4. Handle repeated characters and separators
    # Remove excessive punctuation between letters
    result = re.sub(r'([a-z])[^\w\s]+([a-z])', r'\1\2', result)

    # Handle underscore separation
    result = re.sub(r'([a-z])_+([a-z])', r'\1\2', result)

    # Handle dot separation
    result = re.sub(r'([a-z])\.+([a-z])', r'\1\2', result)

    # 5. Handle common misspellings/variations used for evasion
    evasion_patterns = {
        r'\bfuk\b': 'fuck',
        r'\bfuq\b': 'fuck',
        r'\bfck\b': 'fuck',
        r'\bshyt\b': 'shit',
        r'\bshit\b': 'shit',
        r'\bbiatch\b': 'bitch',
        r'\bbeatch\b': 'bitch',
        r'\basshole\b': 'asshole',
        r'\ba55hole\b': 'asshole',
        r'\btard\b': 'retard',
        r'\bfagg\b': 'fag',
    }

    for pattern, replacement in evasion_patterns.items():
        result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)

    # 6. Clean up multiple spaces
    result = re.sub(r'\s+', ' ', result).strip()

    return result

In [None]:
def aggregate_rationales(rationales, labels, post_length, drop_abnormal=False):
    """
    If all 3 annotators are normal → 3 zero spans → average (all zeros).
    If k annotators are non-normal and k spans exist → average the k spans (no added zeros).
    If k non-normal but fewer than k spans:
        If the missing annotators are non-normal → do not fill with zeros; average only existing spans and record rationale_support = #spans.
        If the missing annotators are normal (e.g., 2 hate + 1 normal + 2 spans) → append one zero span for the normal.
    """
    count_normal = labels.count(0)
    count_hate = labels.count(1)
    count_rationales = len(rationales)
    pad = np.zeros(post_length, dtype="int").tolist()

    # If there are hate labels but no rationales, something is wrong
    if count_hate > 0 and count_rationales == 0:
        if drop_abnormal:
            return None

        # Else just fill with 0
        return np.zeros(post_length).tolist()

    # If all annotators are normal, return all zeros
    if count_normal == 3:
        return np.zeros(post_length).tolist()

    # If we have hate annotators
    if count_hate > 0:
        # Case 1: Number of rationales matches number of hate annotators
        if count_rationales == count_hate:
            return np.average(rationales, axis=0).tolist()

        # Case 2: Fewer rationales than hate annotators
        elif count_rationales < count_hate:
            # Add zero padding for normal annotators only
            rationales_copy = rationales.copy()
            zeros_to_add = count_normal
            for _ in range(zeros_to_add):
                rationales_copy.append(pad)
            return np.average(rationales_copy, axis=0).tolist()

        # Case 3: More rationales than hate annotators (shouldn't happen normally)
        else:
            # Just average what we have
            return np.average(rationales, axis=0).tolist()

    # Fallback: return zeros if no clear case matches
    return np.zeros(post_length).tolist()

In [None]:
from typing import List, Tuple

def preprocess_text(raw_text):
    preprocessed_text = raw_text
    # # Remove HTML tags <>
    preprocessed_text = preprocessed_text.replace("<", "").replace(">", "")
    # # De-Obsfucate Patterns
    preprocessed_text = deobfuscate_text(preprocessed_text)

    return preprocessed_text


def create_text_segment(
    text_tokens: List[str], rationale_mask: List[int]
) -> List[Tuple[List[str], int]]:
    """
    Process a rationale mask to identify contiguous segments of highlighted text.
    Then create a segmented representation of the tokens

    Args:
        text_tokens: Original text tokens
        mask: Binary mask where 1 indicates a highlighted token (this consists of mask from 3 annotators)

    Returns:
        A list of tuples (text segment, mask value)
    """
    # Handle case where mask is empty (no rationale provided), usually this is normal classification
    mask = rationale_mask

    # for mask in all_rationale_mask:
    # Find breakpoints (transitions between highlighted/1 and non-highlighted/0)
    breakpoints = []
    mask_values = []

    # Always start with position 0
    breakpoints.append(0)
    mask_values.append(mask[0])

    # Find transitions in the mask
    for i in range(1, len(mask)):
        if mask[i] != mask[i - 1]:
            breakpoints.append(i)
            mask_values.append(mask[i])

    # Always end with the length of the text
    if breakpoints[-1] != len(mask):
        breakpoints.append(len(mask))

    # Create segments based on breakpoints
    segments = []
    for i in range(len(breakpoints) - 1):
        start = breakpoints[i]
        end = breakpoints[i + 1]
        segments.append((text_tokens[start:end], mask_values[i]))

    return segments


def align_rationales(tokens, rationales, tokenizer, max_length=128):
    """
    Align rationales with tokenized text while handling different tokenizer formats.

    Args:
        tokens: Original text tokens
        rationales: Original rationale masks
        tokenizer: The tokenizer to use
        max_length: Maximum sequence length

    Returns:
        Dictionary with tokenized inputs and aligned rationale masks
    """
    segments = create_text_segment(tokens, rationales)
    all_human_rationales = []
    all_input_ids = []
    all_attention_mask = []
    all_token_type_ids = []
    all_rationales = []
    for text_segment, rationale_value in segments:
        inputs = {}
        concatenated_text = " ".join(text_segment)
        processed_segment = preprocess_text(concatenated_text)
        tokenized = tokenizer(
            processed_segment, add_special_tokens=False, return_tensors="pt"
        )

        # Extract the relevant data
        segment_input_ids = tokenized["input_ids"][0]
        segment_attention_mask = tokenized["attention_mask"][0]
        # Handle token_type_ids if present
        if "token_type_ids" in tokenized:
            segment_token_type_ids = tokenized["token_type_ids"][0]
            all_token_type_ids.extend(segment_token_type_ids)

        # Add input IDs and attention mask
        all_input_ids.extend(segment_input_ids)
        all_attention_mask.extend(segment_attention_mask)

        # Add rationales (excluding special tokens)
        segment_rationales = [rationale_value] * len(segment_input_ids)
        all_rationales.extend(segment_rationales)
    # Get special token IDs
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id

    # Add special tokens at the beginning and end
    all_input_ids = [cls_token_id] + all_input_ids + [sep_token_id]
    all_attention_mask = [1] + all_attention_mask + [1]

    # Handle token_type_ids if the model requires it
    if hasattr(tokenizer, "create_token_type_ids_from_sequences"):
        all_token_type_ids = tokenizer.create_token_type_ids_from_sequences(
            all_input_ids[1:-1]
        )
    elif all_token_type_ids:
        all_token_type_ids = [0] + all_token_type_ids + [0]
    else:
        all_token_type_ids = [0] * len(all_input_ids)

    # Check tokenized vs rationales length
    if len(all_input_ids) != len(all_attention_mask):
        print("Warning: length of tokens and rationales do not match")

    # Add zero rationale values for special tokens
    all_rationales = [0] + all_rationales + [0]

    # Truncate to max length if needed
    if len(all_input_ids) > max_length:
        print("WARNING: NEED TO TRUNCATE")
        all_input_ids = all_input_ids[:max_length]
        all_attention_mask = all_attention_mask[:max_length]
        all_token_type_ids = all_token_type_ids[:max_length]
        all_rationales = all_rationales[:max_length]

    # Pad to max_length if needed
    pad_token_id = tokenizer.pad_token_id
    padding_length = max_length - len(all_input_ids)

    if padding_length > 0:
        all_input_ids = all_input_ids + [pad_token_id] * padding_length
        all_attention_mask = all_attention_mask + [0] * padding_length
        all_token_type_ids = all_token_type_ids + [0] * padding_length
        all_rationales = all_rationales + [0] * padding_length

    # Convert lists to tensors
    inputs = {
        "input_ids": torch.tensor([all_input_ids], dtype=torch.long),
        "attention_mask": torch.tensor([all_attention_mask], dtype=torch.long),
        "token_type_ids": (
            torch.tensor([all_token_type_ids], dtype=torch.long)
            if "token_type_ids" in tokenizer.model_input_names
            else None
        ),
        "rationales": torch.tensor([all_rationales], dtype=torch.float32),
    }

    # Remove None values
    inputs = {k: v for k, v in inputs.items() if v is not None}
    return inputs

In [None]:
import re
import string
from collections import Counter
from tqdm import tqdm
invalid_rationales_key = []
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
def process_raw_entries(data, drop_abnormal = False):
    """
    Process raw data entries
    """
    print("Processing raw entries...")
    processed_entries = {}
    dropped = 0

    for key, value in tqdm(data.items()):
        try:
            # Extract labels (1 = hate/offensive, 0 = normal)
            labels = []
            target_groups = []
            for annot in value["annotators"]:
                label = annot["label"]
                target = annot["target"]

                labels.append(1 if label in ['hatespeech', 'offensive'] else 0)
                target_groups.append(target)

            # Process rationales
            rationales = value.get("rationales", [])
            aggregated_rationale = aggregate_rationales(rationales, labels, len(value["post_tokens"]),  drop_abnormal=drop_abnormal)


            if aggregated_rationale is None:
                dropped += 1
                continue
            inputs = align_rationales(value['post_tokens'], aggregated_rationale, tokenizer)

            # Determine final label
            # Majority vote for hard label
            hard_label = Counter(labels).most_common(1)[0][0]
            # Average for soft label
            soft_label = sum(labels) / len(labels)

            # Determine Targets --> Considered as targets when there are at least 2 mention
            target_groups = [
              t
              for annot in value['annotators']
              for t in annot['target']
            ]
            filtered = [k for k, v in Counter(target_groups).items() if v > 2]

            # Store inputs and labels
            processed_entries[key] = {
                'input_ids': inputs['input_ids'],
                'attention_mask': inputs['attention_mask'],
                'rationales': inputs['rationales'],
                'raw_text': " ".join(value['post_tokens']), # Keep raw text for potential debugging/analysis
                'hard_label': hard_label,
                'soft_label': soft_label,
                'target_groups': filtered
            }

        except Exception as e:
            dropped += 1
            print(f"Error processing {key}: {e}")

    print(f"Processed: {len(processed_entries)}, Dropped: {dropped}")
    return processed_entries

In [None]:
with open(data_path, 'r') as file:
  data = json.load(file)
processed_data = process_raw_entries(data)

Processing raw entries...


 37%|███▋      | 7414/20148 [00:06<00:10, 1228.89it/s]



 89%|████████▉ | 18028/20148 [00:14<00:01, 1083.75it/s]

Error processing 24439295_gab: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.


100%|██████████| 20148/20148 [00:16<00:00, 1193.99it/s]

Processed: 20147, Dropped: 1





In [None]:
with open('Data/post_id_divisions.json') as file:
  post_id_divisions = json.load(file)

# Train
train_data = []
train_missing = 0
for train_key in post_id_divisions['train']:
  try:
    train_data.append(processed_data[train_key])
  except Exception as e:
    train_missing += 1
print(f"Train missing: {train_missing}")

# Val
val_data = []
val_missing = 0
for val_key in post_id_divisions['val']:
  try:
    val_data.append(processed_data[val_key])
  except Exception as e:
    val_missing += 1
print(f"Val missing: {val_missing}")

# Test
test_data = []
test_missing = 0
for test_key in post_id_divisions['test']:
  try:
    test_data.append(processed_data[test_key])
  except Exception as e:
    test_missing += 1
print(f"Test missing: {test_missing}")

Train missing: 1
Val missing: 0
Test missing: 0


In [None]:
from HateDataset import HateDataset

# Create datasets with pre-tokenized data
train_dataset = HateDataset(data=train_data)
val_dataset = HateDataset(data=val_data)
test_dataset = HateDataset(data=test_data)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # Use shuffle=False for validation
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Use shuffle=False for testing

In [None]:
from HateClassifier import HateClassifier
model = HateClassifier(config)

In [None]:
history = model.train(train_dataloader=train_loader, val_dataloader=val_loader)

Training on device: cuda
Model: distilbert-base-uncased
Epochs: 2
Batch size: 32
Gradient accumulation steps: 1
Effective batch size: 32
Learning rate: 1e-05
Mixed precision (AMP): True
Gradient clipping: 1.0

Epoch 1/2


Training: 100%|██████████| 481/481 [01:38<00:00,  4.86batch/s, loss=0.517]
Evaluating: 100%|██████████| 61/61 [00:03<00:00, 17.85batch/s]



Epoch 1 Summary:
  Train Loss: 0.5174
  Val Loss:   0.4670
  Val Acc:    0.7690
  Val F1:     0.7652
  ✓ New best model saved! (F1: 0.7652)

Epoch 2/2


Training: 100%|██████████| 481/481 [01:38<00:00,  4.91batch/s, loss=0.415]
Evaluating: 100%|██████████| 61/61 [00:03<00:00, 18.05batch/s]



Epoch 2 Summary:
  Train Loss: 0.4155
  Val Loss:   0.4440
  Val Acc:    0.7893
  Val F1:     0.7813
  ✓ New best model saved! (F1: 0.7813)

Training completed!
Best F1 Score: 0.7813
Training history saved to: ./checkpoints/training_history.json


# Evaluation

In [None]:
import pickle

result = model.predict(test_dataloader=test_loader, return_attentions=True)

# Save the result to a file
with open('prediction_results.pkl', 'wb') as f:
  pickle.dump(result, f)

print("Results saved to prediction_results.pkl")

Running inference on 61 batches...


Testing: 100%|██████████| 61/61 [00:03<00:00, 17.07batch/s]


Test Results:
  Test Loss:     0.4325
  Test Accuracy: 0.8035
  Test F1:       0.7956
Results saved to prediction_results.pkl





## Bias

In [None]:
def get_bias_evaluation_samples(data, method, group):
    """
    Get positive and negative sample IDs for bias evaluation based on method and group

    Args:
        data: list of data entries
        method: Bias evaluation method ('subgroup', 'bpsn', or 'bnsp')
        group: Target group to evaluate

    Returns:
        Tuple of (positive_ids, negative_ids)
    """
    positive_ids = []
    negative_ids = []

    for idx, row in enumerate(data):
        target_groups = row['target_groups']
        if target_groups is None:
            continue

        is_in_group = group in target_groups

        # Convert various label formats to binary toxic/non-toxic
        if 'hard_label' in row:
            is_toxic = row['hard_label'] == 1
        else:
            continue

        if method == 'subgroup':
            # Only consider samples mentioning the group
            if is_in_group:
                if is_toxic:
                    positive_ids.append(idx)
                else:
                    negative_ids.append(idx)

        elif method == 'bpsn':
            # Compare non-toxic posts mentioning the group with toxic posts NOT mentioning the group
            if is_in_group and not is_toxic:
                negative_ids.append(idx)
            elif not is_in_group and is_toxic:
                positive_ids.append(idx)

        elif method == 'bnsp':
            # Compare toxic posts mentioning the group with non-toxic posts NOT mentioning the group
            if is_in_group and is_toxic:
                positive_ids.append(idx)
            elif not is_in_group and not is_toxic:
                negative_ids.append(idx)

    return positive_ids, negative_ids

In [None]:
from collections import defaultdict
from sklearn.metrics import roc_auc_score

def calculate_gmb_metrics(
    test_data: List[Dict[str, Any]],
    probabilities: np.ndarray,
    target_groups: List[str]
):
    """
    Calculate GMB (Generalized Mean of Bias) AUC metrics from model predictions

    Args:
        probabilities: Model's probability outputs
        test_data: List of test data entries
        target_groups: List of target groups to evaluate

    Returns:
        Dictionary with GMB metrics
    """
    # Create mappings from post_id to predictions and ground truth
    prediction_scores = defaultdict(lambda: defaultdict(dict))
    ground_truth = {}

    for idx, row in enumerate(test_data):
        prediction_scores[idx] = probabilities[idx, 1]
        ground_truth[idx] = row['hard_label']

    # Calculate metrics for each target group and method
    bias_metrics = {}
    methods = ['subgroup', 'bpsn', 'bnsp']

    for method in methods:
        bias_metrics[method] = {}  # Initialize nested dictionary for each method
        for group in target_groups:
            # Get positive and negative samples based on the method
            positive_ids, negative_ids = get_bias_evaluation_samples(test_data, method, group)

            if len(positive_ids) == 0 or len(negative_ids) == 0:
                print(f"Skipping {method} for group {group}: no samples found")
                continue  # Skip if no samples for this group/method

            # Collect ground truth and predictions
            y_true = []
            y_score = []

            for post_id in positive_ids:
                if post_id in ground_truth and post_id in prediction_scores:
                    y_true.append(ground_truth[post_id])
                    y_score.append(prediction_scores[post_id])

            for post_id in negative_ids:
                if post_id in ground_truth and post_id in prediction_scores:
                    y_true.append(ground_truth[post_id])
                    y_score.append(prediction_scores[post_id])

            # Calculate AUC if we have enough samples with both classes
            if len(y_true) > 10 and len(set(y_true)) > 1:
                try:
                    auc = roc_auc_score(y_true, y_score)
                    bias_metrics[method][group] = auc
                except ValueError:
                    print(f"Could not compute AUC for {method} and group {group} due to ValueError")
                    pass

    # Calculate GMB for each method
    gmb_metrics = {}
    power = -5  # Power parameter for generalized mean

    for method in methods:
        if not bias_metrics[method]:
            continue

        scores = list(bias_metrics[method].values())
        if not scores:
            continue

        # Calculate generalized mean with p=-5
        power_mean = np.mean([score ** power for score in scores]) ** (1/power)
        gmb_metrics[f'GMB-{method.upper()}-AUC'] = power_mean

    # Calculate a combined GMB score that includes all methods
    all_scores = []
    for method in methods:
        all_scores.extend(list(bias_metrics[method].values()))

    if all_scores:
        gmb_metrics['GMB-COMBINED-AUC'] = np.mean([score ** power for score in all_scores]) ** (1/power)

    return gmb_metrics, bias_metrics

In [None]:
from collections import Counter
# Get top 10 most common target groups in the full dataset
all_target_groups = []
for _, value in processed_data.items():
  all_target_groups.extend(value['target_groups'])

# Remove None
all_target_groups = [group for group in all_target_groups if group != 'None' and group != 'Other']
counter = Counter(all_target_groups)

n_common = 10
bias_target_groups = [tg[0] for tg in counter.most_common(n_common)]

In [None]:
gmb_metrics, bias_details = calculate_gmb_metrics(
  test_data=test_data,
  probabilities=result['probabilities'],
  target_groups=bias_target_groups
)

In [None]:
print('GMB-Metrics')
for key, value in gmb_metrics.items():
  print(f'{key}: {value}')

GMB-Metrics
GMB-SUBGROUP-AUC: 0.8534172646685065
GMB-BPSN-AUC: 0.7304033627976568
GMB-BNSP-AUC: 0.8263201958983752
GMB-COMBINED-AUC: 0.7921965069548527


In [None]:
print('Bias Details')
print()
for key, entry in bias_details.items():
  print(f"Metrics: {key}")
  for subgroup, value in entry.items():
    print(f'{subgroup}: {value}')
  print()

Bias Details

Metrics: subgroup
African: 0.9025875190258752
Jewish: 0.763157894736842
Islam: 0.9626168224299064
Homosexual: 0.8809523809523809
Women: 0.7076502732240437
Refugee: 0.8181818181818181
Arab: 0.8620689655172413
Hispanic: 1.0
Asian: 1.0
Caucasian: 1.0

Metrics: bpsn
African: 0.7486457204767063
Jewish: 0.5104063429137761
Islam: 0.9359903381642511
Homosexual: 0.8928802588996765
Women: 0.864168979340117
Refugee: 0.9148181544935376
Arab: 0.7511230907457322
Hispanic: 0.8738819320214669
Asian: 0.9964539007092199
Caucasian: 0.9937888198757764

Metrics: bnsp
African: 0.9340918283011237
Jewish: 0.951770493070723
Islam: 0.9048422747038897
Homosexual: 0.8699740968274344
Women: 0.7859768463748522
Refugee: 0.7801746276322548
Arab: 0.9443684047860832
Hispanic: 0.9756188647033717
Asian: 0.8476312419974392
Caucasian: 0.6585470085470085



# XAI

In [None]:
from HateInterpreter import HateInterpreter
interpreter = HateInterpreter(model, tokenizer, config)

In [None]:
result

{'predictions': array([0, 1, 1, ..., 1, 1, 1], dtype=int64),
 'labels': array([0, 0, 1, ..., 1, 1, 1], dtype=int64),
 'probabilities': array([[0.9687085 , 0.0312915 ],
        [0.2739627 , 0.72603726],
        [0.06872476, 0.93127525],
        ...,
        [0.15539865, 0.8446014 ],
        [0.01472861, 0.98527133],
        [0.18557793, 0.81442213]], dtype=float32),
 'loss': 0.43250985223738875,
 'accuracy': 0.8035343035343036,
 'f1': 0.7956465102059854,
 'attentions': [array([0.05991167, 0.17943588, 0.04367437, 0.04170617, 0.01563822,
         0.01408985, 0.02001997, 0.03406667, 0.01396812, 0.01104525,
         0.01104024, 0.02985105, 0.00876457, 0.0191968 , 0.00935086,
         0.0308141 , 0.01516355, 0.00644445, 0.00868303, 0.00992157,
         0.00775399, 0.03111218, 0.03790326, 0.01293332, 0.00713368,
         0.00786988, 0.02202763, 0.03430669, 0.0116527 , 0.01215313,
         0.00932989, 0.01112575, 0.02966489, 0.18224663, 0.        ,
         0.        , 0.        , 0.        , 

In [None]:
input_ids_list = [d['input_ids'].squeeze() for d in test_data]
attention_masks_list = [d['attention_mask'].squeeze() for d in test_data]
human_rationales = [d['rationales'].squeeze() for d in test_data]
attention_scores = result['attentions']
predicted_classes = result['labels']
original_probs = result['probabilities']

In [None]:
interpretation_results = interpreter.compute_all_metrics(
    input_ids_list=input_ids_list,
    attention_masks_list=attention_masks_list,
    human_rationales=human_rationales,
    attention_scores=attention_scores,
    predicted_classes=predicted_classes,
    original_probs=original_probs
)

Using k=2 (average human rationale length)


ValueError: continuous format is not supported