# Model Bias Comparison

This notebook compares bias across different models, showing how bad the spatial vs descriptive bias is for each model.

In [None]:
import sys
import json
import warnings
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

sys.path.append("../")
from bias_steering.data.load_dataset import load_dataframe_from_json

%load_ext autoreload
%autoreload 2
warnings.filterwarnings("ignore")

In [None]:
# Configuration: List of models to compare
# Add paths for each model you want to compare
MODELS_TO_COMPARE = {
    "gpt2": Path("../runs_vision/gpt2"),
    # Add more models as you have them:
    # "Qwen-1_8B-chat": Path("../runs_vision/Qwen-1_8B-chat"),
    # "Mistral": Path("../runs_vision/Mistral-Nemo-Instruct-2407"),
}

# Colors for different models (use Plotly D3 palette)
import plotly.express as px
MODEL_COLORS = px.colors.qualitative.D3

## Load Bias Data for Each Model

In [None]:
def load_model_bias_data(artifact_path):
    """Load validation data with bias scores for a model."""
    try:
        val_data = load_dataframe_from_json(artifact_path / "datasplits/val.json")
        return val_data
    except Exception as e:
        print(f"⚠️  Could not load data from {artifact_path}: {e}")
        return None

# Load data for all models
model_data = {}
for model_name, artifact_path in MODELS_TO_COMPARE.items():
    data = load_model_bias_data(artifact_path)
    if data is not None:
        model_data[model_name] = data
        print(f"✓ {model_name}: {len(data)} validation examples")

print(f"\n✓ Loaded data for {len(model_data)} models")

## Compute Bias Statistics

In [None]:
def compute_bias_stats(data):
    """Compute bias statistics for a model."""
    if "bias" not in data.columns:
        return None
    
    biases = data["bias"].values
    stats = {
        "mean": np.mean(biases),
        "std": np.std(biases),
        "rms": np.sqrt(np.mean(biases**2)),  # RMS bias
        "abs_mean": np.mean(np.abs(biases)),  # Mean absolute bias
        "max": np.max(biases),
        "min": np.min(biases),
        "median": np.median(biases),
        "spatial_favored_pct": np.sum(biases > 0.01) / len(biases) * 100,  # % favoring spatial
        "descriptive_favored_pct": np.sum(biases < -0.01) / len(biases) * 100,  # % favoring descriptive
    }
    return stats

# Compute stats for each model
bias_stats = {}
for model_name, data in model_data.items():
    stats = compute_bias_stats(data)
    if stats:
        bias_stats[model_name] = stats
        print(f"\n{model_name}:")
        print(f"  RMS Bias: {stats['rms']:.4f}")
        print(f"  Mean Bias: {stats['mean']:.4f}")
        print(f"  Spatial favored: {stats['spatial_favored_pct']:.1f}%")
        print(f"  Descriptive favored: {stats['descriptive_favored_pct']:.1f}%")

## Plot 1: Bias Distribution Comparison

In [None]:
def plot_bias_distribution(model_data, bias_stats, width=600, height=400):
    """Plot bias distributions for all models."""
    fig = go.Figure()
    
    for idx, (model_name, data) in enumerate(model_data.items()):
        biases = data["bias"].values
        stats = bias_stats[model_name]
        
        # Add histogram
        fig.add_trace(go.Histogram(
            x=biases,
            name=model_name,
            marker_color=MODEL_COLORS[idx % len(MODEL_COLORS)],
            opacity=0.7,
            nbinsx=30,
            histnorm='probability density'
        ))
        
        # Add vertical line for mean
        fig.add_vline(
            x=stats["mean"], 
            line_dash="dash", 
            line_color=MODEL_COLORS[idx % len(MODEL_COLORS)],
            annotation_text=f"{model_name} mean: {stats['mean']:.3f}",
            annotation_position="top" if idx % 2 == 0 else "bottom"
        )
    
    fig.update_layout(
        width=width, height=height,
        plot_bgcolor='white',
        margin=dict(l=10, r=10, t=40, b=25),
        font=dict(size=13),
        title_text="Bias Distribution Comparison Across Models",
        title_font=dict(size=16), title_x=0.5, title_y=0.98,
        legend=dict(yanchor="top", y=0.98, xanchor="left", x=0.02,
                   bordercolor="darkgrey", borderwidth=1)
    )
    
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline=True, zerolinecolor='black',
        title_text="Bias (spatial - descriptive)",
        title_font=dict(size=14), tickfont=dict(size=12),
        showline=True, linewidth=1, linecolor='darkgrey'
    )
    
    fig.update_yaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline=True, zerolinecolor='darkgrey',
        title_text="Density",
        title_font=dict(size=14), tickfont=dict(size=12),
        showline=True, linewidth=1, linecolor='darkgrey'
    )
    
    return fig

