In [None]:
#................................... FIC Score Script

In [None]:
##########################..... Simulated Cases .......................###################

In [None]:
#... With all FIC Benmarking Tiers for alpha

In [26]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "fic_results"
os.makedirs(output_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. DATA SIMULATION (unchanged)
# ============================================

def generate_healthcare_data(n=5000):
    np.random.seed(42)
    gender = np.random.choice(['Male', 'Female', 'Non-binary'], size=n, p=[0.455, 0.449, 0.096])
    age = np.clip(np.random.normal(loc=45, scale=15, size=n), 18, 85)
    income = np.clip(np.random.lognormal(mean=10.5, sigma=0.8, size=n), 20000, 200000)
    marital_status = np.random.choice(['Single', 'Married', 'Divorced', 'Widowed'], size=n, p=[0.3, 0.4, 0.2, 0.1])
    immigration_status = np.random.choice(['Citizen', 'Immigrant', 'Refugee'], size=n, p=[0.7, 0.25, 0.05])
    education = np.random.choice(['High School', 'College', 'Bachelor', 'Master', 'PhD'], size=n, p=[0.2, 0.3, 0.3, 0.15, 0.05])
    job_status = np.random.choice(['Employed', 'Unemployed', 'Student', 'Retired'], size=n, p=[0.6, 0.15, 0.15, 0.1])

    depression_prob = np.zeros(n)
    for i in range(n):
        base = 0.15 if gender[i] == 'Male' else 0.20 if gender[i] == 'Female' else 0.35
        prob = base
        if job_status[i] == 'Unemployed': prob += 0.20
        if income[i] < 30000: prob += 0.15
        if age[i] < 25 or age[i] > 65: prob += 0.10
        depression_prob[i] = np.clip(prob, 0, 0.95)
    depression = np.random.binomial(1, depression_prob)

    data = pd.DataFrame({
        'age': age, 'income': income, 'marital_status': marital_status,
        'immigration_status': immigration_status, 'education': education,
        'job_status': job_status, 'gender': gender, 'depression': depression
    })
    return data

def generate_criminal_justice_data(n=8000):
    np.random.seed(42)
    regions = ['Africa', 'EU', 'South America', 'North America', 'Arab/Middle East', 'Asia', 'Oceania']
    region_weights = [0.20, 0.25, 0.15, 0.10, 0.10, 0.15, 0.05]
    region = np.random.choice(regions, size=n, p=region_weights)
    gender = np.random.choice(['Male', 'Female'], size=n, p=[0.7, 0.3])
    age = np.clip(np.random.normal(loc=35, scale=12, size=n), 18, 70)
    income = np.clip(np.random.lognormal(mean=10.0, sigma=0.9, size=n), 15000, 150000)
    prior_convictions = np.clip(np.random.poisson(lam=1.5, size=n), 0, 10)
    education = np.random.choice(['Less than HS', 'High School', 'Some College', 'College', 'Graduate'], size=n, p=[0.1, 0.3, 0.25, 0.25, 0.1])
    employment = np.random.choice(['Employed', 'Unemployed', 'Student', 'Other'], size=n, p=[0.5, 0.25, 0.15, 0.1])
    asylum_seeker = np.random.choice([0, 1], size=n, p=[0.85, 0.15])

    high_risk_prob = np.zeros(n)
    for i in range(n):
        base_rates = {'Africa': 0.25, 'EU': 0.10, 'South America': 0.20, 'North America': 0.15,
                      'Arab/Middle East': 0.22, 'Asia': 0.18, 'Oceania': 0.12}
        base = base_rates[region[i]]
        prob = base
        if asylum_seeker[i] == 1: prob += 0.15
        if employment[i] == 'Unemployed': prob += 0.12
        high_risk_prob[i] = np.clip(prob, 0, 0.95)
    high_risk = np.random.binomial(1, high_risk_prob)

    data = pd.DataFrame({
        'gender': gender, 'age': age, 'income': income, 'prior_convictions': prior_convictions,
        'education': education, 'employment': employment, 'asylum_seeker': asylum_seeker,
        'region': region, 'high_risk': high_risk
    })
    return data

# ============================================
# 2-3. MODEL & FIC (unchanged for logic)
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. IMPROVED VISUALIZATIONS
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Colorbar
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('FIC Score', fontsize=14, fontweight='bold')
        cbar.ax.tick_params(labelsize=12)

        # Bold text inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.3f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'alphaF = {alphaF}', fontsize=18, fontweight='bold', pad=20)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), dpi=400, bbox_inches='tight')
    plt.close()

def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a separate figure for each alphaF value
        fig, ax = plt.subplots(figsize=(18, 10))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.5, width=0.7)
        
        # Add tier threshold lines
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.5, 
                   label='Optimum (>0.75)')
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.5, 
                   label='Acceptable (>0.50)')
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.5, 
                   label='Unacceptable (≤0.00)')
        
        # Add value and tier labels on bars
        for bar, score, tier in zip(bars, fic_scores, tiers):
            height = bar.get_height()
            # Position text based on bar height
            if height >= 0:
                ax.text(bar.get_x() + bar.get_width()/2, height + 0.02,
                        f'{score:.3f}\n({tier})',
                        ha='center', va='bottom',
                        fontsize=11, fontweight='bold', color='black')
            else:
                ax.text(bar.get_x() + bar.get_width()/2, height - 0.05,
                        f'{score:.3f}\n({tier})',
                        ha='center', va='top',
                        fontsize=11, fontweight='bold', color='black')
        
        # Customize axes
        ax.set_xlabel('Group Pairs', fontsize=16, fontweight='bold', labelpad=15)
        ax.set_ylabel('FIC Score', fontsize=16, fontweight='bold', labelpad=15)
        ax.set_title(f'{dataset_name}: FIC Benchmarking Tiers ({metric}, alphaF={alphaF})',
                    fontsize=20, fontweight='bold', pad=20)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=13, fontweight='bold')
        
        # Set consistent y-axis limits
        ax.set_ylim(-0.3, 1.15)
        
        # Add grid
        ax.grid(True, axis='y', alpha=0.4, linestyle='--')
        
        # Add legend
        ax.legend(fontsize=14, loc='upper left')
        
        # Adjust layout
        plt.tight_layout()
        
        # Save the figure with alphaF in the filename
        plt.savefig(os.path.join(output_dir, f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'), 
                    dpi=400, bbox_inches='tight')
        plt.close()

# ============================================
# 5-7. REST OF CODE (analyze_dataset, run_complete_analysis, etc.)
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv'))

    fic_results = fic_framework.analyze_fairness(baseline_metrics, 'accuracy')

    # FIC table
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"; row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("FIC ANALYSIS TABLE:")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, f'Case{case_number}_FIC_Analysis.csv'), index=False)

    # Tier classification
    tier_data = []
    print("TIER CLASSIFICATION:")
    for af in fic_framework.alphaF_values:
        print(f"\nFor alphaF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Tier_Classification.csv'), index=False)

    print("GENERATING VISUALIZATIONS...")
    plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}')
    plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}')

    # Model comparison
    print("MODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC alphaF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (alphaF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv'), index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS")
    print("="*80)

    healthcare_results = analyze_dataset(
        dataset_name="Healthcare - Depression Diagnosis",
        data_generator=lambda: generate_healthcare_data(5000),
        target_col='depression',
        protected_col='gender',
        case_number=1
    )

    criminal_results = analyze_dataset(
        dataset_name="Criminal Justice - Recidivism Risk",
        data_generator=lambda: generate_criminal_justice_data(8000),
        target_col='high_risk',
        protected_col='region',
        case_number=2
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT")
    print("="*80)

    print("CASE 1 - HEALTHCARE KEY FINDINGS:")
    print("-" * 60)
    for af in [0.05, 0.10, 0.20]:
        if af in healthcare_results['fic_results'] and healthcare_results['fic_results'][af]:
            omegas = [d['omega'] for d in healthcare_results['fic_results'][af].values()]
            max_o = max(omegas)
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            fic = FairnessInformationCriterion()
            for d in healthcare_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"alphaF={af}: omega_max = {max_o:.4f}")
            print(f"  Tier distribution: {tiers}")

    print("CASE 2 - CRIMINAL JUSTICE KEY FINDINGS:")
    print("-" * 60)
    for af in [0.05, 0.10, 0.20]:
        if af in criminal_results['fic_results'] and criminal_results['fic_results'][af]:
            items = list(criminal_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}: omega_max = {max_o:.4f} ({worst_pair})")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)

    return healthcare_results, criminal_results

if __name__ == "__main__":
    healthcare_results, criminal_results = run_complete_analysis()

    print("All done")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS

CASE 1: Healthcare - Depression Diagnosis
GROUP METRICS TABLE (Baseline Logistic Regression):
            accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
Male          0.7470          0.0281  0.0581  0.9821  0.0179  0.9419  0.5263  0.7534  0.1047  0.6563
Female        0.6834          0.0148  0.0143  0.9850  0.0150  0.9857  0.3000  0.6892  0.0273  0.6029
Non-binary    0.5338          0.0338  0.0294  0.9625  0.0375  0.9706  0.4000  0.5385  0.0548  0.6287
FIC ANALYSIS TABLE:
         Group Pair              alphaF=0.05 Hypothesis alphaF=0.05               alphaF=0.1    Hypothesis alphaF=0.1              alphaF=0.15   Hypothesis alphaF=0.15               alphaF=0.2    Hypothesis alphaF=0.2
Female - Non-binary omega=0.1496, FIC=-1.993     Reject H₀ (Unfair) omega=0.1496, FIC=-0.496       Reject H₀ (Unfair)  omega=0.1496, FIC=0.002 Fail to reject Ho (Fair)  omega=0.1496, FIC=0.252 Fail to reject Ho (Fair

In [None]:
#..................With all PLots

In [68]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "fic_results_ALL_PLOTS"
os.makedirs(output_dir, exist_ok=True)

# Also create PDF subdirectory
pdf_dir = os.path.join(output_dir, "PDF_plots")
os.makedirs(pdf_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. DATA SIMULATION
# ============================================

def generate_healthcare_data(n=5000):
    """
    Generate healthcare data for depression diagnosis
    """
    np.random.seed(42)
    
    # Gender distribution with realistic proportions
    gender = np.random.choice(['Male', 'Female', 'Non-binary'], 
                             size=n, 
                             p=[0.455, 0.449, 0.096])
    
    # Age and income with realistic distributions
    age = np.clip(np.random.normal(loc=45, scale=15, size=n), 18, 85)
    income = np.clip(np.random.lognormal(mean=10.5, sigma=0.8, size=n), 
                    20000, 200000)
    
    # Demographic variables
    marital_status = np.random.choice(['Single', 'Married', 'Divorced', 'Widowed'], 
                                     size=n, 
                                     p=[0.3, 0.4, 0.2, 0.1])
    
    immigration_status = np.random.choice(['Citizen', 'Immigrant', 'Refugee'], 
                                         size=n, 
                                         p=[0.7, 0.25, 0.05])
    
    education = np.random.choice(['High School', 'College', 'Bachelor', 'Master', 'PhD'], 
                                size=n, 
                                p=[0.2, 0.3, 0.3, 0.15, 0.05])
    
    job_status = np.random.choice(['Employed', 'Unemployed', 'Student', 'Retired'], 
                                 size=n, 
                                 p=[0.6, 0.15, 0.15, 0.1])

    # Generate depression probabilities with realistic biases
    depression_prob = np.zeros(n)
    for i in range(n):
        base = 0.15 if gender[i] == 'Male' else 0.20 if gender[i] == 'Female' else 0.35
        
        prob = base
        if job_status[i] == 'Unemployed': 
            prob += 0.20
        if income[i] < 30000: 
            prob += 0.15
        if age[i] < 25 or age[i] > 65: 
            prob += 0.10
        
        depression_prob[i] = np.clip(prob, 0, 0.95)
    
    # Generate binary depression outcome
    depression = np.random.binomial(1, depression_prob)

    # Create DataFrame
    data = pd.DataFrame({
        'age': age, 
        'income': income, 
        'marital_status': marital_status,
        'immigration_status': immigration_status, 
        'education': education,
        'job_status': job_status, 
        'gender': gender, 
        'depression': depression
    })
    
    print(f"Generated healthcare data: {len(data)} samples")
    print(f"Depression prevalence: {depression.mean():.3f}")
    print(f"Gender distribution:")
    print(data['gender'].value_counts(normalize=True).round(3))
    
    return data

def generate_criminal_justice_data(n=8000):
    """
    Generate criminal justice data for recidivism risk prediction
    """
    np.random.seed(42)
    
    # Region distribution
    regions = ['Africa', 'EU', 'South America', 'North America', 
               'Arab/Middle East', 'Asia', 'Oceania']
    region_weights = [0.20, 0.25, 0.15, 0.10, 0.10, 0.15, 0.05]
    region = np.random.choice(regions, size=n, p=region_weights)
    
    # Demographic variables
    gender = np.random.choice(['Male', 'Female'], size=n, p=[0.7, 0.3])
    age = np.clip(np.random.normal(loc=35, scale=12, size=n), 18, 70)
    income = np.clip(np.random.lognormal(mean=10.0, sigma=0.9, size=n), 
                    15000, 150000)
    
    prior_convictions = np.clip(np.random.poisson(lam=1.5, size=n), 0, 10)
    
    education = np.random.choice(['Less than HS', 'High School', 'Some College', 
                                 'College', 'Graduate'], 
                                size=n, 
                                p=[0.1, 0.3, 0.25, 0.25, 0.1])
    
    employment = np.random.choice(['Employed', 'Unemployed', 'Student', 'Other'], 
                                 size=n, 
                                 p=[0.5, 0.25, 0.15, 0.1])
    
    asylum_seeker = np.random.choice([0, 1], size=n, p=[0.85, 0.15])

    # Generate high-risk probabilities with regional biases
    high_risk_prob = np.zeros(n)
    base_rates = {
        'Africa': 0.25, 
        'EU': 0.10, 
        'South America': 0.20, 
        'North America': 0.15,
        'Arab/Middle East': 0.22, 
        'Asia': 0.18, 
        'Oceania': 0.12
    }
    
    for i in range(n):
        base = base_rates[region[i]]
        prob = base
        if asylum_seeker[i] == 1: 
            prob += 0.15
        if employment[i] == 'Unemployed': 
            prob += 0.12
        if prior_convictions[i] > 3: 
            prob += 0.10
        
        high_risk_prob[i] = np.clip(prob, 0, 0.95)
    
    # Generate binary high-risk outcome
    high_risk = np.random.binomial(1, high_risk_prob)

    # Create DataFrame
    data = pd.DataFrame({
        'gender': gender, 
        'age': age, 
        'income': income, 
        'prior_convictions': prior_convictions,
        'education': education, 
        'employment': employment, 
        'asylum_seeker': asylum_seeker,
        'region': region, 
        'high_risk': high_risk
    })
    
    print(f"Generated criminal justice data: {len(data)} samples")
    print(f"High risk prevalence: {high_risk.mean():.3f}")
    print(f"Region distribution:")
    print(data['region'].value_counts(normalize=True).round(3))
    
    return data

# ============================================
# 2-3. MODEL & FIC
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    """
    Compute comprehensive performance metrics
    """
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    """
    Train and evaluate logistic regression models
    """
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    # Preprocessing pipeline
    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    # Train-test split with stratification
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=42, stratify=y
    )
    protected_test = data.loc[X_test.index, protected_col]

    # Preprocess data
    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    # Model selection
    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', 
                                  random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, 
                                  max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    # Train model
    model.fit(X_train_processed, y_train)
    
    # Make predictions
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    # Compute group-wise metrics
    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(
                y_test[mask], y_pred[mask], y_prob[mask]
            )

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    """
    Fairness Information Criterion (FIC) framework
    """
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        """Compute unfairness magnitude (ω)"""
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        """Compute FIC score"""
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        """Classify FIC score into fairness tiers"""
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        """
        Analyze fairness across all group pairs for all alphaF values
        """
        results = {}
        groups = list(group_metrics.keys())
        
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        
                        results[alphaF][pair] = {
                            'omega': omega, 
                            'fic_score': fic_score, 
                            'tier': tier,
                            'metric1': m1, 
                            'metric2': m2
                        }
        return results

# ============================================
# 4. VISUALIZATIONS - IMPROVED WITH PDF SUPPORT
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    """
    Create FIC heatmaps for all alphaF values
    """
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure for publication quality
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different αF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        # Fill matrix with FIC scores
        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        # Create heatmap
        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        # Customize axes
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.78, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on colorbar
    cbar.ax.text(1.8, 0.90, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.8, 0.60, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.8, 0.350, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.8, 0.100, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.6)

    plt.tight_layout(rect=[0, 0.03, 0.78, 0.95])
    
    # Save as PNG and PDF
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), 
                dpi=400, bbox_inches='tight')
    plt.savefig(os.path.join(pdf_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.pdf'), 
                format='pdf', bbox_inches='tight')
    plt.close()


def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    """
    Create benchmarking tier plots for each alphaF value
    """
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {
        'Optimum': '#2E8B57', 
        'Acceptable': '#FFD700', 
        'Questionable': '#FF8C00', 
        'Unacceptable': '#DC143C'
    }
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            continue
        
        # Create figure with expanded width for legend
        fig, ax = plt.subplots(figsize=(20, 8))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Find dynamic y-axis limits
        max_positive = max(fic_scores) if fic_scores else 1.0
        min_negative = min(fic_scores) if fic_scores else -0.25
        
        # Add padding
        y_max = max_positive * 1.10 if max_positive > 0 else 0.10
        y_min = min_negative * 1.10 if min_negative < 0 else -0.10
        
        # Ensure minimum range
        if y_max - y_min < 0.5:
            center = (max_positive + min_negative) / 2
            y_max = center + 0.25
            y_min = center - 0.25
        
        # Create bars
        bar_colors = [colors[t] for t in tiers]
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, alpha=0.7)
        
        # Customize axes
        ax.set_xlabel('Group Pairs', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{dataset_name}\nFIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set y-axis limits
        ax.set_ylim(y_min, y_max)
        
        # Format y-tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        
        # Create tier legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', 
                  label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', 
                  label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', 
                  label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', 
                  label='Unacceptable (FIC ≤ 0)')
        ]
        
        # Create threshold line legend
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, 
                   label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, 
                   label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, 
                   label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.05, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.05, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |M₁ - M₂|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout
        plt.tight_layout(rect=[0, 0, 0.80, 1])
        
        # Save figures
        png_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'
        pdf_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.pdf'
        
        plt.savefig(os.path.join(output_dir, png_filename), 
                    dpi=400, bbox_inches='tight')
        plt.savefig(os.path.join(pdf_dir, pdf_filename), 
                    format='pdf', bbox_inches='tight')
        plt.close()

