# Task 4b: Detecting Cultural Bias Across Subgroups Using UniBias

This notebook uses UniBias to detect bias across identifiable subgroups (pillars) in the Total Defence Meme dataset.

**Dataset**: SEACrowd/total_defense_meme
- **Subgroups (Pillars)**: military, civil, economic, social, psychological, digital, others
- **Task**: Stance classification (against, neutral, supportive)
- **Goal**: Detect if model predictions are biased toward certain stances for certain pillars


In [1]:
# Install required packages for UniBias adaptation experiment
!pip install -q torch transformers datasets==2.19.1 seacrowd numpy pandas tqdm

In [2]:
import torch
import numpy as np
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
import random
from tqdm import tqdm
import torch.nn.functional as F
from collections import defaultdict
import pandas as pd

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


Using device: cuda


## Step 1: Load Model and Dataset

We'll use GPT-2 as our language model. GPT-2 uses **attention heads** - internal components that determine which parts of the input the model focuses on. We'll identify which of these heads contribute to bias.

**Key Setup**:
- Model: GPT-2 (124M parameters, 12 layers, 12 heads per layer = 144 total heads)
- Attention Mode: `eager` (required to support `head_mask` for masking specific heads)
- Device: CUDA (GPU) for faster computation


In [3]:
# Load GPT-2 model
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained("gpt2", attn_implementation="eager")
model.config.pad_token_id = tokenizer.eos_token_id
model = model.to(DEVICE)
model.eval()