fig = plot_bias_distribution(model_data, bias_stats)
fig.show()

## Plot 2: RMS Bias Comparison (Bar Chart)

In [None]:
def plot_rms_bias_comparison(bias_stats, width=500, height=350):
    """Plot RMS bias for each model as bar chart."""
    models = list(bias_stats.keys())
    rms_values = [bias_stats[m]["rms"] for m in models]
    
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        x=models,
        y=rms_values,
        marker_color=[MODEL_COLORS[i % len(MODEL_COLORS)] for i in range(len(models))],
        text=[f"{v:.4f}" for v in rms_values],
        textposition='outside',
        showlegend=False
    ))
    
    fig.update_layout(
        width=width, height=height,
        plot_bgcolor='white',
        margin=dict(l=10, r=10, t=40, b=25),
        font=dict(size=13),
        title_text="RMS Bias Comparison (Higher = More Biased)",
        title_font=dict(size=16), title_x=0.5, title_y=0.98
    )
    
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        title_text="Model",
        title_font=dict(size=14), tickfont=dict(size=12),
        showline=True, linewidth=1, linecolor='darkgrey'
    )
    
    fig.update_yaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline=True, zerolinecolor='darkgrey',
        title_text="RMS Bias",
        title_font=dict(size=14), tickfont=dict(size=12),
        showline=True, linewidth=1, linecolor='darkgrey',
        range=[0, max(rms_values) * 1.2] if rms_values else None
    )
    
    return fig

fig = plot_rms_bias_comparison(bias_stats)
fig.show()

## Plot 3: Bias Statistics Summary Table

In [None]:
# Create summary table
summary_data = []
for model_name, stats in bias_stats.items():
    summary_data.append({
        "Model": model_name,
        "RMS Bias": f"{stats['rms']:.4f}",
        "Mean Bias": f"{stats['mean']:.4f}",
        "Std Bias": f"{stats['std']:.4f}",
        "Spatial Favored %": f"{stats['spatial_favored_pct']:.1f}%",
        "Descriptive Favored %": f"{stats['descriptive_favored_pct']:.1f}%",
    })

summary_df = pd.DataFrame(summary_data)
print("\nBias Statistics Summary:")
print(summary_df.to_string(index=False))

## Plot 4: Box Plot Comparison

In [None]:
def plot_bias_boxplot(model_data, width=600, height=400):
    """Create box plot comparing bias distributions."""
    fig = go.Figure()
    
    for idx, (model_name, data) in enumerate(model_data.items()):
        biases = data["bias"].values
        
        fig.add_trace(go.Box(
            y=biases,
            name=model_name,
            marker_color=MODEL_COLORS[idx % len(MODEL_COLORS)],
            boxmean='sd'  # Show mean and standard deviation
        ))
    
    fig.update_layout(
        width=width, height=height,
        plot_bgcolor='white',
        margin=dict(l=10, r=10, t=40, b=25),
        font=dict(size=13),
        title_text="Bias Distribution Box Plot Comparison",
        title_font=dict(size=16), title_x=0.5, title_y=0.98,
        showlegend=False
    )
    
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        title_text="Model",
        title_font=dict(size=14), tickfont=dict(size=12),
        showline=True, linewidth=1, linecolor='darkgrey'
    )
    
    fig.update_yaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline=True, zerolinecolor='black',
        title_text="Bias (spatial - descriptive)",
        title_font=dict(size=14), tickfont=dict(size=12),
        showline=True, linewidth=1, linecolor='darkgrey'
    )
    
    return fig

fig = plot_bias_boxplot(model_data)
fig.show()