# ============================================
# 5. ANALYSIS FUNCTIONS - FOR ALL METRICS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, 
                   case_number=1, model_types=['baseline', 'l1', 'l2']):
    """
    Complete analysis for a dataset
    """
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    # Generate data
    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    # Train baseline model
    baseline_metrics, _ = train_and_evaluate_models(data, target_col, 
                                                   protected_col, 'baseline')

    # Create metrics table
    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 
                            'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, 
                                  f'Case{case_number}_Group_Metrics.csv'))

    print("\nGENERATING VISUALIZATIONS FOR ALL METRICS...")
    
    # List of all metrics to analyze
    all_metrics = ['accuracy', 'selection_rate', 'tpr', 'tnr', 
                  'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']
    
    # Dictionary to store all FIC results
    all_fic_results = {}
    
    for metric in all_metrics:
        print(f"\n{'='*60}")
        print(f"ANALYZING METRIC: {metric.upper()}")
        print(f"{'='*60}")
        
        # Analyze fairness for this metric
        fic_results = fic_framework.analyze_fairness(baseline_metrics, metric)
        all_fic_results[metric] = fic_results
        
        # Generate heatmaps
        plot_fic_heatmaps(fic_results, 
                         f'Case{case_number}_{dataset_name}_{metric}', 
                         metric)
        
        # Generate benchmarking tiers
        plot_benchmarking_tiers(fic_results, 
                               f'Case{case_number}_{dataset_name}_{metric}', 
                               metric)
        
        # Print summary
        print(f"Summary for {metric}:")
        for af in fic_framework.alphaF_values:
            if af in fic_results and fic_results[af]:
                omegas = [d['omega'] for d in fic_results[af].values()]
                max_o = max(omegas)
                avg_o = np.mean(omegas)
                tiers = {'Optimum': 0, 'Acceptable': 0, 
                        'Questionable': 0, 'Unacceptable': 0}
                for d in fic_results[af].values():
                    tiers[fic_framework.classify_tier(d['fic_score'])] += 1
                print(f"  αF={af}: ω_max={max_o:.4f}, ω_avg={avg_o:.4f}, "
                      f"Tiers={tiers}")

    # Store FIC results for accuracy (original metric)
    fic_results = all_fic_results['accuracy']
    
    # FIC table for accuracy
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = ("Fail to reject H₀ (Fair)" 
                                                  if d['omega'] <= af 
                                                  else "Reject H₀ (Unfair)")
            else:
                row[f'alphaF={af}'] = "N/A"
                row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    
    fic_df = pd.DataFrame(fic_table)
    print("\nFIC ANALYSIS TABLE (Accuracy):")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, 
                              f'Case{case_number}_FIC_Analysis_accuracy.csv'), 
                 index=False)

    # Tier classification for accuracy
    tier_data = []
    print("\nTIER CLASSIFICATION (Accuracy):")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = (tier if d['fic_score'] <= 0.75 
                       else f"{tier} (omega_max < {0.25*af:.4f})")
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 
                                 'ω': d['omega'], 'FIC': d['fic_score'], 
                                 'Tier': tier})
    
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, 
                               f'Case{case_number}_Tier_Classification_accuracy.csv'), 
                  index=False)

    # Model comparison
    print("\nMODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, 
                                                   protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = (np.mean([d['fic_score'] for d in model_fic[0.10].values()]) 
                  if 0.10 in model_fic and model_fic[0.10] else np.nan)
        max_omega = (max([d['omega'] for d in model_fic[0.10].values()]) 
                    if 0.10 in model_fic and model_fic[0.10] else np.nan)
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC (αF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (αF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, 
                                     f'Case{case_number}_Model_Comparison.csv'), 
                        index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'all_fic_results': all_fic_results,
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

# ============================================
# 6. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    """
    Run complete FIC analysis for simulated datasets
    """
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - SIMULATED DATASETS")
    print("="*80)

    # Healthcare dataset analysis
    healthcare_results = analyze_dataset(
        dataset_name="Healthcare - Depression Diagnosis",
        data_generator=lambda: generate_healthcare_data(5000),
        target_col='depression',
        protected_col='gender',
        case_number=1
    )

    # Criminal justice dataset analysis
    criminal_results = analyze_dataset(
        dataset_name="Criminal Justice - Recidivism Risk",
        data_generator=lambda: generate_criminal_justice_data(8000),
        target_col='high_risk',
        protected_col='region',
        case_number=2
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - SIMULATED DATASETS")
    print("="*80)

    print("CASE 1 - HEALTHCARE KEY FINDINGS:")
    print("-" * 60)
    data = healthcare_results['data']
    print(f"Total samples: {len(data)}")
    print(f"Depression prevalence: {data['depression'].mean():.3f}")
    print("\nGender distribution:")
    gender_dist = data['gender'].value_counts()
    for gender, count in gender_dist.items():
        prop = count / len(data)
        print(f"  {gender}: {count} ({prop:.3f})")
    
    print("\nDepression by gender:")
    for gender in sorted(data['gender'].unique()):
        subset = data[data['gender'] == gender]
        depression_prop = subset['depression'].mean()
        print(f"  {gender}: {depression_prop:.3f}")

    print("\nCASE 2 - CRIMINAL JUSTICE KEY FINDINGS:")
    print("-" * 60)
    data = criminal_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk prevalence: {data['high_risk'].mean():.3f}")
    print("\nRegion distribution:")
    region_dist = data['region'].value_counts()
    for region, count in region_dist.items():
        prop = count / len(data)
        print(f"  {region}: {count} ({prop:.3f})")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)
    print(f"Generated plots for all metrics: accuracy, selection_rate, tpr, tnr, "
          f"fpr, fnr, ppv, npv, f1, auc")
    print(f"Each metric has:")
    print(f"  - 1 heatmap figure (2x2 grid for all alphaF values)")
    print(f"  - 4 benchmarking tier plots (one for each alphaF: 0.05, 0.10, 0.15, 0.20)")
    print(f"\nAll plots saved in both PNG and PDF formats.")
    print(f"PNG files saved in: {output_dir}/")
    print(f"PDF files saved in: {pdf_dir}/")

    return healthcare_results, criminal_results

# ============================================
# 7. EXECUTION
# ============================================

if __name__ == "__main__":
    # Run the complete analysis
    healthcare_results, criminal_results = run_complete_analysis()

    print("All analysis completed!")
    print(f"Results saved to: {output_dir}/")
    print(f"PDF files saved to: {pdf_dir}/")
    print("Files include:")
    print("  - Group metrics (CSV)")
    print("  - FIC analysis tables for accuracy (CSV)")
    print("  - Tier classification for accuracy (CSV)")
    print("  - Model comparison (CSV)")
    print("  - FIC heatmaps for ALL 10 metrics (PNG + PDF)")
    print("  - Benchmarking tiers for ALL 10 metrics (4 plots per metric = 40 PNG + 40 PDF files)")
    print(f"Total plots generated: {10 + 40} PNG files + {10 + 40} PDF files = {100} total files")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - SIMULATED DATASETS

CASE 1: Healthcare - Depression Diagnosis
Generated healthcare data: 5000 samples
Depression prevalence: 0.300
Gender distribution:
gender
Male          0.461
Female        0.447
Non-binary    0.093
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
            accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
Male          0.7470          0.0281  0.0581  0.9821  0.0179  0.9419  0.5263  0.7534  0.1047  0.6563
Female        0.6834          0.0148  0.0143  0.9850  0.0150  0.9857  0.3000  0.6892  0.0273  0.6029
Non-binary    0.5338          0.0338  0.0294  0.9625  0.0375  0.9706  0.4000  0.5385  0.0548  0.6287

GENERATING VISUALIZATIONS FOR ALL METRICS...

ANALYZING METRIC: ACCURACY
Summary for accuracy:
  αF=0.05: ω_max=0.2133, ω_avg=0.1422, Tiers={'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 3}
  αF=0.1: ω_max=0.2133, ω_avg=0.

In [None]:
#........ With EXCEL NUMERICAL METRIC VALUES

In [70]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.styles import PatternFill, Font, Alignment, Border, Side

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "fic_results_ALL_METRICS_EXCEL"
os.makedirs(output_dir, exist_ok=True)

# Also create PDF subdirectory
pdf_dir = os.path.join(output_dir, "PDF_plots")
os.makedirs(pdf_dir, exist_ok=True)

# Create Excel subdirectory
excel_dir = os.path.join(output_dir, "Excel_results")
os.makedirs(excel_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. DATA SIMULATION
# ============================================

def generate_healthcare_data(n=5000):
    """
    Generate healthcare data for depression diagnosis
    """
    np.random.seed(42)
    
    # Gender distribution with realistic proportions
    gender = np.random.choice(['Male', 'Female', 'Non-binary'], 
                             size=n, 
                             p=[0.455, 0.449, 0.096])
    
    # Age and income with realistic distributions
    age = np.clip(np.random.normal(loc=45, scale=15, size=n), 18, 85)
    income = np.clip(np.random.lognormal(mean=10.5, sigma=0.8, size=n), 
                    20000, 200000)
    
    # Demographic variables
    marital_status = np.random.choice(['Single', 'Married', 'Divorced', 'Widowed'], 
                                     size=n, 
                                     p=[0.3, 0.4, 0.2, 0.1])
    
    immigration_status = np.random.choice(['Citizen', 'Immigrant', 'Refugee'], 
                                         size=n, 
                                         p=[0.7, 0.25, 0.05])
    
    education = np.random.choice(['High School', 'College', 'Bachelor', 'Master', 'PhD'], 
                                size=n, 
                                p=[0.2, 0.3, 0.3, 0.15, 0.05])
    
    job_status = np.random.choice(['Employed', 'Unemployed', 'Student', 'Retired'], 
                                 size=n, 
                                 p=[0.6, 0.15, 0.15, 0.1])

    # Generate depression probabilities with realistic biases
    depression_prob = np.zeros(n)
    for i in range(n):
        base = 0.15 if gender[i] == 'Male' else 0.20 if gender[i] == 'Female' else 0.35
        
        prob = base
        if job_status[i] == 'Unemployed': 
            prob += 0.20
        if income[i] < 30000: 
            prob += 0.15
        if age[i] < 25 or age[i] > 65: 
            prob += 0.10
        
        depression_prob[i] = np.clip(prob, 0, 0.95)
    
    # Generate binary depression outcome
    depression = np.random.binomial(1, depression_prob)

    # Create DataFrame
    data = pd.DataFrame({
        'age': age, 
        'income': income, 
        'marital_status': marital_status,
        'immigration_status': immigration_status, 
        'education': education,
        'job_status': job_status, 
        'gender': gender, 
        'depression': depression
    })
    
    print(f"Generated healthcare data: {len(data)} samples")
    print(f"Depression prevalence: {depression.mean():.3f}")
    print(f"Gender distribution:")
    print(data['gender'].value_counts(normalize=True).round(3))
    
    return data

def generate_criminal_justice_data(n=8000):
    """
    Generate criminal justice data for recidivism risk prediction
    """
    np.random.seed(42)
    
    # Region distribution
    regions = ['Africa', 'EU', 'South America', 'North America', 
               'Arab/Middle East', 'Asia', 'Oceania']
    region_weights = [0.20, 0.25, 0.15, 0.10, 0.10, 0.15, 0.05]
    region = np.random.choice(regions, size=n, p=region_weights)
    
    # Demographic variables
    gender = np.random.choice(['Male', 'Female'], size=n, p=[0.7, 0.3])
    age = np.clip(np.random.normal(loc=35, scale=12, size=n), 18, 70)
    income = np.clip(np.random.lognormal(mean=10.0, sigma=0.9, size=n), 
                    15000, 150000)
    
    prior_convictions = np.clip(np.random.poisson(lam=1.5, size=n), 0, 10)
    
    education = np.random.choice(['Less than HS', 'High School', 'Some College', 
                                 'College', 'Graduate'], 
                                size=n, 
                                p=[0.1, 0.3, 0.25, 0.25, 0.1])
    
    employment = np.random.choice(['Employed', 'Unemployed', 'Student', 'Other'], 
                                 size=n, 
                                 p=[0.5, 0.25, 0.15, 0.1])
    
    asylum_seeker = np.random.choice([0, 1], size=n, p=[0.85, 0.15])

    # Generate high-risk probabilities with regional biases
    high_risk_prob = np.zeros(n)
    base_rates = {
        'Africa': 0.25, 
        'EU': 0.10, 
        'South America': 0.20, 
        'North America': 0.15,
        'Arab/Middle East': 0.22, 
        'Asia': 0.18, 
        'Oceania': 0.12
    }
    
    for i in range(n):
        base = base_rates[region[i]]
        prob = base
        if asylum_seeker[i] == 1: 
            prob += 0.15
        if employment[i] == 'Unemployed': 
            prob += 0.12
        if prior_convictions[i] > 3: 
            prob += 0.10
        
        high_risk_prob[i] = np.clip(prob, 0, 0.95)
    
    # Generate binary high-risk outcome
    high_risk = np.random.binomial(1, high_risk_prob)

    # Create DataFrame
    data = pd.DataFrame({
        'gender': gender, 
        'age': age, 
        'income': income, 
        'prior_convictions': prior_convictions,
        'education': education, 
        'employment': employment, 
        'asylum_seeker': asylum_seeker,
        'region': region, 
        'high_risk': high_risk
    })
    
    print(f"Generated criminal justice data: {len(data)} samples")
    print(f"High risk prevalence: {high_risk.mean():.3f}")
    print(f"Region distribution:")
    print(data['region'].value_counts(normalize=True).round(3))
    
    return data

# ============================================
# 2-3. MODEL & FIC
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    """
    Compute comprehensive performance metrics
    """
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    """
    Train and evaluate logistic regression models
    """
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    # Preprocessing pipeline
    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    # Train-test split with stratification
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=42, stratify=y
    )
    protected_test = data.loc[X_test.index, protected_col]

    # Preprocess data
    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    # Model selection
    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', 
                                  random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, 
                                  max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    # Train model
    model.fit(X_train_processed, y_train)
    
    # Make predictions
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    # Compute group-wise metrics
    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(
                y_test[mask], y_pred[mask], y_prob[mask]
            )

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    """
    Fairness Information Criterion (FIC) framework
    """
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        """Compute unfairness magnitude (ω)"""
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        """Compute FIC score"""
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        """Classify FIC score into fairness tiers"""
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        """
        Analyze fairness across all group pairs for all alphaF values
        """
        results = {}
        groups = list(group_metrics.keys())
        
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        
                        results[alphaF][pair] = {
                            'omega': omega, 
                            'fic_score': fic_score, 
                            'tier': tier,
                            'metric1': m1, 
                            'metric2': m2
                        }
        return results

# ============================================
# 4. EXCEL EXPORT FUNCTIONS
# ============================================

def save_to_excel_with_formatting(data_dict, filename, sheet_name_prefix=""):
    """
    Save multiple dataframes to Excel with formatting
    """
    excel_path = os.path.join(excel_dir, filename)
    
    with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
        # Save each dataframe to a separate sheet
        for sheet_name, df in data_dict.items():
            # Create full sheet name
            full_sheet_name = f"{sheet_name_prefix}_{sheet_name}" if sheet_name_prefix else sheet_name
            
            # Truncate sheet name if too long (Excel limit is 31 characters)
            if len(full_sheet_name) > 31:
                full_sheet_name = full_sheet_name[:31]
            
            # Write dataframe to Excel
            df.to_excel(writer, sheet_name=full_sheet_name, index=False)
            
            # Get the worksheet for formatting
            worksheet = writer.sheets[full_sheet_name]
            
            # Apply formatting
            for column in worksheet.columns:
                max_length = 0
                column_letter = column[0].column_letter
                for cell in column:
                    try:
                        if len(str(cell.value)) > max_length:
                            max_length = len(str(cell.value))
                    except:
                        pass
                adjusted_width = min(max_length + 2, 50)
                worksheet.column_dimensions[column_letter].width = adjusted_width
            
            # Freeze the first row
            worksheet.freeze_panes = 'A2'
    
    print(f"  ✓ Excel file saved: {filename}")
    return excel_path

def create_comprehensive_excel_report(results_dict, case_name, case_number, metrics_list):
    """
    Create a comprehensive Excel report with all numerical values
    """
    print(f"\n{'='*80}")
    print(f"CREATING COMPREHENSIVE EXCEL REPORT FOR {case_name.upper()}")
    print(f"{'='*80}")
    
    data_dict = {}
    
    # Extract results
    baseline_metrics = results_dict['baseline_metrics']
    all_fic_results = results_dict['all_fic_results']
    data = results_dict['data']
    
    # 1. Group Metrics Table
    print("1. Saving Group Metrics Table...")
    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    data_dict['Group_Metrics'] = metrics_df.reset_index().rename(columns={'index': 'Protected_Group'})
    
    # 2. FIC Analysis Tables for all metrics
    print("2. Saving FIC Analysis Tables for all metrics...")
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            # Create comprehensive FIC table for this metric
            fic_table = []
            pairs = list(set(p for a in fic_results.values() for p in a.keys()))
            
            for pair in sorted(pairs):
                row = {'Group_Pair': pair}
                for af in [0.05, 0.10, 0.15, 0.20]:
                    if af in fic_results and pair in fic_results[af]:
                        d = fic_results[af][pair]
                        row[f'omega_alphaF_{af}'] = d['omega']
                        row[f'FIC_alphaF_{af}'] = d['fic_score']
                        row[f'Tier_alphaF_{af}'] = d['tier']
                        row[f'Metric1_{af}'] = d['metric1']
                        row[f'Metric2_{af}'] = d['metric2']
                    else:
                        row[f'omega_alphaF_{af}'] = np.nan
                        row[f'FIC_alphaF_{af}'] = np.nan
                        row[f'Tier_alphaF_{af}'] = "N/A"
                        row[f'Metric1_{af}'] = np.nan
                        row[f'Metric2_{af}'] = np.nan
                fic_table.append(row)
            
            if fic_table:
                fic_df = pd.DataFrame(fic_table)
                data_dict[f'FIC_Analysis_{metric}'] = fic_df
    
    # 3. Tier Classification Summary for all metrics
    print("3. Saving Tier Classification Summary for all metrics...")
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            tier_summary = []
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                    fic_framework = FairnessInformationCriterion()
                    for d in fic_results[af].values():
                        tiers[fic_framework.classify_tier(d['fic_score'])] += 1
                    
                    total_pairs = sum(tiers.values())
                    for tier_name, count in tiers.items():
                        tier_summary.append({
                            'Metric': metric,
                            'alphaF': af,
                            'Tier': tier_name,
                            'Count': count,
                            'Percentage': (count / total_pairs * 100) if total_pairs > 0 else 0
                        })
            
            if tier_summary:
                tier_summary_df = pd.DataFrame(tier_summary)
                data_dict[f'Tier_Summary_{metric}'] = tier_summary_df
    
    # 4. Benchmarking Tiers Numerical Values
    print("4. Saving Benchmarking Tiers Numerical Values...")
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            benchmark_data = []
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    for pair, d in fic_results[af].items():
                        benchmark_data.append({
                            'Metric': metric,
                            'alphaF': af,
                            'Group_Pair': pair,
                            'FIC_Score': d['fic_score'],
                            'Tier': d['tier'],
                            'omega': d['omega'],
                            'Metric_Value_Group1': d['metric1'],
                            'Metric_Value_Group2': d['metric2']
                        })
            
            if benchmark_data:
                benchmark_df = pd.DataFrame(benchmark_data)
                data_dict[f'Benchmark_Tiers_{metric}'] = benchmark_df
    
    # 5. Summary Statistics for each metric
    print("5. Saving Summary Statistics for each metric...")
    summary_stats = []
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    fic_scores = [d['fic_score'] for d in fic_results[af].values()]
                    omegas = [d['omega'] for d in fic_results[af].values()]
                    
                    summary_stats.append({
                        'Metric': metric,
                        'alphaF': af,
                        'FIC_Mean': np.mean(fic_scores),
                        'FIC_Std': np.std(fic_scores),
                        'FIC_Min': np.min(fic_scores),
                        'FIC_Max': np.max(fic_scores),
                        'omega_Mean': np.mean(omegas),
                        'omega_Std': np.std(omegas),
                        'omega_Min': np.min(omegas),
                        'omega_Max': np.max(omegas),
                        'Num_Pairs': len(fic_scores)
                    })
    
    if summary_stats:
        summary_df = pd.DataFrame(summary_stats)
        data_dict['Summary_Statistics'] = summary_df
    
    # 6. Model Comparison
    print("6. Saving Model Comparison...")
    if 'comparison_df' in results_dict:
        data_dict['Model_Comparison'] = results_dict['comparison_df']
    
    # 7. Dataset Statistics
    print("7. Saving Dataset Statistics...")
    dataset_stats_data = []
    
    # Basic statistics
    dataset_stats_data.append({
        'Statistic': 'Total_Samples',
        'Value': len(data)
    })
    
    # Target variable statistics
    target_col = 'depression' if 'depression' in data.columns else 'high_risk'
    dataset_stats_data.append({
        'Statistic': f'{target_col.capitalize()}_Prevalence',
        'Value': data[target_col].mean()
    })
    
    # Protected attribute statistics
    protected_col = 'gender' if 'gender' in data.columns else 'region'
    for group in data[protected_col].unique():
        count = (data[protected_col] == group).sum()
        proportion = count / len(data)
        dataset_stats_data.append({
            'Statistic': f'{protected_col.capitalize()}_{group}_Count',
            'Value': count
        })
        dataset_stats_data.append({
            'Statistic': f'{protected_col.capitalize()}_{group}_Proportion',
            'Value': proportion
        })
    
    dataset_stats_df = pd.DataFrame(dataset_stats_data)
    data_dict['Dataset_Statistics'] = dataset_stats_df
    
    # 8. Fairness Assessment Matrix
    print("8. Creating Fairness Assessment Matrix...")
    fairness_matrix = []
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    # Count pairs in each tier
                    fic_framework = FairnessInformationCriterion()
                    tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                    for d in fic_results[af].values():
                        tiers[fic_framework.classify_tier(d['fic_score'])] += 1
                    
                    # Determine overall fairness status
                    if tiers['Unacceptable'] > 0:
                        overall_status = "Unfair"
                    elif tiers['Questionable'] > 0:
                        overall_status = "Questionable"
                    elif tiers['Acceptable'] > 0:
                        overall_status = "Acceptable"
                    else:
                        overall_status = "Optimum"
                    
                    fairness_matrix.append({
                        'Metric': metric,
                        'alphaF': af,
                        'Overall_Fairness': overall_status,
                        'Optimum_Pairs': tiers['Optimum'],
                        'Acceptable_Pairs': tiers['Acceptable'],
                        'Questionable_Pairs': tiers['Questionable'],
                        'Unacceptable_Pairs': tiers['Unacceptable'],
                        'Total_Pairs': sum(tiers.values())
                    })
    
    if fairness_matrix:
        fairness_df = pd.DataFrame(fairness_matrix)
        data_dict['Fairness_Assessment'] = fairness_df
    
    # Save all to Excel
    excel_filename = f"{case_name.replace(' ', '_')}_FIC_Complete_Analysis.xlsx"
    excel_file = save_to_excel_with_formatting(data_dict, excel_filename, f"Case{case_number}")
    
    print(f"\n✓ Excel report saved: {excel_file}")
    print(f"  Total sheets: {len(data_dict)}")
    
    # Print sheet names
    print("\nExcel sheets created:")
    for i, sheet_name in enumerate(data_dict.keys(), 1):
        print(f"  {i:2d}. {sheet_name}")
    
    return excel_file

# ============================================
# 5. VISUALIZATIONS - UPDATED FOR PDF AND EXPANDED LEGEND
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    """
    Create FIC heatmaps for all alphaF values
    """
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different αF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        # Fill matrix with FIC scores
        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        # Create heatmap
        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        # Customize axes
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.78, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on colorbar
    cbar.ax.text(1.6, 0.90, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.6, 0.60, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.6, 0.350, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.6, 0.100, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.6)

    plt.tight_layout(rect=[0, 0.03, 0.78, 0.95])
    
    # Save as PNG and PDF
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), 
                dpi=400, bbox_inches='tight')
    plt.savefig(os.path.join(pdf_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.pdf'), 
                format='pdf', bbox_inches='tight')
    plt.close()


def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    """
    Create benchmarking tier plots for each alphaF value
    """
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {
        'Optimum': '#2E8B57', 
        'Acceptable': '#FFD700', 
        'Questionable': '#FF8C00', 
        'Unacceptable': '#DC143C'
    }
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            continue
        
        # Create figure with expanded width for legend
        fig, ax = plt.subplots(figsize=(20, 8))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Find dynamic y-axis limits
        max_positive = max(fic_scores) if fic_scores else 1.0
        min_negative = min(fic_scores) if fic_scores else -0.25
        
        # Add padding
        y_max = max_positive * 1.10 if max_positive > 0 else 0.10
        y_min = min_negative * 1.10 if min_negative < 0 else -0.10
        
        # Ensure minimum range
        if y_max - y_min < 0.5:
            center = (max_positive + min_negative) / 2
            y_max = center + 0.25
            y_min = center - 0.25
        
        # Create bars
        bar_colors = [colors[t] for t in tiers]
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, alpha=0.7)
        
        # Customize axes
        ax.set_xlabel('Inter-Group', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{dataset_name}\nFIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set y-axis limits
        ax.set_ylim(y_min, y_max)
        
        # Format y-tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        
        # Create tier legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', 
                  label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', 
                  label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', 
                  label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', 
                  label='Unacceptable (FIC ≤ 0)')
        ]
        
        # Create threshold line legend
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, 
                   label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, 
                   label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, 
                   label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.05, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.05, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |M₁ - M₂|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout
        plt.tight_layout(rect=[0, 0, 0.80, 1])
        
        # Save figures
        png_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'
        pdf_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.pdf'
        
        plt.savefig(os.path.join(output_dir, png_filename), 
                    dpi=400, bbox_inches='tight')
        plt.savefig(os.path.join(pdf_dir, pdf_filename), 
                    format='pdf', bbox_inches='tight')
        plt.close()

# ============================================
# 6. ANALYSIS FUNCTIONS - UPDATED FOR ALL METRICS
# ============================================

def analyze_simulated_dataset(dataset_name, data_generator, target_col, protected_col, 
                            case_number=1, model_types=['baseline', 'l1', 'l2']):
    """
    Complete analysis for a simulated dataset
    """
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    # Generate data
    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    # Train baseline model
    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    # Create metrics table
    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 
                            'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    
    # Save metrics to CSV
    metrics_csv_path = os.path.join(output_dir, f'Case{case_number}_{dataset_name.replace(" ", "_")}_Group_Metrics.csv')
    metrics_df.to_csv(metrics_csv_path)
    print(f"Group metrics saved to: {metrics_csv_path}")

    print("GENERATING VISUALIZATIONS FOR ALL METRICS...")
    
    # List of all metrics to analyze
    all_metrics = ['accuracy', 'selection_rate', 'tpr', 'tnr', 
                  'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']
    
    # Dictionary to store all FIC results
    all_fic_results = {}
    
    # Dictionary to store metric summaries
    metric_summaries = {}
    
    for metric in all_metrics:
        print(f"\n{'='*60}")
        print(f"ANALYZING METRIC: {metric.upper()}")
        print(f"{'='*60}")
        
        # Analyze fairness for this metric
        fic_results = fic_framework.analyze_fairness(baseline_metrics, metric)
        all_fic_results[metric] = fic_results
        
        # Generate heatmaps for this metric
        plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Generate benchmarking tiers for this metric
        plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Store summary for this metric
        metric_summary = {}
        for af in fic_framework.alphaF_values:
            if af in fic_results and fic_results[af]:
                omegas = [d['omega'] for d in fic_results[af].values()]
                fic_scores = [d['fic_score'] for d in fic_results[af].values()]
                tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                fic = FairnessInformationCriterion()
                for d in fic_results[af].values():
                    tiers[fic.classify_tier(d['fic_score'])] += 1
                
                metric_summary[f'alphaF_{af}'] = {
                    'omega_max': max(omegas),
                    'omega_avg': np.mean(omegas),
                    'omega_min': min(omegas),
                    'fic_max': max(fic_scores),
                    'fic_avg': np.mean(fic_scores),
                    'fic_min': min(fic_scores),
                    'tiers': tiers
                }
        
        metric_summaries[metric] = metric_summary
        
        # Print summary for this metric
        print(f"Summary for {metric}:")
        for af in fic_framework.alphaF_values:
            if af in metric_summary:
                summary = metric_summary[f'alphaF_{af}']
                print(f"  αF={af}: ω_max={summary['omega_max']:.4f}, ω_avg={summary['omega_avg']:.4f}, "
                      f"FIC_avg={summary['fic_avg']:.3f}, Tiers={summary['tiers']}")

    # Store FIC results for accuracy (original metric)
    fic_results = all_fic_results['accuracy']
    
    # FIC table for accuracy
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject H₀ (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"
                row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    
    fic_df = pd.DataFrame(fic_table)
    print("FIC ANALYSIS TABLE (Accuracy):")
    print(fic_df.to_string(index=False))
    
    # Save FIC analysis to CSV
    fic_csv_path = os.path.join(output_dir, f'Case{case_number}_{dataset_name.replace(" ", "_")}_FIC_Analysis_accuracy.csv')
    fic_df.to_csv(fic_csv_path, index=False)
    print(f"FIC analysis saved to: {fic_csv_path}")

    # Tier classification for accuracy
    tier_data = []
    print("TIER CLASSIFICATION (Accuracy):")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    
    tier_df = pd.DataFrame(tier_data)
    
    # Save tier classification to CSV
    tier_csv_path = os.path.join(output_dir, f'Case{case_number}_{dataset_name.replace(" ", "_")}_Tier_Classification_accuracy.csv')
    tier_df.to_csv(tier_csv_path, index=False)
    print(f"✓ Tier classification saved to: {tier_csv_path}")

    # Model comparison
    print("MODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC (αF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (αF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    
    # Save model comparison to CSV
    comparison_csv_path = os.path.join(output_dir, f'Case{case_number}_{dataset_name.replace(" ", "_")}_Model_Comparison.csv')
    comparison_df.to_csv(comparison_csv_path, index=False)
    print(f"Model comparison saved to: {comparison_csv_path}")

    # Create comprehensive Excel report
    excel_file = create_comprehensive_excel_report(
        {
            'data': data,
            'baseline_metrics': baseline_metrics,
            'all_fic_results': all_fic_results,
            'metrics_df': metrics_df,
            'comparison_df': comparison_df
        },
        dataset_name,
        case_number,
        all_metrics
    )

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'all_fic_results': all_fic_results,
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df,
        'excel_file': excel_file,
        'metric_summaries': metric_summaries
    }

# ============================================
# 7. MAIN ANALYSIS
# ============================================

def run_complete_simulated_analysis():
    """
    Run complete FIC analysis for simulated datasets
    """
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - SIMULATED DATASETS")
    print("="*80)
    print(f"Output directory: {output_dir}")
    print(f"PDF directory: {pdf_dir}")
    print(f"Excel directory: {excel_dir}")

    print("\n" + "="*80)
    print("ANALYZING SIMULATED DATASETS")
    print("="*80)
    
    # Case 1: Healthcare Dataset
    healthcare_results = analyze_simulated_dataset(
        dataset_name="Healthcare - Depression Diagnosis",
        data_generator=lambda: generate_healthcare_data(5000),
        target_col='depression',
        protected_col='gender',
        case_number=1
    )
    
    # Case 2: Criminal Justice Dataset
    criminal_results = analyze_simulated_dataset(
        dataset_name="Criminal Justice - Recidivism Risk",
        data_generator=lambda: generate_criminal_justice_data(8000),
        target_col='high_risk',
        protected_col='region',
        case_number=2
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - SIMULATED DATASETS")
    print("="*80)

    print("CASE 1 - HEALTHCARE DATASET KEY FINDINGS:")
    print("-" * 60)
    data = healthcare_results['data']
    print(f"Total samples: {len(data)}")
    print(f"Depression prevalence: {data['depression'].mean():.3f}")
    print(f"Gender distribution:")
    gender_dist = data['gender'].value_counts()
    for gender, count in gender_dist.items():
        prop = count / len(data)
        print(f"  {gender}: {count} ({prop:.3f})")
    
    print("Depression by gender:")
    for gender in sorted(data['gender'].unique()):
        subset = data[data['gender'] == gender]
        depression_prop = subset['depression'].mean()
        print(f"  {gender}: {depression_prop:.3f}")

    print("FIC ANALYSIS SUMMARY (Accuracy - Healthcare):")
    print("-" * 60)
    fic_results = healthcare_results['all_fic_results']['accuracy']
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in fic_results and fic_results[af]:
            items = list(fic_results[af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in fic_results[af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("CASE 2 - CRIMINAL JUSTICE DATASET KEY FINDINGS:")
    print("-" * 60)
    data = criminal_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk prevalence: {data['high_risk'].mean():.3f}")
    print(f"Region distribution:")
    region_dist = data['region'].value_counts()
    for region, count in region_dist.items():
        prop = count / len(data)
        print(f"  {region}: {count} ({prop:.3f})")

    print("FIC ANALYSIS SUMMARY (Accuracy - Criminal Justice):")
    print("-" * 60)
    fic_results = criminal_results['all_fic_results']['accuracy']
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in fic_results and fic_results[af]:
            items = list(fic_results[af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in fic_results[af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - ALL RESULTS SAVED")
    print("="*80)
    
    print(f"VISUALIZATIONS:")
    print(f"  Generated plots for 2 cases × 10 metrics = 20 metric analyses")
    print(f"  Each metric analysis has:")
    print(f"    - 1 heatmap figure (2x2 grid for all alphaF values)")
    print(f"    - 4 benchmarking tier plots (one for each alphaF: 0.05, 0.10, 0.15, 0.20)")
    print(f"  Total plots: {2 * (10 + 40)} PNG files + {2 * (10 + 40)} PDF files = {200} total files")
    
    print(f"NUMERICAL RESULTS:")
    print(f"  CSV files saved in: {output_dir}/")
    print(f"  Comprehensive Excel reports:")
    print(f"    - {healthcare_results['excel_file']}")
    print(f"    - {criminal_results['excel_file']}")
       
    return healthcare_results, criminal_results

# ============================================
# 8. EXECUTION
# ============================================

if __name__ == "__main__":
    # Run the complete analysis for simulated datasets
    healthcare_results, criminal_results = run_complete_simulated_analysis()

    print("\n" + "="*80)
    print("ALL SIMULATED CASES ANALYSIS COMPLETED SUCCESSFULLY!")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - SIMULATED DATASETS
Output directory: fic_results_ALL_METRICS_EXCEL
PDF directory: fic_results_ALL_METRICS_EXCEL\PDF_plots
Excel directory: fic_results_ALL_METRICS_EXCEL\Excel_results

ANALYZING SIMULATED DATASETS

CASE 1: Healthcare - Depression Diagnosis
Generated healthcare data: 5000 samples
Depression prevalence: 0.300
Gender distribution:
gender
Male          0.461
Female        0.447
Non-binary    0.093
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
            accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
Male          0.7470          0.0281  0.0581  0.9821  0.0179  0.9419  0.5263  0.7534  0.1047  0.6563
Female        0.6834          0.0148  0.0143  0.9850  0.0150  0.9857  0.3000  0.6892  0.0273  0.6029
Non-binary    0.5338          0.0338  0.0294  0.9625  0.0375  0.9706  0.4000  0.5385  0.0548  0.6287
Group metrics saved to: fic_results_ALL_METRICS_E

In [None]:
#..................................................................................................#
#..........................        REAL-LIFE COMPAS        .........................................#
#...................................................................................................#

In [None]:
#.... First try

In [27]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "compas_fic_results"
os.makedirs(output_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS COMPAS DATASET
# ============================================

def load_compas_data():
    """
    Load and preprocess COMPAS ProPublica dataset
    Source: https://github.com/propublica/compas-analysis
    """
    # Try to load from local file first
    try:
        compas_df = pd.read_csv("compas-scores-two-years.csv")
        print("Loaded COMPAS dataset from local file")
    except:
        # If local file doesn't exist, download from GitHub
        print("Downloading COMPAS dataset from GitHub...")
        import requests
        url = "https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv"
        response = requests.get(url)
        with open("compas-scores-two-years.csv", "wb") as f:
            f.write(response.content)
        compas_df = pd.read_csv("compas-scores-two-years.csv")
        print("COMPAS dataset downloaded and loaded")
    
    # Basic preprocessing
    print(f"Original dataset shape: {compas_df.shape}")
    
    # Filter relevant columns
    relevant_columns = [
        'age', 'sex', 'race', 'priors_count', 'c_charge_degree',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count',
        'decile_score', 'two_year_recid'
    ]
    
    # Check which columns exist in the dataset
    available_columns = [col for col in relevant_columns if col in compas_df.columns]
    compas_df = compas_df[available_columns].copy()
    
    # Drop rows with missing values
    compas_df = compas_df.dropna()
    
    # Create high_risk target: 0-5 as low risk, 6-10 as high risk
    compas_df['high_risk'] = (compas_df['decile_score'] >= 6).astype(int)
    
    # Consolidate race categories
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'african' in race or 'black' in race:
            return 'African_American'
        elif 'caucasian' in race or 'white' in race:
            return 'Caucasian'
        elif 'hispanic' in race or 'latino' in race:
            return 'Hispanic'
        elif 'asian' in race or 'arab' in race or 'native' in race or 'other' in race:
            return 'Other_Race'
        else:
            return 'Other_Race'
    
    compas_df['race_group'] = compas_df['race'].apply(consolidate_race)
    
    # Filter to keep only our target race groups
    target_races = ['African_American', 'Caucasian', 'Hispanic', 'Other_Race']
    compas_df = compas_df[compas_df['race_group'].isin(target_races)].copy()
    
    # Create additional features for better prediction
    compas_df['total_juvenile_charges'] = compas_df['juv_fel_count'] + compas_df['juv_misd_count'] + compas_df['juv_other_count']
    compas_df['is_felony'] = (compas_df['c_charge_degree'] == 'F').astype(int)
    compas_df['age_group'] = pd.cut(compas_df['age'], 
                                     bins=[0, 25, 35, 45, 55, 100],
                                     labels=['18-25', '26-35', '36-45', '46-55', '56+'])
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_group', 'priors_count', 'is_felony',
        'total_juvenile_charges', 'age_group', 'high_risk'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in compas_df.columns]
    compas_df = compas_df[final_columns]
    
    print(f"Processed dataset shape: {compas_df.shape}")
    print(f"Target distribution (high_risk):")
    print(compas_df['high_risk'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(compas_df['race_group'].value_counts(normalize=True))
    
    return compas_df

def generate_compas_data(n_samples=None):
    """
    Wrapper function to load COMPAS data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_compas_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC (unchanged)
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. VISUALIZATIONS
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Colorbar
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('FIC Score', fontsize=14, fontweight='bold')
        cbar.ax.tick_params(labelsize=12)

        # Bold text inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.3f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), dpi=400, bbox_inches='tight')
    plt.close()

def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a more compact figure
        fig, ax = plt.subplots(figsize=(14, 8))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
                   alpha=0.7, label='Optimum (FIC > 0.75)')
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
                   alpha=0.7, label='Acceptable (FIC > 0.50)')
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
                   alpha=0.7, label='Unacceptable (FIC ≤ 0)')
        
        # Add value and tier labels on bars - more compact
        for bar, score, tier in zip(bars, fic_scores, tiers):
            height = bar.get_height()
            # Position text based on bar height
            if height >= 0:
                ax.text(bar.get_x() + bar.get_width()/2, height + 0.015,
                        f'{score:.2f}',
                        ha='center', va='bottom',
                        fontsize=10, fontweight='bold', color='black')
                # Add tier label at the bottom of positive bars
                ax.text(bar.get_x() + bar.get_width()/2, -0.05,
                        f'{tier[:3]}',  # Show first 3 letters of tier
                        ha='center', va='top',
                        fontsize=9, fontweight='bold', color='black',
                        rotation=90)
            else:
                ax.text(bar.get_x() + bar.get_width()/2, height - 0.03,
                        f'{score:.2f}',
                        ha='center', va='top',
                        fontsize=10, fontweight='bold', color='black')
        
        # Customize axes with better labels
        ax.set_xlabel('Group Pairs', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{dataset_name}\nFIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        # Shorten pair labels if they're too long
        shortened_pairs = []
        for pair in pairs:
            if len(pair) > 15:
                # Take first part of each group name
                g1, g2 = pair.split(' - ')
                g1_short = g1[:3] if len(g1) > 3 else g1
                g2_short = g2[:3] if len(g2) > 3 else g2
                shortened_pairs.append(f'{g1_short}-{g2_short}')
            else:
                shortened_pairs.append(pair)
        
        ax.set_xticklabels(shortened_pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set consistent y-axis limits
        ax.set_ylim(-0.25, 1.05)
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend
        # Create custom legend for tier colors
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable'),
            plt.Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold'),
            plt.Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold'),
            plt.Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold')
        ]
        
        # Place legend outside the plot
        ax.legend(handles=legend_elements, fontsize=10, 
                  loc='center left', bbox_to_anchor=(1.02, 0.5),
                  frameon=True, framealpha=0.9, edgecolor='black')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |metric₁ - metric₂|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        
        # Save the figure with alphaF in the filename
        plt.savefig(os.path.join(output_dir, f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'), 
                    dpi=400, bbox_inches='tight')
        plt.close()
        
        print(f"  Saved benchmarking tiers plot for alphaF={alphaF}")

# ============================================
# 5. ANALYSIS FUNCTIONS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv'))

    fic_results = fic_framework.analyze_fairness(baseline_metrics, 'accuracy')

    # FIC table
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"; row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("FIC ANALYSIS TABLE:")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, f'Case{case_number}_FIC_Analysis.csv'), index=False)

    # Tier classification
    tier_data = []
    print("TIER CLASSIFICATION:")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Tier_Classification.csv'), index=False)

    print("GENERATING VISUALIZATIONS...")
    plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}')
    plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}')

    # Model comparison
    print("MODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC alphaF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (alphaF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv'), index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

# ============================================
# 6. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET")
    print("="*80)

    compas_results = analyze_dataset(
        dataset_name="COMPAS - Recidivism Risk Prediction",
        data_generator=lambda: generate_compas_data(8000),
        target_col='high_risk',
        protected_col='race_group',
        case_number=1
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - COMPAS DATASET")
    print("="*80)

    print("COMPAS DATASET KEY FINDINGS:")
    print("-" * 60)
    data = compas_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk proportion: {data['high_risk'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_group'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh risk by race group:")
    for race in sorted(data['race_group'].unique()):
        subset = data[data['race_group'] == race]
        risk_prop = subset['high_risk'].mean()
        print(f"  {race}: {risk_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY:")
    print("-" * 60)
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in compas_results['fic_results'] and compas_results['fic_results'][af]:
            items = list(compas_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in compas_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)

    return compas_results

if __name__ == "__main__":
    # Check if dataset exists or download it
    compas_results = run_complete_analysis()

    print("\nAll analysis completed!")
    print(f"Results saved to: {output_dir}/")
    print("Files include:")
    print("  - Group metrics (CSV)")
    print("  - FIC analysis tables (CSV)")
    print("  - Tier classification (CSV)")
    print("  - Model comparison (CSV)")
    print("  - FIC heatmaps (PNG)")
    print("  - Benchmarking tiers for all alphaF values (PNG)")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET

CASE 1: COMPAS - Recidivism Risk Prediction
Loaded COMPAS dataset from local file
Original dataset shape: (7214, 53)
Processed dataset shape: (7214, 8)
Target distribution (high_risk):
high_risk
0    0.634599
1    0.365401
Name: proportion, dtype: float64

Race group distribution:
race_group
African_American    0.512337
Caucasian           0.340172
Hispanic            0.088301
Other_Race          0.059190
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
                  accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
African_American    0.6993          0.3501  0.5498  0.8460  0.1540  0.4502  0.7781  0.6568  0.6443  0.7555
Caucasian           0.7976          0.1787  0.4531  0.9139  0.0861  0.5469  0.6397  0.8320  0.5305  0.8210
Hispanic            0.8305          0.1525  0.4615  0.9348  0.0652  0.5385  0.6667  0.8600  0.5455  0.8443
Other_Race   

In [None]:
#.... Better visual and legends

In [33]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "compas_fic_results"
os.makedirs(output_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS COMPAS DATASET
# ============================================

def load_compas_data():
    """
    Load and preprocess COMPAS ProPublica dataset
    Source: https://github.com/propublica/compas-analysis
    """
    # Try to load from local file first
    try:
        compas_df = pd.read_csv("compas-scores-two-years.csv")
        print("Loaded COMPAS dataset from local file")
    except:
        # If local file doesn't exist, download from GitHub
        print("Downloading COMPAS dataset from GitHub...")
        import requests
        url = "https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv"
        response = requests.get(url)
        with open("compas-scores-two-years.csv", "wb") as f:
            f.write(response.content)
        compas_df = pd.read_csv("compas-scores-two-years.csv")
        print("COMPAS dataset downloaded and loaded")
    
    # Basic preprocessing
    print(f"Original dataset shape: {compas_df.shape}")
    
    # Filter relevant columns
    relevant_columns = [
        'age', 'sex', 'race', 'priors_count', 'c_charge_degree',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count',
        'decile_score', 'two_year_recid'
    ]
    
    # Check which columns exist in the dataset
    available_columns = [col for col in relevant_columns if col in compas_df.columns]
    compas_df = compas_df[available_columns].copy()
    
    # Drop rows with missing values
    compas_df = compas_df.dropna()
    
    # Create high_risk target: 0-5 as low risk, 6-10 as high risk
    compas_df['high_risk'] = (compas_df['decile_score'] >= 6).astype(int)
    
    # Consolidate race categories
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'african' in race or 'black' in race:
            return 'African_American'
        elif 'caucasian' in race or 'white' in race:
            return 'Caucasian'
        elif 'hispanic' in race or 'latino' in race:
            return 'Hispanic'
        elif 'asian' in race or 'arab' in race or 'native' in race or 'other' in race:
            return 'Other_Race'
        else:
            return 'Other_Race'
    
    compas_df['race_group'] = compas_df['race'].apply(consolidate_race)
    
    # Filter to keep only our target race groups
    target_races = ['African_American', 'Caucasian', 'Hispanic', 'Other_Race']
    compas_df = compas_df[compas_df['race_group'].isin(target_races)].copy()
    
    # Create additional features for better prediction
    compas_df['total_juvenile_charges'] = compas_df['juv_fel_count'] + compas_df['juv_misd_count'] + compas_df['juv_other_count']
    compas_df['is_felony'] = (compas_df['c_charge_degree'] == 'F').astype(int)
    compas_df['age_group'] = pd.cut(compas_df['age'], 
                                     bins=[0, 25, 35, 45, 55, 100],
                                     labels=['18-25', '26-35', '36-45', '46-55', '56+'])
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_group', 'priors_count', 'is_felony',
        'total_juvenile_charges', 'age_group', 'high_risk'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in compas_df.columns]
    compas_df = compas_df[final_columns]
    
    print(f"Processed dataset shape: {compas_df.shape}")
    print(f"Target distribution (high_risk):")
    print(compas_df['high_risk'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(compas_df['race_group'].value_counts(normalize=True))
    
    return compas_df

def generate_compas_data(n_samples=None):
    """
    Wrapper function to load COMPAS data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_compas_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC (unchanged)
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. VISUALIZATIONS
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells (optional - can comment out if too busy)
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=12, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add a single comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Add tier annotations on the colorbar
    cbar.ax.text(1.5, 0.875, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=11, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.5, 0.625, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=11, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.5, 0.375, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=11, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.5, 0.125, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=11, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2, xmax=0.8)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2, xmax=0.8)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2, xmax=0.8)

    plt.tight_layout(rect=[0, 0.03, 0.9, 0.95])
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), dpi=400, bbox_inches='tight')
    plt.close()

def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a more compact figure
        fig, ax = plt.subplots(figsize=(14, 8))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
       # ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
       #            alpha=0.7, label='Optimum Threshold')
       # ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
       #            alpha=0.7, label='Acceptable Threshold')
       # ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
       #            alpha=0.7, label='Unacceptable Threshold')
        
        # REMOVED value labels on bars as requested
        # Only show tier labels at the bottom
        for bar, tier in zip(bars, tiers):
            ax.text(bar.get_x() + bar.get_width()/2, -0.05,
                    tier[:4],  # Show first 4 letters of tier
                    ha='center', va='top',
                    fontsize=10, fontweight='bold', color='black',
                    rotation=45)
        
        # Customize axes with better labels
        ax.set_xlabel('Group Pairs', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{dataset_name}\nFIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set consistent y-axis limits
        ax.set_ylim(-0.25, 1.05)
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend
        # Create custom legend for tier colors
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable (FIC ≤ 0)'),
            plt.Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold'),
            plt.Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold'),
            plt.Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold')
        ]
        
        # Place legend outside the plot
        ax.legend(handles=legend_elements, fontsize=9, 
                  loc='center left', bbox_to_anchor=(1.02, 0.5),
                  frameon=True, framealpha=0.9, edgecolor='black',
                  title='FIC Tiers', title_fontsize=10)
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |metric₁ - metric₂|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        
        # Save the figure with alphaF in the filename
        plt.savefig(os.path.join(output_dir, f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'), 
                    dpi=400, bbox_inches='tight')
        plt.close()
        
        print(f"Saved benchmarking tiers plot for alphaF={alphaF}")

# ============================================
# 5. ANALYSIS FUNCTIONS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv'))

    fic_results = fic_framework.analyze_fairness(baseline_metrics, 'accuracy')

    # FIC table
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"; row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("FIC ANALYSIS TABLE:")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, f'Case{case_number}_FIC_Analysis.csv'), index=False)

    # Tier classification
    tier_data = []
    print("TIER CLASSIFICATION:")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Tier_Classification.csv'), index=False)

    print("GENERATING VISUALIZATIONS...")
    plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}')
    plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}')

    # Model comparison
    print("MODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC alphaF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (alphaF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv'), index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

# ============================================
# 6. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET")
    print("="*80)

    compas_results = analyze_dataset(
        dataset_name="COMPAS - Recidivism Risk Prediction",
        data_generator=lambda: generate_compas_data(8000),
        target_col='high_risk',
        protected_col='race_group',
        case_number=1
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - COMPAS DATASET")
    print("="*80)

    print("COMPAS DATASET KEY FINDINGS:")
    print("-" * 60)
    data = compas_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk proportion: {data['high_risk'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_group'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh risk by race group:")
    for race in sorted(data['race_group'].unique()):
        subset = data[data['race_group'] == race]
        risk_prop = subset['high_risk'].mean()
        print(f"  {race}: {risk_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY:")
    print("-" * 60)
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in compas_results['fic_results'] and compas_results['fic_results'][af]:
            items = list(compas_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in compas_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)

    return compas_results

if __name__ == "__main__":
    # Check if dataset exists or download it
    compas_results = run_complete_analysis()

    print("\nAll analysis completed!")
    print(f"Results saved to: {output_dir}/")
    print("Files include:")
    print("  - Group metrics (CSV)")
    print("  - FIC analysis tables (CSV)")
    print("  - Tier classification (CSV)")
    print("  - Model comparison (CSV)")
    print("  - FIC heatmaps (PNG)")
    print("  - Benchmarking tiers for all alphaF values (PNG)")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET

CASE 1: COMPAS - Recidivism Risk Prediction
Loaded COMPAS dataset from local file
Original dataset shape: (7214, 53)
Processed dataset shape: (7214, 8)
Target distribution (high_risk):
high_risk
0    0.634599
1    0.365401
Name: proportion, dtype: float64

Race group distribution:
race_group
African_American    0.512337
Caucasian           0.340172
Hispanic            0.088301
Other_Race          0.059190
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
                  accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
African_American    0.6993          0.3501  0.5498  0.8460  0.1540  0.4502  0.7781  0.6568  0.6443  0.7555
Caucasian           0.7976          0.1787  0.4531  0.9139  0.0861  0.5469  0.6397  0.8320  0.5305  0.8210
Hispanic            0.8305          0.1525  0.4615  0.9348  0.0652  0.5385  0.6667  0.8600  0.5455  0.8443
Other_Race   

In [37]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "compas_fic_results_BOLD"
os.makedirs(output_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS COMPAS DATASET
# ============================================

def load_compas_data():
    """
    Load and preprocess COMPAS ProPublica dataset
    Source: https://github.com/propublica/compas-analysis
    """
    # Try to load from local file first
    try:
        compas_df = pd.read_csv("compas-scores-two-years.csv")
        print("Loaded COMPAS dataset from local file")
    except:
        # If local file doesn't exist, download from GitHub
        print("Downloading COMPAS dataset from GitHub...")
        import requests
        url = "https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv"
        response = requests.get(url)
        with open("compas-scores-two-years.csv", "wb") as f:
            f.write(response.content)
        compas_df = pd.read_csv("compas-scores-two-years.csv")
        print("COMPAS dataset downloaded and loaded")
    
    # Basic preprocessing
    print(f"Original dataset shape: {compas_df.shape}")
    
    # Filter relevant columns
    relevant_columns = [
        'age', 'sex', 'race', 'priors_count', 'c_charge_degree',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count',
        'decile_score', 'two_year_recid'
    ]
    
    # Check which columns exist in the dataset
    available_columns = [col for col in relevant_columns if col in compas_df.columns]
    compas_df = compas_df[available_columns].copy()
    
    # Drop rows with missing values
    compas_df = compas_df.dropna()
    
    # Create high_risk target: 0-5 as low risk, 6-10 as high risk
    compas_df['high_risk'] = (compas_df['decile_score'] >= 6).astype(int)
    
    # Consolidate race categories
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'african' in race or 'black' in race:
            return 'African_American'
        elif 'caucasian' in race or 'white' in race:
            return 'Caucasian'
        elif 'hispanic' in race or 'latino' in race:
            return 'Hispanic'
        elif 'asian' in race or 'arab' in race or 'native' in race or 'other' in race:
            return 'Other_Race'
        else:
            return 'Other_Race'
    
    compas_df['race_group'] = compas_df['race'].apply(consolidate_race)
    
    # Filter to keep only our target race groups
    target_races = ['African_American', 'Caucasian', 'Hispanic', 'Other_Race']
    compas_df = compas_df[compas_df['race_group'].isin(target_races)].copy()
    
    # Create additional features for better prediction
    compas_df['total_juvenile_charges'] = compas_df['juv_fel_count'] + compas_df['juv_misd_count'] + compas_df['juv_other_count']
    compas_df['is_felony'] = (compas_df['c_charge_degree'] == 'F').astype(int)
    compas_df['age_group'] = pd.cut(compas_df['age'], 
                                     bins=[0, 25, 35, 45, 55, 100],
                                     labels=['18-25', '26-35', '36-45', '46-55', '56+'])
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_group', 'priors_count', 'is_felony',
        'total_juvenile_charges', 'age_group', 'high_risk'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in compas_df.columns]
    compas_df = compas_df[final_columns]
    
    print(f"Processed dataset shape: {compas_df.shape}")
    print(f"Target distribution (high_risk):")
    print(compas_df['high_risk'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(compas_df['race_group'].value_counts(normalize=True))
    
    return compas_df

def generate_compas_data(n_samples=None):
    """
    Wrapper function to load COMPAS data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_compas_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC (unchanged)
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. VISUALIZATIONS
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add a single comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold the colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on the colorbar
    cbar.ax.text(1.5, 0.875, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=11, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.5, 0.625, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=11, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.5, 0.375, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=11, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.5, 0.125, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=12, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.8)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.8)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.8)

    plt.tight_layout(rect=[0, 0.03, 0.9, 0.95])
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), dpi=400, bbox_inches='tight')
    plt.close()

def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a figure with more width to accommodate legend
        fig, ax = plt.subplots(figsize=(16, 8))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        
        # Customize axes with better labels
        ax.set_xlabel('Group Pairs', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{dataset_name}\nFIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set consistent y-axis limits
        ax.set_ylim(-0.25, 1.05)
        
        # Bold the y-axis tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend - moved to top right with fewer items
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable (FIC ≤ 0)')
        ]
        
        # Create a separate legend for threshold lines
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend at upper left - FIXED: removed title_fontweight
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.02, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        # Make the legend title bold
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend at upper left below tier legend - FIXED: removed title_fontweight
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.02, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        # Make the legend title bold
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |r"$M_1 - M_2$"|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        
        # Save the figure with alphaF in the filename
        plt.savefig(os.path.join(output_dir, f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'), 
                    dpi=400, bbox_inches='tight')
        plt.close()
        
        print(f"Saved benchmarking tiers plot for alphaF={alphaF}")

# ============================================
# 5. ANALYSIS FUNCTIONS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv'))

    fic_results = fic_framework.analyze_fairness(baseline_metrics, 'accuracy')

    # FIC table
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"; row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("FIC ANALYSIS TABLE:")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, f'Case{case_number}_FIC_Analysis.csv'), index=False)

    # Tier classification
    tier_data = []
    print("TIER CLASSIFICATION:")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Tier_Classification.csv'), index=False)

    print("GENERATING VISUALIZATIONS...")
    plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}')
    plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}')

    # Model comparison
    print("MODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC alphaF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (alphaF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv'), index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

# ============================================
# 6. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET")
    print("="*80)

    compas_results = analyze_dataset(
        dataset_name="COMPAS - Recidivism Risk Prediction",
        data_generator=lambda: generate_compas_data(8000),
        target_col='high_risk',
        protected_col='race_group',
        case_number=1
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - COMPAS DATASET")
    print("="*80)

    print("COMPAS DATASET KEY FINDINGS:")
    print("-" * 60)
    data = compas_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk proportion: {data['high_risk'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_group'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh risk by race group:")
    for race in sorted(data['race_group'].unique()):
        subset = data[data['race_group'] == race]
        risk_prop = subset['high_risk'].mean()
        print(f"  {race}: {risk_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY:")
    print("-" * 60)
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in compas_results['fic_results'] and compas_results['fic_results'][af]:
            items = list(compas_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in compas_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)

    return compas_results

if __name__ == "__main__":
    # Check if dataset exists or download it
    compas_results = run_complete_analysis()

    print("\nAll analysis completed!")
    print(f"Results saved to: {output_dir}/")
    print("Files include:")
    print("  - Group metrics (CSV)")
    print("  - FIC analysis tables (CSV)")
    print("  - Tier classification (CSV)")
    print("  - Model comparison (CSV)")
    print("  - FIC heatmaps (PNG)")
    print("  - Benchmarking tiers for all alphaF values (PNG)")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET

CASE 1: COMPAS - Recidivism Risk Prediction
Loaded COMPAS dataset from local file
Original dataset shape: (7214, 53)
Processed dataset shape: (7214, 8)
Target distribution (high_risk):
high_risk
0    0.634599
1    0.365401
Name: proportion, dtype: float64

Race group distribution:
race_group
African_American    0.512337
Caucasian           0.340172
Hispanic            0.088301
Other_Race          0.059190
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
                  accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
African_American    0.6993          0.3501  0.5498  0.8460  0.1540  0.4502  0.7781  0.6568  0.6443  0.7555
Caucasian           0.7976          0.1787  0.4531  0.9139  0.0861  0.5469  0.6397  0.8320  0.5305  0.8210
Hispanic            0.8305          0.1525  0.4615  0.9348  0.0652  0.5385  0.6667  0.8600  0.5455  0.8443
Other_Race   

In [None]:
#.... Compact and Well

'c:\\Users\\Dr. Akin\\OneDrive\\2025\\Paper_2025\\PHD_Work'

In [57]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "compas_fic_results_NLEGEND"
os.makedirs(output_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS COMPAS DATASET
# ============================================

def load_compas_data():
    """
    Load and preprocess COMPAS ProPublica dataset from local folder
    """
    # Define your specific folder path
    data_folder = r'C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work'
    data_file = "compas-scores-two-years.csv"
    data_path = os.path.join(data_folder, data_file)
    
    print(f"Looking for COMPAS dataset at: {data_path}")
    
    # Try to load from your specified folder
    compas_df = pd.read_csv(data_path)
    print("Loaded COMPAS dataset from specified folder")
   
    
    # Filter relevant columns
    relevant_columns = [
        'age', 'sex', 'race', 'priors_count', 'c_charge_degree',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count',
        'decile_score', 'two_year_recid'
    ]
    
    # Check which columns exist in the dataset
    available_columns = [col for col in relevant_columns if col in compas_df.columns]
    compas_df = compas_df[available_columns].copy()
    
    # Drop rows with missing values
    compas_df = compas_df.dropna()
    
    # Create high_risk target: 0-5 as low risk, 6-10 as high risk
    compas_df['high_risk'] = (compas_df['decile_score'] >= 6).astype(int)
    
    # Consolidate race categories
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'african' in race or 'black' in race:
            return 'African_American'
        elif 'caucasian' in race or 'white' in race:
            return 'Caucasian'
        elif 'hispanic' in race or 'latino' in race:
            return 'Hispanic'
        elif 'asian' in race or 'arab' in race or 'native' in race or 'other' in race:
            return 'Other_Race'
        else:
            return 'Other_Race'
    
    compas_df['race_group'] = compas_df['race'].apply(consolidate_race)
    
    # Filter to keep only our target race groups
    target_races = ['African_American', 'Caucasian', 'Hispanic', 'Other_Race']
    compas_df = compas_df[compas_df['race_group'].isin(target_races)].copy()
    
    # Create additional features for better prediction
    compas_df['total_juvenile_charges'] = compas_df['juv_fel_count'] + compas_df['juv_misd_count'] + compas_df['juv_other_count']
    compas_df['is_felony'] = (compas_df['c_charge_degree'] == 'F').astype(int)
    compas_df['age_group'] = pd.cut(compas_df['age'], 
                                     bins=[0, 25, 35, 45, 55, 100],
                                     labels=['18-25', '26-35', '36-45', '46-55', '56+'])
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_group', 'priors_count', 'is_felony',
        'total_juvenile_charges', 'age_group', 'high_risk'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in compas_df.columns]
    compas_df = compas_df[final_columns]
    
    print(f"Processed dataset shape: {compas_df.shape}")
    print(f"Target distribution (high_risk):")
    print(compas_df['high_risk'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(compas_df['race_group'].value_counts(normalize=True))
    
    return compas_df

def generate_compas_data(n_samples=None):
    """
    Wrapper function to load COMPAS data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_compas_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. VISUALIZATIONS
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add a single comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.75, 0.15, 0.03, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold the colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on the colorbar
    cbar.ax.text(1.1, 0.90, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.1, 0.60, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.1, 0.350, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.1, 0.100, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.8)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.8)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.8)

    plt.tight_layout(rect=[0, 0.03, 0.8, 0.95])
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), dpi=400, bbox_inches='tight')
    plt.close()


def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a figure with more width to accommodate legend
        fig, ax = plt.subplots(figsize=(16, 8))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Find max positive and max negative values
        max_positive = max(fic_scores) if fic_scores else 1.0
        min_negative = min(fic_scores) if fic_scores else -0.25
        
        # Add padding (10% on positive side, 10% on negative side)
        y_max = max_positive * 1.10 if max_positive > 0 else 0.10
        y_min = min_negative * 1.10 if min_negative < 0 else -0.10
        
        # Ensure at least some range for visualization
        if y_max - y_min < 0.5:
            # If range is too small, center it around the data
            center = (max_positive + min_negative) / 2
            y_max = center + 0.25
            y_min = center - 0.25
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        
        # Customize axes with better labels
        ax.set_xlabel('Group Pairs', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{dataset_name}\nFIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set dynamic y-axis limits based on actual max positive and max negative - CHANGED
        ax.set_ylim(y_min, y_max)
        
        # Bold the y-axis tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend - moved to top right with fewer items
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable (FIC ≤ 0)')
        ]
        
        # Create a separate legend for threshold lines
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend at upper left
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.02, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        # Make the legend title bold
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend at upper left below tier legend
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.02, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        # Make the legend title bold
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |M₁ - M₂|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        
        # Save the figure with alphaF in the filename
        plt.savefig(os.path.join(output_dir, f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'), 
                    dpi=400, bbox_inches='tight')
        plt.close()
        
        print(f"Saved benchmarking tiers plot for alphaF={alphaF}")


# ============================================
# 5. ANALYSIS FUNCTIONS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv'))

    fic_results = fic_framework.analyze_fairness(baseline_metrics, 'accuracy')

    # FIC table
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"; row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("FIC ANALYSIS TABLE:")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, f'Case{case_number}_FIC_Analysis.csv'), index=False)

    # Tier classification
    tier_data = []
    print("TIER CLASSIFICATION:")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Tier_Classification.csv'), index=False)

    print("GENERATING VISUALIZATIONS...")
    plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}')
    plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}')

    # Model comparison
    print("MODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC alphaF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (alphaF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv'), index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

# ============================================
# 6. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET")
    print("="*80)

    compas_results = analyze_dataset(
        dataset_name="COMPAS - Recidivism Risk Prediction",
        data_generator=lambda: generate_compas_data(8000),
        target_col='high_risk',
        protected_col='race_group',
        case_number=1
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - COMPAS DATASET")
    print("="*80)

    print("COMPAS DATASET KEY FINDINGS:")
    print("-" * 60)
    data = compas_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk proportion: {data['high_risk'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_group'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh risk by race group:")
    for race in sorted(data['race_group'].unique()):
        subset = data[data['race_group'] == race]
        risk_prop = subset['high_risk'].mean()
        print(f"  {race}: {risk_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY:")
    print("-" * 60)
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in compas_results['fic_results'] and compas_results['fic_results'][af]:
            items = list(compas_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in compas_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)

    return compas_results

if __name__ == "__main__":
    # Check if dataset exists or download it
    compas_results = run_complete_analysis()

    print("\nAll analysis completed!")
    print(f"Results saved to: {output_dir}/")
    print("Files include:")
    print("  - Group metrics (CSV)")
    print("  - FIC analysis tables (CSV)")
    print("  - Tier classification (CSV)")
    print("  - Model comparison (CSV)")
    print("  - FIC heatmaps (PNG)")
    print("  - Benchmarking tiers for all alphaF values (PNG)")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET

CASE 1: COMPAS - Recidivism Risk Prediction
Looking for COMPAS dataset at: C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work\compas-scores-two-years.csv
Loaded COMPAS dataset from specified folder
Processed dataset shape: (7214, 8)
Target distribution (high_risk):
high_risk
0    0.634599
1    0.365401
Name: proportion, dtype: float64

Race group distribution:
race_group
African_American    0.512337
Caucasian           0.340172
Hispanic            0.088301
Other_Race          0.059190
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
                  accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
African_American    0.6993          0.3501  0.5498  0.8460  0.1540  0.4502  0.7781  0.6568  0.6443  0.7555
Caucasian           0.7976          0.1787  0.4531  0.9139  0.0861  0.5469  0.6397  0.8320  0.5305  0.8210
Hispanic            0.8305          0.

In [None]:
#..... COMPLETED WITH ALL METRICS PLOTS

In [58]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "compas_fic_results_NLEGEND_ALL_METRICS"
os.makedirs(output_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS COMPAS DATASET
# ============================================

def load_compas_data():
    """
    Load and preprocess COMPAS ProPublica dataset from local folder
    """
    # Define your specific folder path
    data_folder = r'C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work'
    data_file = "compas-scores-two-years.csv"
    data_path = os.path.join(data_folder, data_file)
    
    print(f"Looking for COMPAS dataset at: {data_path}")
    
    # Try to load from your specified folder
    compas_df = pd.read_csv(data_path)
    print("Loaded COMPAS dataset from specified folder")
   
    
    # Filter relevant columns
    relevant_columns = [
        'age', 'sex', 'race', 'priors_count', 'c_charge_degree',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count',
        'decile_score', 'two_year_recid'
    ]
    
    # Check which columns exist in the dataset
    available_columns = [col for col in relevant_columns if col in compas_df.columns]
    compas_df = compas_df[available_columns].copy()
    
    # Drop rows with missing values
    compas_df = compas_df.dropna()
    
    # Create high_risk target: 0-5 as low risk, 6-10 as high risk
    compas_df['high_risk'] = (compas_df['decile_score'] >= 6).astype(int)
    
    # Consolidate race categories
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'african' in race or 'black' in race:
            return 'African_American'
        elif 'caucasian' in race or 'white' in race:
            return 'Caucasian'
        elif 'hispanic' in race or 'latino' in race:
            return 'Hispanic'
        elif 'asian' in race or 'arab' in race or 'native' in race or 'other' in race:
            return 'Other_Race'
        else:
            return 'Other_Race'
    
    compas_df['race_group'] = compas_df['race'].apply(consolidate_race)
    
    # Filter to keep only our target race groups
    target_races = ['African_American', 'Caucasian', 'Hispanic', 'Other_Race']
    compas_df = compas_df[compas_df['race_group'].isin(target_races)].copy()
    
    # Create additional features for better prediction
    compas_df['total_juvenile_charges'] = compas_df['juv_fel_count'] + compas_df['juv_misd_count'] + compas_df['juv_other_count']
    compas_df['is_felony'] = (compas_df['c_charge_degree'] == 'F').astype(int)
    compas_df['age_group'] = pd.cut(compas_df['age'], 
                                     bins=[0, 25, 35, 45, 55, 100],
                                     labels=['18-25', '26-35', '36-45', '46-55', '56+'])
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_group', 'priors_count', 'is_felony',
        'total_juvenile_charges', 'age_group', 'high_risk'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in compas_df.columns]
    compas_df = compas_df[final_columns]
    
    print(f"Processed dataset shape: {compas_df.shape}")
    print(f"Target distribution (high_risk):")
    print(compas_df['high_risk'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(compas_df['race_group'].value_counts(normalize=True))
    
    return compas_df

def generate_compas_data(n_samples=None):
    """
    Wrapper function to load COMPAS data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_compas_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. VISUALIZATIONS - UPDATED FOR ALL METRICS
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'{dataset_name}: FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add a single comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.75, 0.15, 0.03, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold the colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on the colorbar
    cbar.ax.text(1.1, 0.90, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.1, 0.60, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.1, 0.350, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.1, 0.100, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.8)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.8)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.8)

    plt.tight_layout(rect=[0, 0.03, 0.8, 0.95])
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), dpi=400, bbox_inches='tight')
    plt.close()


def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a figure with more width to accommodate legend
        fig, ax = plt.subplots(figsize=(16, 8))
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Find max positive and max negative values
        max_positive = max(fic_scores) if fic_scores else 1.0
        min_negative = min(fic_scores) if fic_scores else -0.25
        
        # Add padding (10% on positive side, 10% on negative side)
        y_max = max_positive * 1.10 if max_positive > 0 else 0.10
        y_min = min_negative * 1.10 if min_negative < 0 else -0.10
        
        # Ensure at least some range for visualization
        if y_max - y_min < 0.5:
            # If range is too small, center it around the data
            center = (max_positive + min_negative) / 2
            y_max = center + 0.25
            y_min = center - 0.25
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        
        # Customize axes with better labels
        ax.set_xlabel('Group Pairs', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{dataset_name}\nFIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set dynamic y-axis limits based on actual max positive and max negative
        ax.set_ylim(y_min, y_max)
        
        # Bold the y-axis tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend - moved to top right with fewer items
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable (FIC ≤ 0)')
        ]
        
        # Create a separate legend for threshold lines
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend at upper left
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.02, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        # Make the legend title bold
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend at upper left below tier legend
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.02, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        # Make the legend title bold
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |M₁ - M₂|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        
        # Save the figure with alphaF in the filename
        plt.savefig(os.path.join(output_dir, f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'), 
                    dpi=400, bbox_inches='tight')
        plt.close()
        
        print(f"  Saved benchmarking tiers plot for alphaF={alphaF} ({metric})")

# ============================================
# 5. ANALYSIS FUNCTIONS - UPDATED FOR ALL METRICS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv'))

    print("\nGENERATING VISUALIZATIONS FOR ALL METRICS...")
    
    # List of all metrics to analyze
    all_metrics = ['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']
    
    # Dictionary to store all FIC results
    all_fic_results = {}
    
    for metric in all_metrics:
        print(f"\n{'='*60}")
        print(f"ANALYZING METRIC: {metric.upper()}")
        print(f"{'='*60}")
        
        # Analyze fairness for this metric
        fic_results = fic_framework.analyze_fairness(baseline_metrics, metric)
        all_fic_results[metric] = fic_results
        
        # Generate heatmaps for this metric
        plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Generate benchmarking tiers for this metric
        plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Print summary for this metric
        print(f"Summary for {metric}:")
        for af in fic_framework.alphaF_values:
            if af in fic_results and fic_results[af]:
                omegas = [d['omega'] for d in fic_results[af].values()]
                max_o = max(omegas)
                avg_o = np.mean(omegas)
                fic = FairnessInformationCriterion()
                tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                for d in fic_results[af].values():
                    tiers[fic.classify_tier(d['fic_score'])] += 1
                print(f"  αF={af}: ω_max={max_o:.4f}, ω_avg={avg_o:.4f}, Tiers={tiers}")

    # Store FIC results for accuracy (original metric) for backward compatibility
    fic_results = all_fic_results['accuracy']
    
    # FIC table for accuracy (original)
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"; row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("\nFIC ANALYSIS TABLE (Accuracy):")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, f'Case{case_number}_FIC_Analysis_accuracy.csv'), index=False)

    # Tier classification for accuracy (original)
    tier_data = []
    print("\nTIER CLASSIFICATION (Accuracy):")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Tier_Classification_accuracy.csv'), index=False)

    # Model comparison
    print("\nMODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC alphaF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (alphaF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv'), index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'all_fic_results': all_fic_results,  # Store all metrics results
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

# ============================================
# 6. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET")
    print("="*80)

    compas_results = analyze_dataset(
        dataset_name="COMPAS - Recidivism Risk Prediction",
        data_generator=lambda: generate_compas_data(8000),
        target_col='high_risk',
        protected_col='race_group',
        case_number=1
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - COMPAS DATASET")
    print("="*80)

    print("COMPAS DATASET KEY FINDINGS:")
    print("-" * 60)
    data = compas_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk proportion: {data['high_risk'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_group'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh risk by race group:")
    for race in sorted(data['race_group'].unique()):
        subset = data[data['race_group'] == race]
        risk_prop = subset['high_risk'].mean()
        print(f"  {race}: {risk_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY (Accuracy):")
    print("-" * 60)
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in compas_results['fic_results'] and compas_results['fic_results'][af]:
            items = list(compas_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in compas_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)
    print(f"Generated plots for all metrics: accuracy, selection_rate, tpr, tnr, fpr, fnr, ppv, npv, f1, auc")
    print(f"Each metric has:")
    print(f"  - 1 heatmap figure (2x2 grid for all alphaF values)")
    print(f"  - 4 benchmarking tier plots (one for each alphaF: 0.05, 0.10, 0.15, 0.20)")

    return compas_results

if __name__ == "__main__":
    # Check if dataset exists or download it
    compas_results = run_complete_analysis()

    print("\nAll analysis completed!")
    print(f"Results saved to: {output_dir}/")
    print("Files include:")
    print("  - Group metrics (CSV)")
    print("  - FIC analysis tables for accuracy (CSV)")
    print("  - Tier classification for accuracy (CSV)")
    print("  - Model comparison (CSV)")
    print("  - FIC heatmaps for ALL 10 metrics (PNG)")
    print("  - Benchmarking tiers for ALL 10 metrics (4 plots per metric = 40 PNG files)")
    print(f"\nTotal plots generated: {10 + 40} files")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET

CASE 1: COMPAS - Recidivism Risk Prediction
Looking for COMPAS dataset at: C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work\compas-scores-two-years.csv
Loaded COMPAS dataset from specified folder
Processed dataset shape: (7214, 8)
Target distribution (high_risk):
high_risk
0    0.634599
1    0.365401
Name: proportion, dtype: float64

Race group distribution:
race_group
African_American    0.512337
Caucasian           0.340172
Hispanic            0.088301
Other_Race          0.059190
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
                  accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
African_American    0.6993          0.3501  0.5498  0.8460  0.1540  0.4502  0.7781  0.6568  0.6443  0.7555
Caucasian           0.7976          0.1787  0.4531  0.9139  0.0861  0.5469  0.6397  0.8320  0.5305  0.8210
Hispanic            0.8305          0.

In [67]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "compas_fic_results_NLEGEND_ALL_METRICS_PDF"
os.makedirs(output_dir, exist_ok=True)

# Also create PDF subdirectory
pdf_dir = os.path.join(output_dir, "PDF_plots")
os.makedirs(pdf_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS COMPAS DATASET
# ============================================

def load_compas_data():
    """
    Load and preprocess COMPAS ProPublica dataset from local folder
    """
    # Define your specific folder path
    data_folder = r'C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work'
    data_file = "compas-scores-two-years.csv"
    data_path = os.path.join(data_folder, data_file)
    
    print(f"Looking for COMPAS dataset at: {data_path}")
    
    # Try to load from your specified folder
    compas_df = pd.read_csv(data_path)
    print("Loaded COMPAS dataset from specified folder")
   
    
    # Filter relevant columns
    relevant_columns = [
        'age', 'sex', 'race', 'priors_count', 'c_charge_degree',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count',
        'decile_score', 'two_year_recid'
    ]
    
    # Check which columns exist in the dataset
    available_columns = [col for col in relevant_columns if col in compas_df.columns]
    compas_df = compas_df[available_columns].copy()
    
    # Drop rows with missing values
    compas_df = compas_df.dropna()
    
    # Create high_risk target: 0-5 as low risk, 6-10 as high risk
    compas_df['high_risk'] = (compas_df['decile_score'] >= 6).astype(int)
    
    # Consolidate race categories
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'african' in race or 'black' in race:
            return 'African_American'
        elif 'caucasian' in race or 'white' in race:
            return 'Caucasian'
        elif 'hispanic' in race or 'latino' in race:
            return 'Hispanic'
        elif 'asian' in race or 'arab' in race or 'native' in race or 'other' in race:
            return 'Other_Race'
        else:
            return 'Other_Race'
    
    compas_df['race_group'] = compas_df['race'].apply(consolidate_race)
    
    # Filter to keep only our target race groups
    target_races = ['African_American', 'Caucasian', 'Hispanic', 'Other_Race']
    compas_df = compas_df[compas_df['race_group'].isin(target_races)].copy()
    
    # Create additional features for better prediction
    compas_df['total_juvenile_charges'] = compas_df['juv_fel_count'] + compas_df['juv_misd_count'] + compas_df['juv_other_count']
    compas_df['is_felony'] = (compas_df['c_charge_degree'] == 'F').astype(int)
    compas_df['age_group'] = pd.cut(compas_df['age'], 
                                     bins=[0, 25, 35, 45, 55, 100],
                                     labels=['18-25', '26-35', '36-45', '46-55', '56+'])
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_group', 'priors_count', 'is_felony',
        'total_juvenile_charges', 'age_group', 'high_risk'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in compas_df.columns]
    compas_df = compas_df[final_columns]
    
    print(f"Processed dataset shape: {compas_df.shape}")
    print(f"Target distribution (high_risk):")
    print(compas_df['high_risk'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(compas_df['race_group'].value_counts(normalize=True))
    
    return compas_df

def generate_compas_data(n_samples=None):
    """
    Wrapper function to load COMPAS data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_compas_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. VISUALIZATIONS - UPDATED FOR PDF AND EXPANDED LEGEND
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add a single comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.78, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold the colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on the colorbar with more space
    cbar.ax.text(1.6, 0.90, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.6, 0.60, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.6, 0.350, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.6, 0.100, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.6)

    plt.tight_layout(rect=[0, 0.03, 0.78, 0.95])
    
    # Save as PNG
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), 
                dpi=400, bbox_inches='tight')
    # Save as PDF
    plt.savefig(os.path.join(pdf_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.pdf'), 
                format='pdf', bbox_inches='tight')
    plt.close()


def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a figure with EXPANDED width to prevent legend cutoff
        fig, ax = plt.subplots(figsize=(20, 8))  # Increased width from 16 to 20
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Find max positive and max negative values
        max_positive = max(fic_scores) if fic_scores else 1.0
        min_negative = min(fic_scores) if fic_scores else -0.25
        
        # Add padding (10% on positive side, 10% on negative side)
        y_max = max_positive * 1.10 if max_positive > 0 else 0.10
        y_min = min_negative * 1.10 if min_negative < 0 else -0.10
        
        # Ensure at least some range for visualization
        if y_max - y_min < 0.5:
            # If range is too small, center it around the data
            center = (max_positive + min_negative) / 2
            y_max = center + 0.25
            y_min = center - 0.25
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        
        # Customize axes with better labels
        ax.set_xlabel('Inter-Group', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'FIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set dynamic y-axis limits based on actual max positive and max negative
        ax.set_ylim(y_min, y_max)
        
        # Bold the y-axis tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend - moved to top right with fewer items
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable (FIC ≤ 0)')
        ]
        
        
        # Create a separate legend for threshold lines
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend at upper left - MORE SPACE with bbox_to_anchor
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.05, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        # Make the legend title bold
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend at upper left below tier legend - MORE SPACE
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.05, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        # Make the legend title bold
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |$M₁ - M₂$|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend - MORE SPACE allocated
        plt.tight_layout(rect=[0, 0, 0.80, 1])  # Changed from 0.85 to 0.80 for more legend space
        
        # Save the figure with alphaF in the filename - BOTH PNG AND PDF
        png_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'
        pdf_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.pdf'
        
        plt.savefig(os.path.join(output_dir, png_filename), 
                    dpi=400, bbox_inches='tight')
        plt.savefig(os.path.join(pdf_dir, pdf_filename), 
                    format='pdf', bbox_inches='tight')
        plt.close()
        
        print(f"  Saved benchmarking tiers plot for alphaF={alphaF} ({metric})")


# ============================================
# 5. ANALYSIS FUNCTIONS - UPDATED FOR ALL METRICS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    metrics_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv'))

    print("\nGENERATING VISUALIZATIONS FOR ALL METRICS...")
    
    # List of all metrics to analyze
    all_metrics = ['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']
    
    # Dictionary to store all FIC results
    all_fic_results = {}
    
    for metric in all_metrics:
        print(f"\n{'='*60}")
        print(f"ANALYZING METRIC: {metric.upper()}")
        print(f"{'='*60}")
        
        # Analyze fairness for this metric
        fic_results = fic_framework.analyze_fairness(baseline_metrics, metric)
        all_fic_results[metric] = fic_results
        
        # Generate heatmaps for this metric
        plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Generate benchmarking tiers for this metric
        plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Print summary for this metric
        print(f"Summary for {metric}:")
        for af in fic_framework.alphaF_values:
            if af in fic_results and fic_results[af]:
                omegas = [d['omega'] for d in fic_results[af].values()]
                max_o = max(omegas)
                avg_o = np.mean(omegas)
                fic = FairnessInformationCriterion()
                tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                for d in fic_results[af].values():
                    tiers[fic.classify_tier(d['fic_score'])] += 1
                print(f"  αF={af}: ω_max={max_o:.4f}, ω_avg={avg_o:.4f}, Tiers={tiers}")

    # Store FIC results for accuracy (original metric) for backward compatibility
    fic_results = all_fic_results['accuracy']
    
    # FIC table for accuracy (original)
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"; row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("\nFIC ANALYSIS TABLE (Accuracy):")
    print(fic_df.to_string(index=False))
    fic_df.to_csv(os.path.join(output_dir, f'Case{case_number}_FIC_Analysis_accuracy.csv'), index=False)

    # Tier classification for accuracy (original)
    tier_data = []
    print("\nTIER CLASSIFICATION (Accuracy):")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    tier_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Tier_Classification_accuracy.csv'), index=False)

    # Model comparison
    print("\nMODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC alphaF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (alphaF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    comparison_df.to_csv(os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv'), index=False)

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'all_fic_results': all_fic_results,  # Store all metrics results
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df
    }

# ============================================
# 6. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET")
    print("="*80)

    compas_results = analyze_dataset(
        dataset_name="COMPAS - Recidivism Risk Prediction",
        data_generator=lambda: generate_compas_data(8000),
        target_col='high_risk',
        protected_col='race_group',
        case_number=1
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - COMPAS DATASET")
    print("="*80)

    print("COMPAS DATASET KEY FINDINGS:")
    print("-" * 60)
    data = compas_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk proportion: {data['high_risk'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_group'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh risk by race group:")
    for race in sorted(data['race_group'].unique()):
        subset = data[data['race_group'] == race]
        risk_prop = subset['high_risk'].mean()
        print(f"  {race}: {risk_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY (Accuracy):")
    print("-" * 60)
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in compas_results['fic_results'] and compas_results['fic_results'][af]:
            items = list(compas_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in compas_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - HIGH-QUALITY PLOTS SAVED")
    print("="*80)
    print(f"Generated plots for all metrics: accuracy, selection_rate, tpr, tnr, fpr, fnr, ppv, npv, f1, auc")
    print(f"Each metric has:")
    print(f"  - 1 heatmap figure (2x2 grid for all alphaF values)")
    print(f"  - 4 benchmarking tier plots (one for each alphaF: 0.05, 0.10, 0.15, 0.20)")
    print(f"\nAll plots saved in both PNG and PDF formats.")
    print(f"PNG files saved in: {output_dir}/")
    print(f"PDF files saved in: {pdf_dir}/")

    return compas_results

if __name__ == "__main__":
    # Check if dataset exists or download it
    compas_results = run_complete_analysis()

    print("\nAll analysis completed!")
    print(f"Results saved to: {output_dir}/")
    print(f"PDF files saved to: {pdf_dir}/")
    print("\nFiles include:")
    print("  - Group metrics (CSV)")
    print("  - FIC analysis tables for accuracy (CSV)")
    print("  - Tier classification for accuracy (CSV)")
    print("  - Model comparison (CSV)")
    print("  - FIC heatmaps for ALL 10 metrics (PNG + PDF)")
    print("  - Benchmarking tiers for ALL 10 metrics (4 plots per metric = 40 PNG + 40 PDF files)")
    print(f"Total plots generated: {10 + 40} PNG files + {10 + 40} PDF files = {100} total files")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET

CASE 1: COMPAS - Recidivism Risk Prediction
Looking for COMPAS dataset at: C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work\compas-scores-two-years.csv
Loaded COMPAS dataset from specified folder
Processed dataset shape: (7214, 8)
Target distribution (high_risk):
high_risk
0    0.634599
1    0.365401
Name: proportion, dtype: float64

Race group distribution:
race_group
African_American    0.512337
Caucasian           0.340172
Hispanic            0.088301
Other_Race          0.059190
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
                  accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
African_American    0.6993          0.3501  0.5498  0.8460  0.1540  0.4502  0.7781  0.6568  0.6443  0.7555
Caucasian           0.7976          0.1787  0.4531  0.9139  0.0861  0.5469  0.6397  0.8320  0.5305  0.8210
Hispanic            0.8305          0.

In [None]:
#.... COMPLETED ALL METRICS PLOTS AND POINT FAIRNESS HYPOTHESIS

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.styles import PatternFill, Font, Alignment, Border, Side

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "Compas_NLEGEND_ALL_METRICS_PDF_EXCEL"
os.makedirs(output_dir, exist_ok=True)

# Also create PDF subdirectory
pdf_dir = os.path.join(output_dir, "PDF_plots")
os.makedirs(pdf_dir, exist_ok=True)

# Create Excel subdirectory
excel_dir = os.path.join(output_dir, "Excel_results")
os.makedirs(excel_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS COMPAS DATASET
# ============================================

def load_compas_data():
    """
    Load and preprocess COMPAS ProPublica dataset from local folder
    """
    # Define your specific folder path
    data_folder = r'C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work'
    data_file = "compas-scores-two-years.csv"
    data_path = os.path.join(data_folder, data_file)
    
    print(f"Looking for COMPAS dataset at: {data_path}")
    
    # Try to load from your specified folder
    compas_df = pd.read_csv(data_path)
    print("Loaded COMPAS dataset from specified folder")
   
    
    # Filter relevant columns
    relevant_columns = [
        'age', 'sex', 'race', 'priors_count', 'c_charge_degree',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count',
        'decile_score', 'two_year_recid'
    ]
    
    # Check which columns exist in the dataset
    available_columns = [col for col in relevant_columns if col in compas_df.columns]
    compas_df = compas_df[available_columns].copy()
    
    # Drop rows with missing values
    compas_df = compas_df.dropna()
    
    # Create high_risk target: 0-5 as low risk, 6-10 as high risk
    compas_df['high_risk'] = (compas_df['decile_score'] >= 6).astype(int)
    
    # Consolidate race categories
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'african' in race or 'black' in race:
            return 'African_American'
        elif 'caucasian' in race or 'white' in race:
            return 'Caucasian'
        elif 'hispanic' in race or 'latino' in race:
            return 'Hispanic'
        elif 'asian' in race or 'arab' in race or 'native' in race or 'other' in race:
            return 'Other_Race'
        else:
            return 'Other_Race'
    
    compas_df['race_group'] = compas_df['race'].apply(consolidate_race)
    
    # Filter to keep only our target race groups
    target_races = ['African_American', 'Caucasian', 'Hispanic', 'Other_Race']
    compas_df = compas_df[compas_df['race_group'].isin(target_races)].copy()
    
    # Create additional features for better prediction
    compas_df['total_juvenile_charges'] = compas_df['juv_fel_count'] + compas_df['juv_misd_count'] + compas_df['juv_other_count']
    compas_df['is_felony'] = (compas_df['c_charge_degree'] == 'F').astype(int)
    compas_df['age_group'] = pd.cut(compas_df['age'], 
                                     bins=[0, 25, 35, 45, 55, 100],
                                     labels=['18-25', '26-35', '36-45', '46-55', '56+'])
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_group', 'priors_count', 'is_felony',
        'total_juvenile_charges', 'age_group', 'high_risk'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in compas_df.columns]
    compas_df = compas_df[final_columns]
    
    print(f"Processed dataset shape: {compas_df.shape}")
    print(f"Target distribution (high_risk):")
    print(compas_df['high_risk'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(compas_df['race_group'].value_counts(normalize=True))
    
    return compas_df

def generate_compas_data(n_samples=None):
    """
    Wrapper function to load COMPAS data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_compas_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'selection_rate': (tp + fp) / len(y_true),
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first'), categorical_cols)
    ])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000)
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0)
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0)
    else:
        model = LogisticRegression(random_state=42, max_iter=1000)

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_metrics[group] = compute_all_metrics(y_test[mask], y_pred[mask], y_prob[mask])

    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. EXCEL EXPORT FUNCTIONS
# ============================================

def save_to_excel_with_formatting(data_dict, filename, sheet_name_prefix="Case1"):
    """
    Save multiple dataframes to Excel with formatting
    """
    excel_path = os.path.join(excel_dir, filename)
    
    with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
        # Save each dataframe to a separate sheet
        for sheet_name, df in data_dict.items():
            # Create full sheet name
            full_sheet_name = f"{sheet_name_prefix}_{sheet_name}" if sheet_name_prefix else sheet_name
            
            # Truncate sheet name if too long (Excel limit is 31 characters)
            if len(full_sheet_name) > 31:
                full_sheet_name = full_sheet_name[:31]
            
            # Write dataframe to Excel
            df.to_excel(writer, sheet_name=full_sheet_name, index=False)
            
            # Get the worksheet for formatting
            worksheet = writer.sheets[full_sheet_name]
            
            # Apply formatting
            for column in worksheet.columns:
                max_length = 0
                column_letter = column[0].column_letter
                for cell in column:
                    try:
                        if len(str(cell.value)) > max_length:
                            max_length = len(str(cell.value))
                    except:
                        pass
                adjusted_width = min(max_length + 2, 50)
                worksheet.column_dimensions[column_letter].width = adjusted_width
            
            # Freeze the first row
            worksheet.freeze_panes = 'A2'
    
    print(f"  Saved Excel file: {filename}")
    return excel_path

def create_comprehensive_excel_report(compas_results, all_fic_results, metrics_list):
    """
    Create a comprehensive Excel report with all numerical values
    """
    print("\n" + "="*80)
    print("CREATING COMPREHENSIVE EXCEL REPORT")
    print("="*80)
    
    data_dict = {}
    
    # 1. Group Metrics Table
    print("1. Saving Group Metrics Table...")
    data_dict['Group_Metrics'] = compas_results['metrics_df'].reset_index().rename(columns={'index': 'Race_Group'})
    
    # 2. FIC Analysis Tables for all metrics
    print("2. Saving FIC Analysis Tables for all metrics...")
    for metric in metrics_list:
        fic_results = all_fic_results[metric]
        
        # Create comprehensive FIC table for this metric
        fic_table = []
        for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
            row = {'Group_Pair': pair}
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and pair in fic_results[af]:
                    d = fic_results[af][pair]
                    row[f'omega_alphaF_{af}'] = d['omega']
                    row[f'FIC_alphaF_{af}'] = d['fic_score']
                    row[f'Tier_alphaF_{af}'] = d['tier']
                    row[f'Metric1_{af}'] = d['metric1']
                    row[f'Metric2_{af}'] = d['metric2']
                else:
                    row[f'omega_alphaF_{af}'] = np.nan
                    row[f'FIC_alphaF_{af}'] = np.nan
                    row[f'Tier_alphaF_{af}'] = "N/A"
                    row[f'Metric1_{af}'] = np.nan
                    row[f'Metric2_{af}'] = np.nan
            fic_table.append(row)
        
        fic_df = pd.DataFrame(fic_table)
        data_dict[f'FIC_Analysis_{metric}'] = fic_df
    
    # 3. Tier Classification Summary for all metrics
    print("3. Saving Tier Classification Summary for all metrics...")
    for metric in metrics_list:
        fic_results = all_fic_results[metric]
        
        tier_summary = []
        for af in [0.05, 0.10, 0.15, 0.20]:
            if af in fic_results and fic_results[af]:
                tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                fic_framework = FairnessInformationCriterion()
                for d in fic_results[af].values():
                    tiers[fic_framework.classify_tier(d['fic_score'])] += 1
                
                total_pairs = sum(tiers.values())
                for tier_name, count in tiers.items():
                    tier_summary.append({
                        'Metric': metric,
                        'alphaF': af,
                        'Tier': tier_name,
                        'Count': count,
                        'Percentage': (count / total_pairs * 100) if total_pairs > 0 else 0
                    })
        
        tier_summary_df = pd.DataFrame(tier_summary)
        data_dict[f'Tier_Summary_{metric}'] = tier_summary_df
    
    # 4. Benchmarking Tiers Numerical Values
    print("4. Saving Benchmarking Tiers Numerical Values...")
    for metric in metrics_list:
        fic_results = all_fic_results[metric]
        
        benchmark_data = []
        for af in [0.05, 0.10, 0.15, 0.20]:
            if af in fic_results and fic_results[af]:
                for pair, d in fic_results[af].items():
                    benchmark_data.append({
                        'Metric': metric,
                        'alphaF': af,
                        'Group_Pair': pair,
                        'FIC_Score': d['fic_score'],
                        'Tier': d['tier'],
                        'omega': d['omega'],
                        'Metric_Value_Group1': d['metric1'],
                        'Metric_Value_Group2': d['metric2']
                    })
        
        benchmark_df = pd.DataFrame(benchmark_data)
        data_dict[f'Benchmark_Tiers_{metric}'] = benchmark_df
    
    # 5. Summary Statistics for each metric
    print("5. Saving Summary Statistics for each metric...")
    summary_stats = []
    for metric in metrics_list:
        fic_results = all_fic_results[metric]
        
        for af in [0.05, 0.10, 0.15, 0.20]:
            if af in fic_results and fic_results[af]:
                fic_scores = [d['fic_score'] for d in fic_results[af].values()]
                omegas = [d['omega'] for d in fic_results[af].values()]
                
                summary_stats.append({
                    'Metric': metric,
                    'alphaF': af,
                    'FIC_Mean': np.mean(fic_scores),
                    'FIC_Std': np.std(fic_scores),
                    'FIC_Min': np.min(fic_scores),
                    'FIC_Max': np.max(fic_scores),
                    'omega_Mean': np.mean(omegas),
                    'omega_Std': np.std(omegas),
                    'omega_Min': np.min(omegas),
                    'omega_Max': np.max(omegas),
                    'Num_Pairs': len(fic_scores)
                })
    
    summary_df = pd.DataFrame(summary_stats)
    data_dict['Summary_Statistics'] = summary_df
    
    # 6. Model Comparison
    print("6. Saving Model Comparison...")
    data_dict['Model_Comparison'] = compas_results['comparison_df']
    
    # 7. Dataset Statistics
    print("7. Saving Dataset Statistics...")
    data = compas_results['data']
    dataset_stats = pd.DataFrame({
        'Statistic': ['Total_Samples', 'High_Risk_Proportion', 
                      'African_American_Count', 'Caucasian_Count',
                      'Hispanic_Count', 'Other_Race_Count',
                      'Male_Count', 'Female_Count'],
        'Value': [
            len(data),
            data['high_risk'].mean(),
            (data['race_group'] == 'African_American').sum(),
            (data['race_group'] == 'Caucasian').sum(),
            (data['race_group'] == 'Hispanic').sum(),
            (data['race_group'] == 'Other_Race').sum(),
            (data['sex'] == 'Male').sum() if 'sex' in data.columns else np.nan,
            (data['sex'] == 'Female').sum() if 'sex' in data.columns else np.nan
        ]
    })
    data_dict['Dataset_Statistics'] = dataset_stats
    
    # 8. Fairness Assessment Matrix
    print("8. Creating Fairness Assessment Matrix...")
    fairness_matrix = []
    for metric in metrics_list:
        fic_results = all_fic_results[metric]
        
        for af in [0.05, 0.10, 0.15, 0.20]:
            if af in fic_results and fic_results[af]:
                # Count pairs in each tier
                fic_framework = FairnessInformationCriterion()
                tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                for d in fic_results[af].values():
                    tiers[fic_framework.classify_tier(d['fic_score'])] += 1
                
                # Determine overall fairness status
                if tiers['Unacceptable'] > 0:
                    overall_status = "Unfair"
                elif tiers['Questionable'] > 0:
                    overall_status = "Questionable"
                elif tiers['Acceptable'] > 0:
                    overall_status = "Acceptable"
                else:
                    overall_status = "Optimum"
                
                fairness_matrix.append({
                    'Metric': metric,
                    'alphaF': af,
                    'Overall_Fairness': overall_status,
                    'Optimum_Pairs': tiers['Optimum'],
                    'Acceptable_Pairs': tiers['Acceptable'],
                    'Questionable_Pairs': tiers['Questionable'],
                    'Unacceptable_Pairs': tiers['Unacceptable'],
                    'Total_Pairs': sum(tiers.values())
                })
    
    fairness_df = pd.DataFrame(fairness_matrix)
    data_dict['Fairness_Assessment'] = fairness_df
    
    # Save all to Excel
    excel_file = save_to_excel_with_formatting(data_dict, "COMPAS_FIC_Complete_Analysis.xlsx", "COMPAS")
    
    print(f"\n✓ Excel report saved: {excel_file}")
    print(f"  Total sheets: {len(data_dict)}")
    
    # Print sheet names
    print("\nExcel sheets created:")
    for i, sheet_name in enumerate(data_dict.keys(), 1):
        print(f"  {i:2d}. {sheet_name}")
    
    return excel_file

# ============================================
# 5. VISUALIZATIONS - UPDATED FOR PDF AND EXPANDED LEGEND
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add a single comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.78, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold the colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on the colorbar with more space
    cbar.ax.text(1.6, 0.90, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.6, 0.60, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.6, 0.350, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.6, 0.100, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.6)

    plt.tight_layout(rect=[0, 0.03, 0.78, 0.95])
    
    # Save as PNG
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), 
                dpi=400, bbox_inches='tight')
    # Save as PDF
    plt.savefig(os.path.join(pdf_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.pdf'), 
                format='pdf', bbox_inches='tight')
    plt.close()


def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a figure with EXPANDED width to prevent legend cutoff
        fig, ax = plt.subplots(figsize=(20, 8))  # Increased width from 16 to 20
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Find max positive and max negative values
        max_positive = max(fic_scores) if fic_scores else 1.0
        min_negative = min(fic_scores) if fic_scores else -0.25
        
        # Add padding (10% on positive side, 10% on negative side)
        y_max = max_positive * 1.10 if max_positive > 0 else 0.10
        y_min = min_negative * 1.10 if min_negative < 0 else -0.10
        
        # Ensure at least some range for visualization
        if y_max - y_min < 0.5:
            # If range is too small, center it around the data
            center = (max_positive + min_negative) / 2
            y_max = center + 0.25
            y_min = center - 0.25
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        
        # Customize axes with better labels
        ax.set_xlabel('Inter-Group', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'FIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set dynamic y-axis limits based on actual max positive and max negative
        ax.set_ylim(y_min, y_max)
        
        # Bold the y-axis tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend - moved to top right with fewer items
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable (FIC ≤ 0)')
        ]
        
        # Create a separate legend for threshold lines
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend at upper left - MORE SPACE with bbox_to_anchor
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.05, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        # Make the legend title bold
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend at upper left below tier legend - MORE SPACE
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.05, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        # Make the legend title bold
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |$M₁ - M₂$|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend - MORE SPACE allocated
        plt.tight_layout(rect=[0, 0, 0.80, 1])  # Changed from 0.85 to 0.80 for more legend space
        
        # Save the figure with alphaF in the filename - BOTH PNG AND PDF
        png_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'
        pdf_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.pdf'
        
        plt.savefig(os.path.join(output_dir, png_filename), 
                    dpi=400, bbox_inches='tight')
        plt.savefig(os.path.join(pdf_dir, pdf_filename), 
                    format='pdf', bbox_inches='tight')
        plt.close()
        
        print(f"  Saved benchmarking tiers plot for alphaF={alphaF} ({metric})")

# ============================================
# 6. ANALYSIS FUNCTIONS - UPDATED FOR ALL METRICS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    
    # Save metrics to CSV
    metrics_csv_path = os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv')
    metrics_df.to_csv(metrics_csv_path)
    print(f"✓ Group metrics saved to: {metrics_csv_path}")

    print("\nGENERATING VISUALIZATIONS FOR ALL METRICS...")
    
    # List of all metrics to analyze
    all_metrics = ['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']
    
    # Dictionary to store all FIC results
    all_fic_results = {}
    
    # Dictionary to store metric summaries
    metric_summaries = {}
    
    for metric in all_metrics:
        print(f"\n{'='*60}")
        print(f"ANALYZING METRIC: {metric.upper()}")
        print(f"{'='*60}")
        
        # Analyze fairness for this metric
        fic_results = fic_framework.analyze_fairness(baseline_metrics, metric)
        all_fic_results[metric] = fic_results
        
        # Generate heatmaps for this metric
        plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Generate benchmarking tiers for this metric
        plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Store summary for this metric
        metric_summary = {}
        for af in fic_framework.alphaF_values:
            if af in fic_results and fic_results[af]:
                omegas = [d['omega'] for d in fic_results[af].values()]
                fic_scores = [d['fic_score'] for d in fic_results[af].values()]
                tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                fic = FairnessInformationCriterion()
                for d in fic_results[af].values():
                    tiers[fic.classify_tier(d['fic_score'])] += 1
                
                metric_summary[f'alphaF_{af}'] = {
                    'omega_max': max(omegas),
                    'omega_avg': np.mean(omegas),
                    'omega_min': min(omegas),
                    'fic_max': max(fic_scores),
                    'fic_avg': np.mean(fic_scores),
                    'fic_min': min(fic_scores),
                    'tiers': tiers
                }
        
        metric_summaries[metric] = metric_summary
        
        # Print summary for this metric
        print(f"Summary for {metric}:")
        for af in fic_framework.alphaF_values:
            if af in metric_summary:
                summary = metric_summary[f'alphaF_{af}']
                print(f"  αF={af}: ω_max={summary['omega_max']:.4f}, ω_avg={summary['omega_avg']:.4f}, "
                      f"FIC_avg={summary['fic_avg']:.3f}, Tiers={summary['tiers']}")

    # Store FIC results for accuracy (original metric) for backward compatibility
    fic_results = all_fic_results['accuracy']
    
    # FIC table for accuracy (original)
    fic_table = []
    for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
        row = {'Group Pair': pair}
        for af in fic_framework.alphaF_values:
            if af in fic_results and pair in fic_results[af]:
                d = fic_results[af][pair]
                row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
            else:
                row[f'alphaF={af}'] = "N/A"
                row[f'Hypothesis alphaF={af}'] = "N/A"
        fic_table.append(row)
    fic_df = pd.DataFrame(fic_table)
    print("\nFIC ANALYSIS TABLE (Accuracy):")
    print(fic_df.to_string(index=False))
    
    # Save FIC analysis to CSV
    fic_csv_path = os.path.join(output_dir, f'Case{case_number}_FIC_Analysis_accuracy.csv')
    fic_df.to_csv(fic_csv_path, index=False)
    print(f"✓ FIC analysis saved to: {fic_csv_path}")

    # Tier classification for accuracy (original)
    tier_data = []
    print("\nTIER CLASSIFICATION (Accuracy):")
    for af in fic_framework.alphaF_values:
        print(f"\nFor αF = {af}:")
        print("-" * 50)
        if af in fic_results:
            for pair, d in fic_results[af].items():
                tier = fic_framework.classify_tier(d['fic_score'])
                msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
    tier_df = pd.DataFrame(tier_data)
    
    # Save tier classification to CSV
    tier_csv_path = os.path.join(output_dir, f'Case{case_number}_Tier_Classification_accuracy.csv')
    tier_df.to_csv(tier_csv_path, index=False)
    print(f"✓ Tier classification saved to: {tier_csv_path}")

    # Model comparison
    print("\nMODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
        avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
        _, y_test, _, y_pred, _ = test_data
        acc = accuracy_score(y_test, y_pred)
        comparison.append({
            'Model': mt.upper(),
            'Overall Accuracy': f"{acc:.4f}",
            'Avg FIC (αF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
            'ω_max (αF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
        })
    comparison_df = pd.DataFrame(comparison)
    print(comparison_df.to_string(index=False))
    
    # Save model comparison to CSV
    comparison_csv_path = os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv')
    comparison_df.to_csv(comparison_csv_path, index=False)
    print(f"✓ Model comparison saved to: {comparison_csv_path}")

    # Create comprehensive Excel report
    excel_file = create_comprehensive_excel_report(
        {
            'data': data,
            'baseline_metrics': baseline_metrics,
            'fic_results': fic_results,
            'all_fic_results': all_fic_results,
            'metrics_df': metrics_df,
            'fic_df': fic_df,
            'tier_df': tier_df,
            'comparison_df': comparison_df
        },
        all_fic_results,
        all_metrics
    )

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'all_fic_results': all_fic_results,  # Store all metrics results
        'metrics_df': metrics_df,
        'fic_df': fic_df,
        'tier_df': tier_df,
        'comparison_df': comparison_df,
        'excel_file': excel_file,
        'metric_summaries': metric_summaries
    }

# ============================================
# 7. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET")
    print("="*80)
    print(f"Output directory: {output_dir}")
    print(f"PDF directory: {pdf_dir}")
    print(f"Excel directory: {excel_dir}")

    compas_results = analyze_dataset(
        dataset_name="COMPAS - Recidivism Risk Prediction",
        data_generator=lambda: generate_compas_data(8000),
        target_col='high_risk',
        protected_col='race_group',
        case_number=1
    )

    print("\n" + "="*80)
    print("SUMMARY REPORT - COMPAS DATASET")
    print("="*80)

    print("COMPAS DATASET KEY FINDINGS:")
    print("-" * 60)
    data = compas_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High risk proportion: {data['high_risk'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_group'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh risk by race group:")
    for race in sorted(data['race_group'].unique()):
        subset = data[data['race_group'] == race]
        risk_prop = subset['high_risk'].mean()
        print(f"  {race}: {risk_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY (Accuracy):")
    print("-" * 60)
    for af in [0.05, 0.10, 0.15, 0.20]:
        if af in compas_results['fic_results'] and compas_results['fic_results'][af]:
            items = list(compas_results['fic_results'][af].items())
            max_o = max(d['omega'] for _, d in items)
            min_o = min(d['omega'] for _, d in items)
            avg_o = np.mean([d['omega'] for _, d in items])
            worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
            best_pair = min(items, key=lambda x: x[1]['omega'])[0]
            print(f"alphaF={af}:")
            print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
            print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
            print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
            
            # Tier distribution
            fic = FairnessInformationCriterion()
            tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
            for d in compas_results['fic_results'][af].values():
                tiers[fic.classify_tier(d['fic_score'])] += 1
            print(f"  Tier distribution: {tiers}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - ALL RESULTS SAVED")
    
    return compas_results

if __name__ == "__main__":
    # Run the complete analysis
    compas_results = run_complete_analysis()

    print("\n" + "="*80)
    print("ALL ANALYSIS COMPLETED SUCCESSFULLY!")


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - COMPAS DATASET
Output directory: Compas_NLEGEND_ALL_METRICS_PDF_EXCEL
PDF directory: Compas_NLEGEND_ALL_METRICS_PDF_EXCEL\PDF_plots
Excel directory: Compas_NLEGEND_ALL_METRICS_PDF_EXCEL\Excel_results

CASE 1: COMPAS - Recidivism Risk Prediction
Looking for COMPAS dataset at: C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work\compas-scores-two-years.csv
Loaded COMPAS dataset from specified folder
Processed dataset shape: (7214, 8)
Target distribution (high_risk):
high_risk
0    0.634599
1    0.365401
Name: proportion, dtype: float64

Race group distribution:
race_group
African_American    0.512337
Caucasian           0.340172
Hispanic            0.088301
Other_Race          0.059190
Name: proportion, dtype: float64
GROUP METRICS TABLE (Baseline Logistic Regression):
                  accuracy  selection_rate     tpr     tnr     fpr     fnr     ppv     npv      f1     auc
African_American    0.6993          0.3501  0.5498  0.8460  0.1540 

In [None]:
#...... CREDIT SCORING ASSESSMENT

In [74]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import warnings
import os
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.styles import PatternFill, Font, Alignment, Border, Side

warnings.filterwarnings('ignore')

# Create output directory
output_dir = "Adult_NLEGEND_ALL_METRICS_PDF_EXCEL2"
os.makedirs(output_dir, exist_ok=True)

# Also create PDF subdirectory
pdf_dir = os.path.join(output_dir, "PDF_plots")
os.makedirs(pdf_dir, exist_ok=True)

# Create Excel subdirectory
excel_dir = os.path.join(output_dir, "Excel_results")
os.makedirs(excel_dir, exist_ok=True)

# Set style for publication quality
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Global font settings for consistency
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
})

# ============================================
# 1. LOAD AND PREPROCESS ADULT DATASET
# ============================================

def load_adult_data():
    """
    Load and preprocess Adult dataset from UCI
    """
    # Define your specific folder path
    data_folder = r'C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work'
    data_file = "adult.csv"
    data_path = os.path.join(data_folder, data_file)
    
    print(f"Looking for Adult dataset at: {data_path}")
    
    # Try to load from your specified folder
    adult_df = pd.read_csv(data_path)
    print("Loaded Adult dataset from specified folder")
    
    # Check column names (Adult dataset might have spaces or different names)
    # Standardize column names
    adult_df.columns = [col.strip().replace('-', '_').lower() for col in adult_df.columns]
    
    # Show available columns
    print(f"Available columns: {adult_df.columns.tolist()}")
    print(f"Initial dataset shape: {adult_df.shape}")
    
    # Drop rows with missing values (indicated by '?' in Adult dataset)
    adult_df = adult_df.replace('?', np.nan)
    adult_df = adult_df.dropna()
    
    # Clean string columns (remove extra spaces)
    string_columns = adult_df.select_dtypes(include=['object']).columns
    for col in string_columns:
        adult_df[col] = adult_df[col].str.strip()
    
    # Create binary target: >50K as high income (1), <=50K as low income (0)
    adult_df['high_income'] = adult_df['income'].apply(lambda x: 1 if '>50K' in str(x) else 0)
    
    # Consolidate race categories as specified
    def consolidate_race(race):
        race = str(race).strip().lower()
        if 'white' in race:
            return 'White'
        elif 'black' in race:
            return 'Black'
        elif 'asian' in race or 'pac' in race or 'islander' in race:
            return 'Asian_Pac_Islander'
        elif 'indian' in race or 'eskimo' in race or 'aleut' in race:
            return 'Amer_Indian_Eskimo'
        else:
            return 'Other'
    
    adult_df['race_group'] = adult_df['race'].apply(consolidate_race)
    
    # Combine Asian-Pac-Islander and Amer-Indian-Eskimo as specified
    def combine_race_groups(race_group):
        if race_group in ['Asian_Pac_Islander', 'Amer_Indian_Eskimo']:
            return 'APAI' # Asian_Pacific_Amer_Indian
        else:
            return race_group
    
    adult_df['race_combined'] = adult_df['race_group'].apply(combine_race_groups)
    
    # Filter to keep only our target race groups
    target_races = ['White', 'Black', 'APAI', 'Other']
    adult_df = adult_df[adult_df['race_combined'].isin(target_races)].copy()
    
    # Check and ensure minimum sample size for each race group
    print("\nRace group distribution before filtering:")
    race_counts = adult_df['race_combined'].value_counts()
    print(race_counts)
    
    # Remove groups with too few samples
    min_samples = 100  # Minimum samples per group
    valid_races = race_counts[race_counts >= min_samples].index.tolist()
    adult_df = adult_df[adult_df['race_combined'].isin(valid_races)].copy()
    
    print(f"\nKeeping race groups with at least {min_samples} samples: {valid_races}")
    
    # Create additional features for better prediction
    # Create age groups
    adult_df['age_group'] = pd.cut(adult_df['age'], 
                                   bins=[0, 25, 35, 45, 55, 65, 100],
                                   labels=['18-25', '26-35', '36-45', '46-55', '56-65', '66+'])
    
    # Create education level groups
    def categorize_education(edu):
        edu_str = str(edu).lower()
        if any(x in edu_str for x in ['preschool', '1st', '2nd', '3rd', '4th', '5th', '6th', '7th', '8th', '9th']):
            return 'Primary'
        elif any(x in edu_str for x in ['10th', '11th', '12th', 'hs']):
            return 'High_School'
        elif any(x in edu_str for x in ['some', 'assoc']):
            return 'Some_College'
        elif any(x in edu_str for x in ['bachelors']):
            return 'Bachelors'
        elif any(x in edu_str for x in ['masters', 'prof', 'doctorate']):
            return 'Graduate'
        else:
            return 'Other'
    
    adult_df['education_level'] = adult_df['education'].apply(categorize_education)
    
    # Create work hours category
    adult_df['work_hours_category'] = pd.cut(adult_df['hours_per_week'],
                                            bins=[0, 20, 40, 60, 168],
                                            labels=['Part_Time', 'Full_Time', 'Overtime', 'Excessive'])
    
    # Create capital gain/loss indicator
    adult_df['has_capital_gain'] = (adult_df['capital_gain'] > 0).astype(int)
    adult_df['has_capital_loss'] = (adult_df['capital_loss'] > 0).astype(int)
    
    # Select final columns for analysis
    final_columns = [
        'age', 'sex', 'race_combined', 'education_num', 'marital_status',
        'occupation', 'relationship', 'hours_per_week', 'workclass',
        'education_level', 'age_group', 'work_hours_category',
        'has_capital_gain', 'has_capital_loss', 'high_income'
    ]
    
    # Ensure all columns exist
    final_columns = [col for col in final_columns if col in adult_df.columns]
    adult_df = adult_df[final_columns]
    
    # Handle rare categories by combining them into 'Other'
    categorical_cols = adult_df.select_dtypes(include=['object']).columns
    for col in categorical_cols:
        # For columns with many unique values, combine rare categories
        if adult_df[col].nunique() > 10:
            value_counts = adult_df[col].value_counts()
            rare_categories = value_counts[value_counts < 50].index
            adult_df[col] = adult_df[col].apply(lambda x: 'Other' if x in rare_categories else x)
    
    print(f"\nProcessed dataset shape: {adult_df.shape}")
    print(f"Target distribution (high_income):")
    print(adult_df['high_income'].value_counts(normalize=True))
    print(f"\nRace group distribution:")
    print(adult_df['race_combined'].value_counts(normalize=True))
    
    return adult_df

def generate_adult_data(n_samples=None):
    """
    Wrapper function to load Adult data
    n_samples parameter is kept for compatibility but not used
    """
    data = load_adult_data()
    
    # If n_samples is specified and smaller than dataset, sample it
    if n_samples and n_samples < len(data):
        data = data.sample(n=n_samples, random_state=42)
    
    return data

# ============================================
# 2-3. MODEL & FIC (WITH FIX FOR SMALL GROUPS)
# ============================================

def compute_all_metrics(y_true, y_pred, y_prob):
    """
    Compute comprehensive performance metrics with edge case handling
    """
    # Handle edge cases where confusion matrix might not have all 4 values
    cm = confusion_matrix(y_true, y_pred)
    
    # Get unique classes
    classes = np.unique(y_true)
    
    if len(classes) == 2:
        # Binary classification with both classes present
        tn, fp, fn, tp = cm.ravel()
    elif len(classes) == 1:
        # Only one class present (all 0s or all 1s)
        if classes[0] == 0:
            # All negatives
            tn, fp, fn, tp = cm[0, 0], 0, 0, 0
        else:
            # All positives
            tn, fp, fn, tp = 0, 0, 0, cm[0, 0]
    else:
        # Should not happen for binary classification
        tn, fp, fn, tp = 0, 0, 0, 0
    
    # Calculate metrics with edge case protection
    n = len(y_true)
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred) if n > 0 else np.nan,
        'selection_rate': (tp + fp) / n if n > 0 else np.nan,
        'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
        'fnr': fn / (tp + fn) if (tp + fn) > 0 else 0,
        'ppv': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'npv': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'f1': f1_score(y_true, y_pred, zero_division=0) if n > 0 else np.nan,
        'auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
    }
    return metrics

def train_and_evaluate_models(data, target_col, protected_col, model_type='baseline'):
    X = data.drop(columns=[target_col, protected_col])
    y = data[target_col]
    categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
    numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    # Use handle_unknown='ignore' to handle unseen categories in test set
    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first', handle_unknown='ignore'), categorical_cols)
    ])

    # Use larger test size to ensure each group has enough samples
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42, stratify=y)
    protected_test = data.loc[X_test.index, protected_col]

    X_train_processed = preprocessor.fit_transform(X_train)
    X_test_processed = preprocessor.transform(X_test)

    if model_type == 'baseline':
        model = LogisticRegression(random_state=42, max_iter=1000, class_weight='balanced')
    elif model_type == 'l1':
        model = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000, C=1.0, class_weight='balanced')
    elif model_type == 'l2':
        model = LogisticRegression(penalty='l2', random_state=42, max_iter=1000, C=1.0, class_weight='balanced')
    else:
        model = LogisticRegression(random_state=42, max_iter=1000, class_weight='balanced')

    model.fit(X_train_processed, y_train)
    y_pred = model.predict(X_test_processed)
    y_prob = model.predict_proba(X_test_processed)[:, 1]

    group_metrics = {}
    for group in protected_test.unique():
        mask = protected_test == group
        if mask.sum() > 0:
            group_y_true = y_test[mask]
            group_y_pred = y_pred[mask]
            group_y_prob = y_prob[mask]
            
            # Check if group has enough samples and both classes
            if len(group_y_true) >= 10:  # Minimum samples for reliable metrics
                group_metrics[group] = compute_all_metrics(group_y_true, group_y_pred, group_y_prob)
            else:
                print(f"  Warning: Group {group} has only {len(group_y_true)} samples in test set - skipping")
    
    return group_metrics, (X_test, y_test, protected_test, y_pred, y_prob)