print(f"Model loaded: {model.config.name_or_path}")
print(f"Layers: {model.config.n_layer}, Heads per layer: {model.config.n_head}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Model loaded: gpt2
Layers: 12, Heads per layer: 12


In [4]:
import seacrowd as sc
tdm = sc.load_dataset("total_defense_meme", schema="seacrowd_imtext")
print("Dataset loaded successfully!")
print(f"Dataset splits: {list(tdm.keys())}")
print(f"Train size: {len(tdm['train'])}")
print("\nFirst sample structure:")
print(tdm["train"][0])
print("\nDataset features/columns:")
print(tdm["train"].features)


 22%|██▏       | 89/402 [00:03<00:21, 14.73it/s]

Install the `pyreadr` package to use.
Install the `pyreadr` package to use.


 60%|█████▉    | 240/402 [00:08<00:04, 33.68it/s]

Please install audiosegment to use the `national_speech_corpus_sg_imda` dataloader.
Please install textgrid to use the `national_speech_corpus_sg_imda` dataloader.
Please install audiosegment to use the `national_speech_corpus_sg_imda` dataloader.
Please install textgrid to use the `national_speech_corpus_sg_imda` dataloader.


 70%|███████   | 282/402 [00:09<00:03, 39.78it/s]

Install the `audiosegment` package to use.
Install the `textgrid` package to use.
Install the `audiosegment` package to use.
Install the `textgrid` package to use.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|██████████| 402/402 [00:26<00:00, 15.11it/s]
Downloading...
From (original): https://drive.google.com/uc?id=1oJIh4QQS3Idff2g6bZORstS5uBROjUUz
From (redirected): https://drive.google.com/uc?id=1oJIh4QQS3Idff2g6bZORstS5uBROjUUz&confirm=t&uuid=f3fdbb27-b911-48c0-814e-471af35e0aa9
To: /content/data/total_defense_meme/total_defense_meme.zip
100%|██████████| 767M/767M [00:11<00:00, 64.0MB/s]


Downloading data:   0%|          | 0.00/1.45M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset loaded successfully!
Dataset splits: ['train']
Train size: 7012

First sample structure:
{'id': '0', 'image_paths': ['/root/.cache/huggingface/datasets/downloads/extracted/f03b95ee59c393962bf7e1227024f7c3253d96d9937914cc57b444283a5cf2fb/TD_Memes/img_001.jpg'], 'texts': 'Types of accounts Your How contributions are allocated Ultimate CPF for the self-employed and contract staff Guide to Interest rates CPF The CPF LIFE Scheme Where your CPF money goes after you die PLANNERBEE', 'metadata': {'tags': None, 'pillar_stances': {'category': [], 'stance': []}}}

Dataset features/columns:
{'id': Value(dtype='string', id=None), 'image_paths': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'texts': Value(dtype='string', id=None), 'metadata': {'tags': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'pillar_stances': Sequence(feature={'category': ClassLabel(names=['Social', 'Economic', 'Psychological', 'Military', 'Civil', 'Digital', 'Others'], id

## Step 2: Prepare Data - Extract Text and Labels

**Dataset**: Total Defence Meme dataset contains memes about Singapore's Total Defence policy.

**What we extract**:
- **Text**: OCR-extracted text from memes (from `texts` field)
- **Pillar**: Subgroup category (military, civil, economic, social, psychological, digital, others)
- **Stance**: Label (against, neutral, supportive)

**Why subgroups matter**: We want to detect if the model shows **representational bias** - treating different pillars differently even when the content is similar. This is a form of bias where certain subgroups receive different predictions.


In [6]:
def prepare_data_from_dataset(data):
    """Extract text, stance labels, and pillar labels from total_defense_meme (schema='seacrowd_imtext')."""
    samples = []

    # Introspect label names from the dataset features
    features = data.features
    ps_feat = features["metadata"]["pillar_stances"].feature

    # ClassLabel objects for pillar category and stance
    pillar_names = ps_feat["category"].names
    stance_names = ps_feat["stance"].feature.names

    for entry in data:
        # 1. Get text (field is 'texts' for this dataset)
        text = entry.get("texts", None)
        if not text:
            continue

        # 2. Get pillar_stances from metadata
        metadata = entry.get("metadata", {})
        pillar_stances = metadata.get("pillar_stances", {})

        categories = pillar_stances.get("category") or []
        stances = pillar_stances.get("stance") or []

        # Skip examples with no labels
        if not categories or not stances:
            continue

        # 3. Decode pillar (take the first category index)
        cat_idx = categories[0]
        if not isinstance(cat_idx, int) or not (0 <= cat_idx < len(pillar_names)):
            continue
        pillar = pillar_names[cat_idx]

        # 4. Decode stance (handle possible nested list structure)
        first_stance = stances[0]
        if isinstance(first_stance, list):
            if not first_stance:
                continue
            stance_idx = first_stance[0]
        else:
            stance_idx = first_stance

        if not isinstance(stance_idx, int) or not (0 <= stance_idx < len(stance_names)):
            continue
        stance = stance_names[stance_idx]

        # 5. Append clean sample
        samples.append({
            "text": str(text),
            "stance": stance.lower(),
            "pillar": pillar.lower(),
        })

    return samples


# Use the loaded dataset
if 'tdm' not in locals() or tdm is None or 'train' not in tdm:
    raise ValueError("Dataset 'tdm' not loaded. Please run the previous cell to load the dataset first.")

data = tdm['train']
samples = prepare_data_from_dataset(data)

if len(samples) == 0:
    raise ValueError("No samples extracted from dataset. Please check the dataset structure and field extraction logic.")

print(f"Prepared {len(samples)} samples with text and labels")

# Analyze distribution
from collections import Counter
stances = [s['stance'] for s in samples]
pillars = [s['pillar'] for s in samples]

print(f"\nStance distribution:")
print(Counter(stances))

print(f"\nPillar distribution:")
print(Counter(pillars))

print(f"\nSample entries:")
for i, s in enumerate(samples[:3]):
    print(f"  {i+1}. Pillar: {s['pillar']}, Stance: {s['stance']}")
    print(f"     Text: {s['text'][:100]}...")

# Use a sample for faster processing
if len(samples) > 50:
    print(f"\nUsing first 50 samples for faster processing...")
    samples = samples[:50]


Prepared 2512 samples with text and labels

Stance distribution:
Counter({'neutral': 1137, 'against': 1079, 'supportive': 296})

Pillar distribution:
Counter({'others': 946, 'military': 724, 'psychological': 377, 'economic': 209, 'civil': 158, 'social': 59, 'digital': 39})

Sample entries:
  1. Pillar: economic, Stance: neutral
     Text: When a HDB flat finishes it's 99-year lease: This is so sad can we hit a pedestrian with escooters?...
  2. Pillar: economic, Stance: against
     Text: when you've searched everywhere and still can't find your CPF This is what it means to be Singaporea...
  3. Pillar: economic, Stance: against
     Text: VOLUNTARY TOP UP MY CPF SPECIAL ACCOUNT??...

Using first 50 samples for faster processing...


## Step 3: Create Prompts and Define Answer Tokens

**Task Design**: We frame this as a **stance classification** task using few-shot prompting. This is compatible with UniBias methodology, which works best with discrete outputs (classification tasks).

**Prompt Format**: Few-shot examples showing the model how to classify stance, followed by the meme text.

**Answer Tokens**: We need to find which token IDs correspond to each stance label (against, neutral, supportive). UniBias measures bias by examining logit differences between these tokens.


In [7]:
# Create prompts for stance classification
def create_stance_prompt(text):
    """Create a prompt for stance classification"""
    # Few-shot examples
    demonstration = """Classify the stance toward Total Defence:
Text: Total Defence is essential for Singapore's security
Stance: supportive

Text: I disagree with Total Defence policies
Stance: against

Text: Total Defence has both pros and cons
Stance: neutral

"""
    return demonstration + f"Text: {text}\nStance:"

# Create prompts
prompts = [create_stance_prompt(s['text']) for s in samples]
labels = [s['stance'] for s in samples]
pillars = [s['pillar'] for s in samples]

print(f"Created {len(prompts)} prompts")
print(f"\nExample prompt:\n{prompts[0][:200]}...")


Created 50 prompts

Example prompt:
Classify the stance toward Total Defence:
Text: Total Defence is essential for Singapore's security
Stance: supportive

Text: I disagree with Total Defence policies
Stance: against

Text: Total Defenc...


In [8]:
# Answer tokens for stance classification
ans_token_list = [['against'], ['neutral'], ['supportive']]

def find_possible_ids_for_labels(ans_token_list, tokenizer):
    """Find all possible token IDs for each label"""
    ids_dict = {}
    for ans_tuple in ans_token_list:
        label = ans_tuple[0]
        ids_dict[label] = []
        # Try encoding the label directly
        encoded = tokenizer.encode(label, add_special_tokens=False)
        if encoded:
            ids_dict[label].extend(encoded)
        # Also search for tokens that contain the label
        for token_id in range(min(1000, len(tokenizer))):  # Limit search for speed
            try:
                decoded = tokenizer.decode([token_id])
                if label.lower() in decoded.lower() and len(decoded) > 0:
                    if token_id not in ids_dict[label]:
                        ids_dict[label].append(token_id)
            except:
                continue
    return ids_dict

gt_ans_ids_dict = find_possible_ids_for_labels(ans_token_list, tokenizer)
print(f"Token IDs found:")
for label, ids in gt_ans_ids_dict.items():
    print(f"  {label}: {len(ids)} token IDs (showing first 5: {ids[:5]})")


Token IDs found:
  against: 1 token IDs (showing first 5: [32826])
  neutral: 1 token IDs (showing first 5: [29797])
  supportive: 2 token IDs (showing first 5: [11284, 425])


## Step 4: Baseline Evaluation - Measure Predictions by Pillar


### Subgroup Analysis Results

This analysis reveals if the model shows **representational bias** - different prediction patterns across different pillars (subgroups). 

**What to look for**:
- **Different favored stances**: Do some pillars favor "against" while others favor "neutral"?
- **Different bias magnitudes**: Do some pillars show stronger bias (larger logit differences) than others?
- **Accuracy differences**: Does the model perform better on some pillars than others?

If we see systematic differences, this indicates **representational bias** - the model is treating subgroups differently.


In [9]:
def compute_predictions_by_pillar(model, tokenizer, prompts, labels, pillars, gt_ans_ids_dict):
    """Compute predictions and analyze by pillar"""
    predictions = []
    logits_by_stance = {'against': [], 'neutral': [], 'supportive': []}
    results_by_pillar = defaultdict(lambda: {'predictions': [], 'labels': [], 'logits': {'against': [], 'neutral': [], 'supportive': []}})

    for i, prompt in enumerate(tqdm(prompts, desc="Computing predictions")):
        encoded = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        pillar = pillars[i]
        true_label = labels[i]

        with torch.no_grad():
            outputs = model(**encoded)
            logits = outputs.logits[0, -1]  # Last token logits

            # Get logits for each stance
            stance_logits = {}
            for stance in ['against', 'neutral', 'supportive']:
                stance_ids = gt_ans_ids_dict[stance]
                if stance_ids:
                    stance_logit = max([logits[id].item() for id in stance_ids if id < len(logits)], default=-100)
                else:
                    stance_logit = -100
                stance_logits[stance] = stance_logit
                logits_by_stance[stance].append(stance_logit)
                results_by_pillar[pillar]['logits'][stance].append(stance_logit)

            # Prediction: highest logit
            predicted_stance = max(stance_logits.items(), key=lambda x: x[1])[0]
            predictions.append(predicted_stance)
            results_by_pillar[pillar]['predictions'].append(predicted_stance)
            results_by_pillar[pillar]['labels'].append(true_label)

    return predictions, logits_by_stance, results_by_pillar

predictions, logits_by_stance, results_by_pillar = compute_predictions_by_pillar(
    model, tokenizer, prompts, labels, pillars, gt_ans_ids_dict
)

print(f"\nOverall Predictions:")
print(f"  Against: {predictions.count('against')} ({predictions.count('against')/len(predictions)*100:.1f}%)")
print(f"  Neutral: {predictions.count('neutral')} ({predictions.count('neutral')/len(predictions)*100:.1f}%)")
print(f"  Supportive: {predictions.count('supportive')} ({predictions.count('supportive')/len(predictions)*100:.1f}%)")

print(f"\nAverage Logits by Stance:")
for stance, logit_list in logits_by_stance.items():
    if logit_list:
        print(f"  {stance}: {np.mean(logit_list):.4f}")


Computing predictions: 100%|██████████| 50/50 [00:01<00:00, 26.76it/s]


Overall Predictions:
  Against: 19 (38.0%)
  Neutral: 29 (58.0%)
  Supportive: 2 (4.0%)

Average Logits by Stance:
  against: -109.0636
  neutral: -108.7679
  supportive: -109.8780





In [10]:
# Analyze predictions by pillar
from collections import Counter

print("\n" + "="*80)
print("PREDICTIONS BY PILLAR (Subgroup Analysis)")
print("="*80)

for pillar in sorted(results_by_pillar.keys()):
    pillar_data = results_by_pillar[pillar]
    if len(pillar_data['predictions']) == 0:
        continue

    pred_counts = Counter(pillar_data['predictions'])
    label_counts = Counter(pillar_data['labels'])

    print(f"\nPillar: {pillar.upper()}")
    print(f"  Samples: {len(pillar_data['predictions'])}")
    print(f"  Predictions: {dict(pred_counts)}")
    print(f"  True Labels: {dict(label_counts)}")

    # Calculate average logits for this pillar
    avg_logits = {}
    for stance in ['against', 'neutral', 'supportive']:
        if pillar_data['logits'][stance]:
            avg_logits[stance] = np.mean(pillar_data['logits'][stance])
        else:
            avg_logits[stance] = -100

    print(f"  Avg Logits: {avg_logits}")

    # Calculate bias: difference between highest and lowest logit
    if avg_logits:
        max_logit = max(avg_logits.values())
        min_logit = min(avg_logits.values())
        bias_magnitude = max_logit - min_logit
        favored_stance = max(avg_logits.items(), key=lambda x: x[1])[0]
        print(f"  Bias: Favors '{favored_stance}' (magnitude: {bias_magnitude:.4f})")



PREDICTIONS BY PILLAR (Subgroup Analysis)

Pillar: CIVIL
  Samples: 4
  Predictions: {'neutral': 2, 'against': 2}
  True Labels: {'neutral': 2, 'against': 1, 'supportive': 1}
  Avg Logits: {'against': np.float64(-111.74154472351074), 'neutral': np.float64(-111.31406593322754), 'supportive': np.float64(-112.71752738952637)}
  Bias: Favors 'neutral' (magnitude: 1.4035)

Pillar: DIGITAL
  Samples: 5
  Predictions: {'neutral': 4, 'against': 1}
  True Labels: {'against': 5}
  Avg Logits: {'against': np.float64(-108.34479370117188), 'neutral': np.float64(-108.32415466308593), 'supportive': np.float64(-109.67637939453125)}
  Bias: Favors 'neutral' (magnitude: 1.3522)

Pillar: ECONOMIC
  Samples: 16
  Predictions: {'against': 8, 'neutral': 6, 'supportive': 2}
  True Labels: {'neutral': 3, 'against': 11, 'supportive': 2}
  Avg Logits: {'against': np.float64(-110.32508373260498), 'neutral': np.float64(-110.53112506866455), 'supportive': np.float64(-110.97400045394897)}
  Bias: Favors 'against' 

### Baseline Bias Score

This measures the overall bias in the model's stance predictions. Higher standard deviation means the model strongly favors one stance over others.

**Interpretation**:
- High bias score = model strongly favors one stance (e.g., always predicts "neutral")
- Low bias score = model is more balanced across stances


## Step 5: Detect Bias Using UniBias - Identify Biased Heads


### Head Identification Process

**What we're doing**: Testing each attention head to see if masking it reduces bias.

**For each head**:
- Create a mask that zeros out only that specific head
- Run all prompts through the model with this mask
- Measure how much bias changes (delta)

**Delta interpretation**:
- **Positive delta** (e.g., 0.158): Masking this head reduces bias → this head contributes to bias
- **Negative delta** (e.g., -0.028): Masking this head increases bias → this head may be reducing bias

**Note**: We analyze first 2 layers (24 heads) for speed. Full analysis would check all 144 heads.


In [11]:
def compute_baseline_bias(model, tokenizer, prompts, gt_ans_ids_dict):
    """Compute baseline bias across all stances"""
    stance_logits = {'against': [], 'neutral': [], 'supportive': []}

    for prompt in prompts:
        encoded = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            outputs = model(**encoded)
            logits = outputs.logits[0, -1]

            for stance in ['against', 'neutral', 'supportive']:
                stance_ids = gt_ans_ids_dict[stance]
                if stance_ids:
                    stance_logit = max([logits[id].item() for id in stance_ids if id < len(logits)], default=-100)
                    stance_logits[stance].append(stance_logit)

    # Bias = variance in average logits (higher variance = more bias)
    avg_logits = {stance: np.mean(logits) for stance, logits in stance_logits.items() if logits}
    bias_score = np.std(list(avg_logits.values()))  # Standard deviation measures bias

    return bias_score, avg_logits

baseline_bias, baseline_avg_logits = compute_baseline_bias(model, tokenizer, prompts, gt_ans_ids_dict)
print(f"Baseline Bias Score (std of avg logits): {baseline_bias:.4f}")
print(f"Average Logits: {baseline_avg_logits}")


Baseline Bias Score (std of avg logits): 0.4694
Average Logits: {'against': np.float64(-109.06356399536133), 'neutral': np.float64(-108.76793167114258), 'supportive': np.float64(-109.87797927856445)}


### Selected Biased Heads

**Top 3 heads selected for masking**:
- These are the heads that, when masked, show the largest bias reduction
- Layer 0, Head 5 shows the strongest effect (delta: 0.1588)
- Early layers (0-1) contain the most biased heads

**Why mask these**: By removing these heads' contributions, we should reduce the model's bias toward certain stances.


In [12]:
def compute_head_contribution(model, tokenizer, prompts, gt_ans_ids_dict, layer_idx, head_idx, baseline_bias):
    """Compute how much a specific head contributes to bias"""
    num_layers = model.config.n_layer
    num_heads = model.config.n_head

    # Create head mask: mask this specific head
    head_mask = torch.ones(num_layers, num_heads, device=DEVICE)
    head_mask[layer_idx, head_idx] = 0.0

    stance_logits = {'against': [], 'neutral': [], 'supportive': []}

    for prompt in prompts:
        encoded = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            outputs = model(**encoded, head_mask=head_mask)
            logits = outputs.logits[0, -1]

            for stance in ['against', 'neutral', 'supportive']:
                stance_ids = gt_ans_ids_dict[stance]
                if stance_ids:
                    stance_logit = max([logits[id].item() for id in stance_ids if id < len(logits)], default=-100)
                    stance_logits[stance].append(stance_logit)

    avg_logits = {stance: np.mean(logits) for stance, logits in stance_logits.items() if logits}
    masked_bias = np.std(list(avg_logits.values()))

    # Delta: how much bias changes when this head is masked
    delta = baseline_bias - masked_bias

    return delta, masked_bias

print("Identifying biased attention heads...")
head_contributions = {}

num_layers = model.config.n_layer
num_heads = model.config.n_head

# Sample subset for speed (analyze first 2 layers)
SAMPLE_HEADS = True
if SAMPLE_HEADS:
    layers_to_check = [0, 1]
else:
    layers_to_check = range(num_layers)

for layer_idx in tqdm(layers_to_check, desc="Layers"):
    for head_idx in range(num_heads):
        delta, masked_bias = compute_head_contribution(
            model, tokenizer, prompts[:min(20, len(prompts))],  # Use subset for speed
            gt_ans_ids_dict, layer_idx, head_idx, baseline_bias
        )
        head_contributions[(layer_idx, head_idx)] = {
            'delta': delta,
            'masked_bias': masked_bias
        }

print(f"\nAnalyzed {len(head_contributions)} attention heads")


Identifying biased attention heads...


Layers: 100%|██████████| 2/2 [00:19<00:00,  9.59s/it]


Analyzed 24 attention heads





In [13]:
# Find top biased heads
sorted_heads = sorted(head_contributions.items(), key=lambda x: x[1]['delta'], reverse=True)

print("Top 10 heads by bias reduction (when masked):")
for (layer, head), stats in sorted_heads[:10]:
    print(f"Layer {layer}, Head {head}: delta={stats['delta']:.6f}, masked_bias={stats['masked_bias']:.4f}")

# Select top 3 heads for masking
top_biased_heads = [layer_head for layer_head, _ in sorted_heads[:3]]
print(f"\nSelected top 3 biased heads for masking: {top_biased_heads}")


Top 10 heads by bias reduction (when masked):
Layer 0, Head 5: delta=0.158836, masked_bias=0.3105
Layer 0, Head 11: delta=0.038258, masked_bias=0.4311
Layer 1, Head 2: delta=0.035952, masked_bias=0.4334
Layer 1, Head 10: delta=0.029752, masked_bias=0.4396
Layer 1, Head 4: delta=0.012714, masked_bias=0.4567
Layer 1, Head 3: delta=-0.005042, masked_bias=0.4744
Layer 1, Head 0: delta=-0.015707, masked_bias=0.4851
Layer 1, Head 9: delta=-0.025497, masked_bias=0.4949
Layer 0, Head 7: delta=-0.028404, masked_bias=0.4978
Layer 1, Head 11: delta=-0.040065, masked_bias=0.5094

Selected top 3 biased heads for masking: [(0, 5), (0, 11), (1, 2)]


### Creating Head Mask

We create a mask that zeros out the selected biased heads while keeping all other heads active. This allows us to test if removing these heads' contributions reduces bias.


## Step 6: Mask Biased Heads and Re-analyze by Pillar


### Results After Masking

**Key Changes**:
- **Supportive predictions increased**: From 4% to 16% (4x increase!)
- **Neutral predictions decreased**: From 58% to 48%
- **Against predictions decreased**: From 38% to 36%

**Interpretation**: Masking biased heads successfully reduced the model's strong bias toward "neutral" stance, making predictions more balanced. The model now predicts "supportive" more often, which was previously under-represented.


In [14]:
# Create head mask for top biased heads
num_layers = model.config.n_layer
num_heads = model.config.n_head
head_mask = torch.ones(num_layers, num_heads, device=DEVICE)

for layer_idx, head_idx in top_biased_heads:
    head_mask[layer_idx, head_idx] = 0.0
    print(f"Masking Layer {layer_idx}, Head {head_idx}")

print(f"\nMasked {len(top_biased_heads)} attention heads")


Masking Layer 0, Head 5
Masking Layer 0, Head 11
Masking Layer 1, Head 2

Masked 3 attention heads


In [15]:
# Re-analyze predictions by pillar after masking
def compute_predictions_by_pillar_masked(model, tokenizer, prompts, labels, pillars, gt_ans_ids_dict, head_mask):
    """Compute predictions with head mask"""
    predictions = []
    results_by_pillar = defaultdict(lambda: {'predictions': [], 'labels': [], 'logits': {'against': [], 'neutral': [], 'supportive': []}})

    for i, prompt in enumerate(tqdm(prompts, desc="Evaluating with mask")):
        encoded = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        pillar = pillars[i]
        true_label = labels[i]

        with torch.no_grad():
            outputs = model(**encoded, head_mask=head_mask)
            logits = outputs.logits[0, -1]

            stance_logits = {}
            for stance in ['against', 'neutral', 'supportive']:
                stance_ids = gt_ans_ids_dict[stance]
                if stance_ids:
                    stance_logit = max([logits[id].item() for id in stance_ids if id < len(logits)], default=-100)
                else:
                    stance_logit = -100
                stance_logits[stance] = stance_logit
                results_by_pillar[pillar]['logits'][stance].append(stance_logit)

            predicted_stance = max(stance_logits.items(), key=lambda x: x[1])[0]
            predictions.append(predicted_stance)
            results_by_pillar[pillar]['predictions'].append(predicted_stance)
            results_by_pillar[pillar]['labels'].append(true_label)

    return predictions, results_by_pillar

masked_predictions, masked_results_by_pillar = compute_predictions_by_pillar_masked(
    model, tokenizer, prompts, labels, pillars, gt_ans_ids_dict, head_mask
)

print(f"\nPredictions After Masking:")
print(f"  Against: {masked_predictions.count('against')} ({masked_predictions.count('against')/len(masked_predictions)*100:.1f}%)")
print(f"  Neutral: {masked_predictions.count('neutral')} ({masked_predictions.count('neutral')/len(masked_predictions)*100:.1f}%)")
print(f"  Supportive: {masked_predictions.count('supportive')} ({masked_predictions.count('supportive')/len(masked_predictions)*100:.1f}%)")


Evaluating with mask: 100%|██████████| 50/50 [00:01<00:00, 30.17it/s]


Predictions After Masking:
  Against: 18 (36.0%)
  Neutral: 24 (48.0%)
  Supportive: 8 (16.0%)





## Step 7: Compare Bias Across Pillars - Before vs After


## Summary and Interpretation

### Key Findings

1. **Representational Bias Detected**: Model predictions systematically differ across pillars
   - Economic & Psychological pillars favor "against" stance
   - All other pillars favor "neutral" stance
   - Social pillar shows highest bias magnitude (2.17)

2. **Head Identification Successful**: Identified 3 biased attention heads
   - Layer 0, Head 5 shows strongest effect (delta: 0.1588)
   - Early layers (0-1) contain most biased heads

3. **Mitigation Partially Effective**: 
   - 5 out of 7 pillars showed bias reduction ✅
   - 2 pillars (Digital, Psychological) showed increased bias ❌
   - Supportive predictions increased 4x (from 4% to 16%)

4. **Subgroup-Specific Effects**: Different pillars respond differently to masking, suggesting complex bias mechanisms that may require pillar-specific mitigation strategies.

### Implications

- **Subgroup analysis is essential**: Overall bias metrics can mask differential treatment of subgroups
- **Mitigation strategies should be subgroup-aware**: One-size-fits-all may not work
- **Complete masking may be too aggressive**: Some heads may have mixed effects (reducing bias for some subgroups, increasing for others)


In [16]:
print("="*80)
print("BIAS ANALYSIS BY PILLAR: Baseline vs After Masking")
print("="*80)

# Create comparison table
comparison_data = []

for pillar in sorted(set(pillars)):
    if pillar not in results_by_pillar or len(results_by_pillar[pillar]['predictions']) == 0:
        continue

    baseline_data = results_by_pillar[pillar]
    masked_data = masked_results_by_pillar[pillar]

    # Calculate bias for baseline
    baseline_avg_logits = {}
    for stance in ['against', 'neutral', 'supportive']:
        if baseline_data['logits'][stance]:
            baseline_avg_logits[stance] = np.mean(baseline_data['logits'][stance])
        else:
            baseline_avg_logits[stance] = -100

    baseline_bias_mag = max(baseline_avg_logits.values()) - min(baseline_avg_logits.values())
    baseline_favored = max(baseline_avg_logits.items(), key=lambda x: x[1])[0]

    # Calculate bias for masked
    masked_avg_logits = {}
    for stance in ['against', 'neutral', 'supportive']:
        if masked_data['logits'][stance]:
            masked_avg_logits[stance] = np.mean(masked_data['logits'][stance])
        else:
            masked_avg_logits[stance] = -100

    masked_bias_mag = max(masked_avg_logits.values()) - min(masked_avg_logits.values())
    masked_favored = max(masked_avg_logits.items(), key=lambda x: x[1])[0]

    bias_reduction = baseline_bias_mag - masked_bias_mag

    comparison_data.append({
        'Pillar': pillar,
        'Samples': len(baseline_data['predictions']),
        'Baseline_Favored': baseline_favored,
        'Baseline_Bias': f"{baseline_bias_mag:.4f}",
        'Masked_Favored': masked_favored,
        'Masked_Bias': f"{masked_bias_mag:.4f}",
        'Bias_Reduction': f"{bias_reduction:.4f}"
    })

    print(f"\nPillar: {pillar.upper()}")
    print(f"  Samples: {len(baseline_data['predictions'])}")
    print(f"  Baseline: Favors '{baseline_favored}' (bias: {baseline_bias_mag:.4f})")
    print(f"  After Masking: Favors '{masked_favored}' (bias: {masked_bias_mag:.4f})")
    print(f"  Bias Reduction: {bias_reduction:+.4f}")

# Create summary DataFrame
if comparison_data:
    df = pd.DataFrame(comparison_data)
    print("\n" + "="*80)
    print("SUMMARY TABLE")
    print("="*80)
    print(df.to_string(index=False))


BIAS ANALYSIS BY PILLAR: Baseline vs After Masking

Pillar: CIVIL
  Samples: 4
  Baseline: Favors 'neutral' (bias: 1.4035)
  After Masking: Favors 'neutral' (bias: 0.6006)
  Bias Reduction: +0.8029

Pillar: DIGITAL
  Samples: 5
  Baseline: Favors 'neutral' (bias: 1.3522)
  After Masking: Favors 'neutral' (bias: 1.6411)
  Bias Reduction: -0.2889

Pillar: ECONOMIC
  Samples: 16
  Baseline: Favors 'against' (bias: 0.6489)
  After Masking: Favors 'against' (bias: 0.2547)
  Bias Reduction: +0.3942

Pillar: MILITARY
  Samples: 15
  Baseline: Favors 'neutral' (bias: 1.6267)
  After Masking: Favors 'neutral' (bias: 1.5039)
  Bias Reduction: +0.1228

Pillar: OTHERS
  Samples: 3
  Baseline: Favors 'neutral' (bias: 1.1681)
  After Masking: Favors 'neutral' (bias: 0.8433)
  Bias Reduction: +0.3248

Pillar: PSYCHOLOGICAL
  Samples: 5
  Baseline: Favors 'against' (bias: 0.7996)
  After Masking: Favors 'supportive' (bias: 0.8341)
  Bias Reduction: -0.0345

Pillar: SOCIAL
  Samples: 2
  Baseline: Favo