class FairnessInformationCriterion:
    def __init__(self, alphaF_values=[0.05, 0.10, 0.15, 0.20]):
        self.alphaF_values = alphaF_values

    def compute_omega(self, metric1, metric2):
        return abs(metric1 - metric2)

    def compute_fic(self, omega, alphaF):
        return 1 - (omega / alphaF)

    def classify_tier(self, fic_score):
        if fic_score > 0.75:
            return "Optimum"
        elif fic_score > 0.50:
            return "Acceptable"
        elif fic_score > 0:
            return "Questionable"
        else:
            return "Unacceptable"

    def analyze_fairness(self, group_metrics, metric_name='accuracy'):
        results = {}
        groups = list(group_metrics.keys())
        for alphaF in self.alphaF_values:
            results[alphaF] = {}
            for i, g1 in enumerate(groups):
                for g2 in groups[i+1:]:
                    pair = f"{g1} - {g2}"
                    m1 = group_metrics[g1].get(metric_name, np.nan)
                    m2 = group_metrics[g2].get(metric_name, np.nan)
                    
                    # Only compute if both metrics are valid numbers
                    if not np.isnan(m1) and not np.isnan(m2):
                        omega = self.compute_omega(m1, m2)
                        fic_score = self.compute_fic(omega, alphaF)
                        tier = self.classify_tier(fic_score)
                        results[alphaF][pair] = {
                            'omega': omega, 'fic_score': fic_score, 'tier': tier,
                            'metric1': m1, 'metric2': m2
                        }
        return results

# ============================================
# 4. EXCEL EXPORT FUNCTIONS (UNCHANGED)
# ============================================

def save_to_excel_with_formatting(data_dict, filename, sheet_name_prefix="Case1"):
    """
    Save multiple dataframes to Excel with formatting
    """
    excel_path = os.path.join(excel_dir, filename)
    
    with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
        # Save each dataframe to a separate sheet
        for sheet_name, df in data_dict.items():
            # Create full sheet name
            full_sheet_name = f"{sheet_name_prefix}_{sheet_name}" if sheet_name_prefix else sheet_name
            
            # Truncate sheet name if too long (Excel limit is 31 characters)
            if len(full_sheet_name) > 31:
                full_sheet_name = full_sheet_name[:31]
            
            # Write dataframe to Excel
            df.to_excel(writer, sheet_name=full_sheet_name, index=False)
            
            # Get the worksheet for formatting
            worksheet = writer.sheets[full_sheet_name]
            
            # Apply formatting
            for column in worksheet.columns:
                max_length = 0
                column_letter = column[0].column_letter
                for cell in column:
                    try:
                        if len(str(cell.value)) > max_length:
                            max_length = len(str(cell.value))
                    except:
                        pass
                adjusted_width = min(max_length + 2, 50)
                worksheet.column_dimensions[column_letter].width = adjusted_width
            
            # Freeze the first row
            worksheet.freeze_panes = 'A2'
    
    print(f"  Saved Excel file: {filename}")
    return excel_path

def create_comprehensive_excel_report(adult_results, all_fic_results, metrics_list):
    """
    Create a comprehensive Excel report with all numerical values
    """
    print("\n" + "="*80)
    print("CREATING COMPREHENSIVE EXCEL REPORT")
    print("="*80)
    
    data_dict = {}
    
    # 1. Group Metrics Table
    print("1. Saving Group Metrics Table...")
    data_dict['Group_Metrics'] = adult_results['metrics_df'].reset_index().rename(columns={'index': 'Race_Group'})
    
    # 2. FIC Analysis Tables for all metrics
    print("2. Saving FIC Analysis Tables for all metrics...")
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            # Create comprehensive FIC table for this metric
            fic_table = []
            pairs = list(set(p for a in fic_results.values() for p in a.keys()))
            
            for pair in sorted(pairs):
                row = {'Group_Pair': pair}
                for af in [0.05, 0.10, 0.15, 0.20]:
                    if af in fic_results and pair in fic_results[af]:
                        d = fic_results[af][pair]
                        row[f'omega_alphaF_{af}'] = d['omega']
                        row[f'FIC_alphaF_{af}'] = d['fic_score']
                        row[f'Tier_alphaF_{af}'] = d['tier']
                        row[f'Metric1_{af}'] = d['metric1']
                        row[f'Metric2_{af}'] = d['metric2']
                    else:
                        row[f'omega_alphaF_{af}'] = np.nan
                        row[f'FIC_alphaF_{af}'] = np.nan
                        row[f'Tier_alphaF_{af}'] = "N/A"
                        row[f'Metric1_{af}'] = np.nan
                        row[f'Metric2_{af}'] = np.nan
                fic_table.append(row)
            
            if fic_table:
                fic_df = pd.DataFrame(fic_table)
                data_dict[f'FIC_Analysis_{metric}'] = fic_df
    
    # 3. Tier Classification Summary for all metrics
    print("3. Saving Tier Classification Summary for all metrics...")
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            tier_summary = []
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                    fic_framework = FairnessInformationCriterion()
                    for d in fic_results[af].values():
                        tiers[fic_framework.classify_tier(d['fic_score'])] += 1
                    
                    total_pairs = sum(tiers.values())
                    for tier_name, count in tiers.items():
                        tier_summary.append({
                            'Metric': metric,
                            'alphaF': af,
                            'Tier': tier_name,
                            'Count': count,
                            'Percentage': (count / total_pairs * 100) if total_pairs > 0 else 0
                        })
            
            if tier_summary:
                tier_summary_df = pd.DataFrame(tier_summary)
                data_dict[f'Tier_Summary_{metric}'] = tier_summary_df
    
    # 4. Benchmarking Tiers Numerical Values
    print("4. Saving Benchmarking Tiers Numerical Values...")
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            benchmark_data = []
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    for pair, d in fic_results[af].items():
                        benchmark_data.append({
                            'Metric': metric,
                            'alphaF': af,
                            'Group_Pair': pair,
                            'FIC_Score': d['fic_score'],
                            'Tier': d['tier'],
                            'omega': d['omega'],
                            'Metric_Value_Group1': d['metric1'],
                            'Metric_Value_Group2': d['metric2']
                        })
            
            if benchmark_data:
                benchmark_df = pd.DataFrame(benchmark_data)
                data_dict[f'Benchmark_Tiers_{metric}'] = benchmark_df
    
    # 5. Summary Statistics for each metric
    print("5. Saving Summary Statistics for each metric...")
    summary_stats = []
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    fic_scores = [d['fic_score'] for d in fic_results[af].values()]
                    omegas = [d['omega'] for d in fic_results[af].values()]
                    
                    summary_stats.append({
                        'Metric': metric,
                        'alphaF': af,
                        'FIC_Mean': np.mean(fic_scores),
                        'FIC_Std': np.std(fic_scores),
                        'FIC_Min': np.min(fic_scores),
                        'FIC_Max': np.max(fic_scores),
                        'omega_Mean': np.mean(omegas),
                        'omega_Std': np.std(omegas),
                        'omega_Min': np.min(omegas),
                        'omega_Max': np.max(omegas),
                        'Num_Pairs': len(fic_scores)
                    })
    
    if summary_stats:
        summary_df = pd.DataFrame(summary_stats)
        data_dict['Summary_Statistics'] = summary_df
    
    # 6. Model Comparison
    print("6. Saving Model Comparison...")
    if 'comparison_df' in adult_results:
        data_dict['Model_Comparison'] = adult_results['comparison_df']
    
    # 7. Dataset Statistics
    print("7. Saving Dataset Statistics...")
    data = adult_results['data']
    dataset_stats_data = []
    
    # Basic statistics
    dataset_stats_data.append({
        'Statistic': 'Total_Samples',
        'Value': len(data)
    })
    
    # Target variable statistics
    target_col = 'high_income'
    dataset_stats_data.append({
        'Statistic': f'{target_col.capitalize()}_Proportion',
        'Value': data[target_col].mean()
    })
    
    # Protected attribute statistics
    protected_col = 'race_combined'
    for group in data[protected_col].unique():
        count = (data[protected_col] == group).sum()
        proportion = count / len(data)
        dataset_stats_data.append({
            'Statistic': f'{protected_col.capitalize()}_{group}_Count',
            'Value': count
        })
        dataset_stats_data.append({
            'Statistic': f'{protected_col.capitalize()}_{group}_Proportion',
            'Value': proportion
        })
    
    dataset_stats_df = pd.DataFrame(dataset_stats_data)
    data_dict['Dataset_Statistics'] = dataset_stats_df
    
    # 8. Fairness Assessment Matrix
    print("8. Creating Fairness Assessment Matrix...")
    fairness_matrix = []
    for metric in metrics_list:
        if metric in all_fic_results:
            fic_results = all_fic_results[metric]
            
            for af in [0.05, 0.10, 0.15, 0.20]:
                if af in fic_results and fic_results[af]:
                    # Count pairs in each tier
                    fic_framework = FairnessInformationCriterion()
                    tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                    for d in fic_results[af].values():
                        tiers[fic_framework.classify_tier(d['fic_score'])] += 1
                    
                    # Determine overall fairness status
                    if tiers['Unacceptable'] > 0:
                        overall_status = "Unfair"
                    elif tiers['Questionable'] > 0:
                        overall_status = "Questionable"
                    elif tiers['Acceptable'] > 0:
                        overall_status = "Acceptable"
                    else:
                        overall_status = "Optimum"
                    
                    fairness_matrix.append({
                        'Metric': metric,
                        'alphaF': af,
                        'Overall_Fairness': overall_status,
                        'Optimum_Pairs': tiers['Optimum'],
                        'Acceptable_Pairs': tiers['Acceptable'],
                        'Questionable_Pairs': tiers['Questionable'],
                        'Unacceptable_Pairs': tiers['Unacceptable'],
                        'Total_Pairs': sum(tiers.values())
                    })
    
    if fairness_matrix:
        fairness_df = pd.DataFrame(fairness_matrix)
        data_dict['Fairness_Assessment'] = fairness_df
    
    # Save all to Excel
    excel_filename = "ADULT_FIC_Complete_Analysis.xlsx"
    excel_file = save_to_excel_with_formatting(data_dict, excel_filename, "ADULT")
    
    print(f"\n✓ Excel report saved: {excel_file}")
    print(f"  Total sheets: {len(data_dict)}")
    
    # Print sheet names
    print("\nExcel sheets created:")
    for i, sheet_name in enumerate(data_dict.keys(), 1):
        print(f"  {i:2d}. {sheet_name}")
    
    return excel_file

# ============================================
# 5. VISUALIZATIONS - UPDATED FOR PDF AND EXPANDED LEGEND
# ============================================

def plot_fic_heatmaps(fic_results, dataset_name, metric='accuracy'):
    alphaF_values = sorted(fic_results.keys())
    if not alphaF_values:
        return

    pairs = list(fic_results[alphaF_values[0]].keys())
    all_groups = sorted(set(g for p in pairs for g in p.split(' - ')))

    # Larger figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle(f'FIC Heatmaps for Different alphaF Values ({metric})',
                 fontsize=20, fontweight='bold', y=0.98)

    axes = axes.flatten()

    for idx, alphaF in enumerate(alphaF_values):
        ax = axes[idx]
        n = len(all_groups)
        mat = np.full((n, n), np.nan)
        group_idx = {g: i for i, g in enumerate(all_groups)}

        for pair, d in fic_results[alphaF].items():
            g1, g2 = pair.split(' - ')
            i, j = group_idx[g1], group_idx[g2]
            mat[i, j] = mat[j, i] = d['fic_score']

        im = ax.imshow(mat, cmap='RdYlGn', vmin=-1, vmax=1, aspect='equal')

        # Add value labels inside cells
        for i in range(n):
            for j in range(n):
                if i != j and not np.isnan(mat[i, j]):
                    text = ax.text(j, i, f'{mat[i,j]:.2f}',
                                   ha='center', va='center',
                                   fontsize=14, fontweight='bold',
                                   color='white' if abs(mat[i,j]) > 0.5 else 'black')

        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(all_groups, rotation=45, ha='right', fontsize=13, fontweight='bold')
        ax.set_yticklabels(all_groups, fontsize=13, fontweight='bold')
        ax.set_title(f'αF = {alphaF}', fontsize=18, fontweight='bold', pad=20)
        
        # Add grid
        ax.set_xticks(np.arange(-.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-.5, n, 1), minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add a single comprehensive colorbar with tier labels
    cbar_ax = fig.add_axes([0.78, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('FIC Score', fontsize=14, fontweight='bold', labelpad=15)
    cbar.ax.tick_params(labelsize=12)
    
    # Bold the colorbar tick labels
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    
    # Add tier annotations on the colorbar with more space
    cbar.ax.text(1.6, 0.90, 'Optimum', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkgreen')
    cbar.ax.text(1.6, 0.60, 'Acceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='goldenrod')
    cbar.ax.text(1.6, 0.350, 'Questionable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkorange')
    cbar.ax.text(1.6, 0.100, 'Unacceptable', transform=cbar.ax.transAxes, 
                 fontsize=14, fontweight='bold', va='center', ha='left', color='darkred')
    
    # Add tier threshold lines on colorbar
    cbar.ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=3, xmax=0.6)
    cbar.ax.axhline(0.00, color='darkred', linestyle='--', linewidth=3, xmax=0.6)

    plt.tight_layout(rect=[0, 0.03, 0.78, 0.95])
    
    # Save as PNG
    plt.savefig(os.path.join(output_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.png'), 
                dpi=400, bbox_inches='tight')
    # Save as PDF
    plt.savefig(os.path.join(pdf_dir, f'{dataset_name}_FIC_Heatmaps_{metric}.pdf'), 
                format='pdf', bbox_inches='tight')
    plt.close()


def plot_benchmarking_tiers(fic_results, dataset_name, metric='accuracy'):
    # Sort alphaF values to ensure consistent order
    alphaF_values = sorted(fic_results.keys())
    
    # Define colors for tiers
    colors = {'Optimum': '#2E8B57', 'Acceptable': '#FFD700', 
              'Questionable': '#FF8C00', 'Unacceptable': '#DC143C'}
    
    for alphaF in alphaF_values:
        if alphaF not in fic_results or not fic_results[alphaF]:
            print(f"No data for alphaF={alphaF} in benchmarking tiers")
            continue
        
        # Create a figure with EXPANDED width to prevent legend cutoff
        fig, ax = plt.subplots(figsize=(20, 8))  # Increased width from 16 to 20
        
        data = fic_results[alphaF]
        pairs = list(data.keys())
        fic_scores = [data[p]['fic_score'] for p in pairs]
        tiers = [data[p]['tier'] for p in pairs]
        
        # Find max positive and max negative values
        max_positive = max(fic_scores) if fic_scores else 1.0
        min_negative = min(fic_scores) if fic_scores else -0.25
        
        # Add padding (10% on positive side, 10% on negative side)
        y_max = max_positive * 1.10 if max_positive > 0 else 0.10
        y_min = min_negative * 1.10 if min_negative < 0 else -0.10
        
        # Ensure at least some range for visualization
        if y_max - y_min < 0.5:
            # If range is too small, center it around the data
            center = (max_positive + min_negative) / 2
            y_max = center + 0.25
            y_min = center - 0.25
        
        # Create bar colors based on tiers
        bar_colors = [colors[t] for t in tiers]
        
        # Create bars with smaller width for more compact look
        bars = ax.bar(range(len(pairs)), fic_scores, color=bar_colors, 
                      edgecolor='black', linewidth=1.2, width=0.6)
        
        # Add tier threshold lines with better styling
        ax.axhline(0.75, color='darkgreen', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.50, color='goldenrod', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        ax.axhline(0.00, color='darkred', linestyle='--', linewidth=2.0, 
                   alpha=0.7)
        
        # Customize axes with better labels
        ax.set_xlabel('Inter-Group', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('FIC Score', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'FIC Benchmarking Tiers ({metric}, αF = {alphaF})',
                    fontsize=16, fontweight='bold', pad=15)
        
        # Set x-ticks with rotation for readability
        ax.set_xticks(range(len(pairs)))
        ax.set_xticklabels(pairs, rotation=45, ha='right', fontsize=11, fontweight='bold')
        
        # Set dynamic y-axis limits based on actual max positive and max negative
        ax.set_ylim(y_min, y_max)
        
        # Bold the y-axis tick labels
        y_ticks = ax.get_yticks()
        ax.set_yticklabels([f'{tick:.2f}' for tick in y_ticks], fontsize=11, fontweight='bold')
        
        # Add grid with lighter style
        ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
        ax.grid(True, axis='x', alpha=0.1, linestyle='-', linewidth=0.5)
        
        # Add better legend - moved to top right with fewer items
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=colors['Optimum'], edgecolor='black', label='Optimum (FIC > 0.75)'),
            Patch(facecolor=colors['Acceptable'], edgecolor='black', label='Acceptable (0.50 < FIC ≤ 0.75)'),
            Patch(facecolor=colors['Questionable'], edgecolor='black', label='Questionable (0 < FIC ≤ 0.50)'),
            Patch(facecolor=colors['Unacceptable'], edgecolor='black', label='Unacceptable (FIC ≤ 0)')
        ]
        
        # Create a separate legend for threshold lines
        from matplotlib.lines import Line2D
        line_legend_elements = [
            Line2D([0], [0], color='darkgreen', linestyle='--', linewidth=2, label='Optimum Threshold (0.75)'),
            Line2D([0], [0], color='goldenrod', linestyle='--', linewidth=2, label='Acceptable Threshold (0.50)'),
            Line2D([0], [0], color='darkred', linestyle='--', linewidth=2, label='Unacceptable Threshold (0.00)')
        ]
        
        # Place tier legend at upper left - MORE SPACE with bbox_to_anchor
        tier_legend = ax.legend(handles=legend_elements, fontsize=10, 
                                loc='upper left', bbox_to_anchor=(1.05, 1.0),
                                frameon=True, framealpha=0.9, edgecolor='black',
                                title='FIC Tiers', title_fontsize=11)
        # Make the legend title bold
        tier_legend.get_title().set_fontweight('bold')
        ax.add_artist(tier_legend)
        
        # Place threshold legend at upper left below tier legend - MORE SPACE
        threshold_legend = ax.legend(handles=line_legend_elements, fontsize=9, 
                                     loc='upper left', bbox_to_anchor=(1.05, 0.65),
                                     frameon=True, framealpha=0.9, edgecolor='black',
                                     title='Thresholds', title_fontsize=10)
        # Make the legend title bold
        threshold_legend.get_title().set_fontweight('bold')
        
        # Add annotation for alphaF interpretation
        annotation_text = f'αF = {alphaF}\nFIC = 1 - (ω/αF)\nω = |$M₁ - M₂$|'
        ax.text(0.02, 0.98, annotation_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Adjust layout to make room for legend - MORE SPACE allocated
        plt.tight_layout(rect=[0, 0, 0.80, 1])  # Changed from 0.85 to 0.80 for more legend space
        
        # Save the figure with alphaF in the filename - BOTH PNG AND PDF
        png_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.png'
        pdf_filename = f'{dataset_name}_Benchmarking_Tiers_alphaF_{alphaF}_{metric}.pdf'
        
        plt.savefig(os.path.join(output_dir, png_filename), 
                    dpi=400, bbox_inches='tight')
        plt.savefig(os.path.join(pdf_dir, pdf_filename), 
                    format='pdf', bbox_inches='tight')
        plt.close()
        
        print(f"  Saved benchmarking tiers plot for alphaF={alphaF} ({metric})")

# ============================================
# 6. ANALYSIS FUNCTIONS - UPDATED FOR ALL METRICS
# ============================================

def analyze_dataset(dataset_name, data_generator, target_col, protected_col, case_number=1, model_types=['baseline', 'l1', 'l2']):
    print(f"\n{'='*80}")
    print(f"CASE {case_number}: {dataset_name}")
    print(f"{'='*80}")

    data = data_generator()
    fic_framework = FairnessInformationCriterion()

    baseline_metrics, _ = train_and_evaluate_models(data, target_col, protected_col, 'baseline')
    
    # Check if we have valid metrics
    if not baseline_metrics:
        print("Warning: No valid group metrics computed. Dataset may be too small or imbalanced.")
        return None

    metrics_df = pd.DataFrame.from_dict(baseline_metrics, orient='index')
    metrics_df = metrics_df[['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']]
    print("GROUP METRICS TABLE (Baseline Logistic Regression):")
    print(metrics_df.round(4).to_string())
    
    # Save metrics to CSV
    metrics_csv_path = os.path.join(output_dir, f'Case{case_number}_Group_Metrics.csv')
    metrics_df.to_csv(metrics_csv_path)
    print(f"Group metrics saved to: {metrics_csv_path}")

    print("GENERATING VISUALIZATIONS FOR ALL METRICS...")
    
    # List of all metrics to analyze
    all_metrics = ['accuracy', 'selection_rate', 'tpr', 'tnr', 'fpr', 'fnr', 'ppv', 'npv', 'f1', 'auc']
    
    # Dictionary to store all FIC results
    all_fic_results = {}
    
    # Dictionary to store metric summaries
    metric_summaries = {}
    
    for metric in all_metrics:
        print(f"\n{'='*60}")
        print(f"ANALYZING METRIC: {metric.upper()}")
        print(f"{'='*60}")
        
        # Analyze fairness for this metric
        fic_results = fic_framework.analyze_fairness(baseline_metrics, metric)
        all_fic_results[metric] = fic_results
        
        # Only generate visualizations if we have results
        if fic_results and any(fic_results.values()):
            # Generate heatmaps for this metric
            plot_fic_heatmaps(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
            
            # Generate benchmarking tiers for this metric
            plot_benchmarking_tiers(fic_results, f'Case{case_number}_{dataset_name}_{metric}', metric)
        
        # Store summary for this metric
        metric_summary = {}
        for af in fic_framework.alphaF_values:
            if af in fic_results and fic_results[af]:
                omegas = [d['omega'] for d in fic_results[af].values()]
                fic_scores = [d['fic_score'] for d in fic_results[af].values()]
                tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                fic = FairnessInformationCriterion()
                for d in fic_results[af].values():
                    tiers[fic.classify_tier(d['fic_score'])] += 1
                
                metric_summary[f'alphaF_{af}'] = {
                    'omega_max': max(omegas) if omegas else np.nan,
                    'omega_avg': np.mean(omegas) if omegas else np.nan,
                    'omega_min': min(omegas) if omegas else np.nan,
                    'fic_max': max(fic_scores) if fic_scores else np.nan,
                    'fic_avg': np.mean(fic_scores) if fic_scores else np.nan,
                    'fic_min': min(fic_scores) if fic_scores else np.nan,
                    'tiers': tiers
                }
        
        metric_summaries[metric] = metric_summary
        
        # Print summary for this metric
        print(f"Summary for {metric}:")
        for af in fic_framework.alphaF_values:
            if af in metric_summary:
                summary = metric_summary[f'alphaF_{af}']
                print(f"  αF={af}: ω_max={summary['omega_max']:.4f}, ω_avg={summary['omega_avg']:.4f}, "
                      f"FIC_avg={summary['fic_avg']:.3f}, Tiers={summary['tiers']}")

    # Store FIC results for accuracy (original metric) for backward compatibility
    fic_results = all_fic_results.get('accuracy', {})
    
    if fic_results:
        # FIC table for accuracy (original)
        fic_table = []
        for pair in sorted(set(p for a in fic_results.values() for p in a.keys())):
            row = {'Group Pair': pair}
            for af in fic_framework.alphaF_values:
                if af in fic_results and pair in fic_results[af]:
                    d = fic_results[af][pair]
                    row[f'alphaF={af}'] = f"omega={d['omega']:.4f}, FIC={d['fic_score']:.3f}"
                    row[f'Hypothesis alphaF={af}'] = "Fail to reject Ho (Fair)" if d['omega'] <= af else "Reject H₀ (Unfair)"
                else:
                    row[f'alphaF={af}'] = "N/A"
                    row[f'Hypothesis alphaF={af}'] = "N/A"
            fic_table.append(row)
        fic_df = pd.DataFrame(fic_table)
        print("FIC ANALYSIS TABLE (Accuracy):")
        print(fic_df.to_string(index=False))
        
        # Save FIC analysis to CSV
        fic_csv_path = os.path.join(output_dir, f'Case{case_number}_FIC_Analysis_accuracy.csv')
        fic_df.to_csv(fic_csv_path, index=False)
        print(f"FIC analysis saved to: {fic_csv_path}")

        # Tier classification for accuracy
        tier_data = []
        print("TIER CLASSIFICATION (Accuracy):")
        for af in fic_framework.alphaF_values:
            print(f"\nFor αF = {af}:")
            print("-" * 50)
            if af in fic_results:
                for pair, d in fic_results[af].items():
                    tier = fic_framework.classify_tier(d['fic_score'])
                    msg = tier if d['fic_score'] <= 0.75 else f"{tier} (omega_max < {0.25*af:.4f})"
                    print(f"{pair}: ω={d['omega']:.4f}, FIC={d['fic_score']:.3f} → {msg}")
                    tier_data.append({'alphaF': af, 'Group Pair': pair, 'ω': d['omega'], 'FIC': d['fic_score'], 'Tier': tier})
        
        tier_df = pd.DataFrame(tier_data)
        
        # Save tier classification to CSV
        tier_csv_path = os.path.join(output_dir, f'Case{case_number}_Tier_Classification_accuracy.csv')
        tier_df.to_csv(tier_csv_path, index=False)
        print(f"Tier classification saved to: {tier_csv_path}")
    else:
        print("No FIC results for accuracy metric - skipping FIC analysis table")

    # Model comparison
    print("MODEL COMPARISON:")
    comparison = []
    for mt in model_types:
        mets, test_data = train_and_evaluate_models(data, target_col, protected_col, mt)
        if mets:  # Only if we got valid metrics
            model_fic = fic_framework.analyze_fairness(mets, 'accuracy')
            avg_fic = np.mean([d['fic_score'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
            max_omega = max([d['omega'] for d in model_fic[0.10].values()]) if 0.10 in model_fic and model_fic[0.10] else np.nan
            _, y_test, _, y_pred, _ = test_data
            acc = accuracy_score(y_test, y_pred)
            comparison.append({
                'Model': mt.upper(),
                'Overall Accuracy': f"{acc:.4f}",
                'Avg FIC (αF=0.10)': f"{avg_fic:.3f}" if not np.isnan(avg_fic) else "N/A",
                'ω_max (αF=0.10)': f"{max_omega:.4f}" if not np.isnan(max_omega) else "N/A"
            })
    
    if comparison:
        comparison_df = pd.DataFrame(comparison)
        print(comparison_df.to_string(index=False))
        
        # Save model comparison to CSV
        comparison_csv_path = os.path.join(output_dir, f'Case{case_number}_Model_Comparison.csv')
        comparison_df.to_csv(comparison_csv_path, index=False)
        print(f"✓ Model comparison saved to: {comparison_csv_path}")
    else:
        comparison_df = pd.DataFrame()

    # Create comprehensive Excel report
    excel_file = create_comprehensive_excel_report(
        {
            'data': data,
            'baseline_metrics': baseline_metrics,
            'fic_results': fic_results,
            'all_fic_results': all_fic_results,
            'metrics_df': metrics_df,
            'comparison_df': comparison_df
        },
        all_fic_results,
        all_metrics
    )

    return {
        'data': data,
        'baseline_metrics': baseline_metrics,
        'fic_results': fic_results,
        'all_fic_results': all_fic_results,  # Store all metrics results
        'metrics_df': metrics_df,
        'comparison_df': comparison_df,
        'excel_file': excel_file,
        'metric_summaries': metric_summaries
    }

# ============================================
# 7. MAIN ANALYSIS
# ============================================

def run_complete_analysis():
    print("\n" + "="*80)
    print("FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - ADULT DATASET")
    print("="*80)
    print(f"Output directory: {output_dir}")
    print(f"PDF directory: {pdf_dir}")
    print(f"Excel directory: {excel_dir}")

    adult_results = analyze_dataset(
        dataset_name="ADULT - Income Prediction",
        data_generator=lambda: generate_adult_data(8000),
        target_col='high_income',
        protected_col='race_combined',
        case_number=1
    )
    
    if adult_results is None:
        print("\nAnalysis failed. Check dataset and try again.")
        return None

    print("\n" + "="*80)
    print("SUMMARY REPORT - ADULT DATASET")
    print("="*80)

    print("ADULT DATASET KEY FINDINGS:")
    print("-" * 60)
    data = adult_results['data']
    print(f"Total samples: {len(data)}")
    print(f"High income proportion (>50K): {data['high_income'].mean():.3f}")
    print("\nRace group distribution:")
    race_dist = data['race_combined'].value_counts()
    for race, count in race_dist.items():
        prop = count / len(data)
        print(f"  {race}: {count} ({prop:.3f})")
    
    print("\nHigh income by race group:")
    for race in sorted(data['race_combined'].unique()):
        subset = data[data['race_combined'] == race]
        income_prop = subset['high_income'].mean()
        print(f"  {race}: {income_prop:.3f}")

    print("\nFIC ANALYSIS SUMMARY (Accuracy):")
    print("-" * 60)
    fic_results = adult_results.get('fic_results', {})
    if fic_results:
        for af in [0.05, 0.10, 0.15, 0.20]:
            if af in fic_results and fic_results[af]:
                items = list(fic_results[af].items())
                if items:
                    max_o = max(d['omega'] for _, d in items)
                    min_o = min(d['omega'] for _, d in items)
                    avg_o = np.mean([d['omega'] for _, d in items])
                    worst_pair = max(items, key=lambda x: x[1]['omega'])[0]
                    best_pair = min(items, key=lambda x: x[1]['omega'])[0]
                    print(f"alphaF={af}:")
                    print(f"  omega range: [{min_o:.4f}, {max_o:.4f}], avg: {avg_o:.4f}")
                    print(f"  Most unfair pair: {worst_pair} (ω={max_o:.4f})")
                    print(f"  Most fair pair: {best_pair} (ω={min_o:.4f})")
                    
                    # Tier distribution
                    fic = FairnessInformationCriterion()
                    tiers = {'Optimum': 0, 'Acceptable': 0, 'Questionable': 0, 'Unacceptable': 0}
                    for d in fic_results[af].values():
                        tiers[fic.classify_tier(d['fic_score'])] += 1
                    print(f"  Tier distribution: {tiers}")
    else:
        print("No FIC results available.")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - ALL RESULTS SAVED")
    
    return adult_results

if __name__ == "__main__":
    # Run the complete analysis
    adult_results = run_complete_analysis()
    
    if adult_results:
        print("\n" + "="*80)


FAIRNESS INFORMATION CRITERION (FIC) ANALYSIS - ADULT DATASET
Output directory: Adult_NLEGEND_ALL_METRICS_PDF_EXCEL2
PDF directory: Adult_NLEGEND_ALL_METRICS_PDF_EXCEL2\PDF_plots
Excel directory: Adult_NLEGEND_ALL_METRICS_PDF_EXCEL2\Excel_results

CASE 1: ADULT - Income Prediction
Looking for Adult dataset at: C:\Users\Dr. Akin\OneDrive\2025\Paper_2025\PHD_Work\adult.csv
Loaded Adult dataset from specified folder
Available columns: ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', 'income']
Initial dataset shape: (48842, 15)

Race group distribution before filtering:
race_combined
White    41762
Black     4685
APAI      1989
Other      406
Name: count, dtype: int64

Keeping race groups with at least 100 samples: ['White', 'Black', 'APAI', 'Other']

Processed dataset shape: (48842, 15)
Target distribution (high_income):
high_income
0    0.760718


In [None]:
# Execute the script
#python adult_fic_analysis_fixed2.py