# Looking at the Balanced Dataset (3D Stratified)

This notebook analyzes the perfectly balanced dataset created with 3D stratified sampling:
- Emotion × Race × Gender (24 strata)
- Each stratum has equal representation

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Set style for better-looking plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

In [None]:
# Load the balanced dataset
df = pd.read_csv('../FER-New-Dataset/dataset_balanced_3d.csv')
print(f"Total images in balanced dataset: {len(df):,}")
df.head()

In [None]:
# Extract dataset source from image_path
df['dataset_source'] = df['image_path'].apply(lambda x: x.split('/')[2] if len(x.split('/')) >= 3 else 'unknown')

# Filter only original images for fairness analysis
df_orig = df[df['augmented'] == 'original'].copy()

print(f"\nOriginal (non-augmented) images: {len(df_orig):,}")
print(f"Augmented images: {len(df[df['augmented'] != 'original']):,}")
print(f"\nAugmentation types in dataset:")
print(df['augmented'].value_counts())

## Overview: Perfect Balance Verification

In [None]:
print("="*70)
print("PERFECT BALANCE VERIFICATION (Original Images Only)")
print("="*70)

# Check all 24 strata
emotions = ['anger', 'fear', 'calm', 'surprise']
races = ['Caucasian', 'Asian', 'African-American']
genders = ['male', 'female']

strata_counts = []
print("\nAll 24 Strata Counts:")
for emotion in emotions:
    for race in races:
        for gender in genders:
            count = len(df_orig[
                (df_orig['emotion'] == emotion) &
                (df_orig['race'] == race) &
                (df_orig['gender'] == gender)
            ])
            strata_counts.append(count)
            stratum_name = f"{emotion}_{race}_{gender}"
            print(f"  {stratum_name:50s} : {count:4d}")

if len(set(strata_counts)) == 1:
    print(f"\n✓ PERFECT BALANCE CONFIRMED: All 24 strata have exactly {strata_counts[0]} images!")
else:
    print(f"\n✗ Imbalance detected: {set(strata_counts)}")

## Emotion Distribution

In [None]:
# Emotion distribution (original images)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Original images
emotion_order = ['anger', 'fear', 'calm', 'surprise']
counts = df_orig['emotion'].value_counts()[emotion_order]
axes[0].bar(counts.index, counts.values, color=['#e74c3c', '#9b59b6', '#3498db', '#f39c12'])
axes[0].set_title('Emotion Distribution (Original Images)', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Count', fontsize=12)
axes[0].axhline(y=len(df_orig)/4, color='red', linestyle='--', alpha=0.5, label='Perfect Balance')

for i, v in enumerate(counts.values):
    axes[0].text(i, v + 10, str(v), ha='center', fontweight='bold')
    pct = (v / len(df_orig)) * 100
    axes[0].text(i, v/2, f'{pct:.1f}%', ha='center', fontsize=10, color='white', fontweight='bold')

axes[0].legend()

# All images (including augmented)
counts_all = df['emotion'].value_counts()[emotion_order]
axes[1].bar(counts_all.index, counts_all.values, color=['#e74c3c', '#9b59b6', '#3498db', '#f39c12'])
axes[1].set_title('Emotion Distribution (All Images)', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Count', fontsize=12)
axes[1].axhline(y=len(df)/4, color='red', linestyle='--', alpha=0.5, label='Perfect Balance')

for i, v in enumerate(counts_all.values):
    axes[1].text(i, v + 100, str(v), ha='center', fontweight='bold')
    pct = (v / len(df)) * 100
    axes[1].text(i, v/2, f'{pct:.1f}%', ha='center', fontsize=10, color='white', fontweight='bold')

axes[1].legend()

plt.tight_layout()
plt.show()

print("\nEmotion counts (original):")
print(counts)
print(f"\nExpected per emotion: {len(df_orig)/4}")

## Race Distribution

In [None]:
# Race distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Original images
race_order = ['Caucasian', 'Asian', 'African-American']
counts = df_orig['race'].value_counts()[race_order]
colors = ['#3498db', '#e67e22', '#2ecc71']
axes[0].bar(counts.index, counts.values, color=colors)
axes[0].set_title('Race Distribution (Original Images)', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Count', fontsize=12)
axes[0].axhline(y=len(df_orig)/3, color='red', linestyle='--', alpha=0.5, label='Perfect Balance')
axes[0].tick_params(axis='x', rotation=45)

for i, v in enumerate(counts.values):
    axes[0].text(i, v + 10, str(v), ha='center', fontweight='bold')
    pct = (v / len(df_orig)) * 100
    axes[0].text(i, v/2, f'{pct:.1f}%', ha='center', fontsize=10, color='white', fontweight='bold')

axes[0].legend()

# All images
counts_all = df['race'].value_counts()[race_order]
axes[1].bar(counts_all.index, counts_all.values, color=colors)
axes[1].set_title('Race Distribution (All Images)', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Count', fontsize=12)
axes[1].axhline(y=len(df)/3, color='red', linestyle='--', alpha=0.5, label='Perfect Balance')
axes[1].tick_params(axis='x', rotation=45)

for i, v in enumerate(counts_all.values):
    axes[1].text(i, v + 100, str(v), ha='center', fontweight='bold')
    pct = (v / len(df)) * 100
    axes[1].text(i, v/2, f'{pct:.1f}%', ha='center', fontsize=10, color='white', fontweight='bold')

axes[1].legend()

plt.tight_layout()
plt.show()

print("\nRace counts (original):")
print(counts)
print(f"\nExpected per race: {len(df_orig)/3}")

## Gender Distribution

In [None]:
# Gender distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Original images
counts = df_orig['gender'].value_counts()
colors = ['#3498db', '#e74c3c']
axes[0].bar(counts.index, counts.values, color=colors)
axes[0].set_title('Gender Distribution (Original Images)', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Count', fontsize=12)
axes[0].axhline(y=len(df_orig)/2, color='red', linestyle='--', alpha=0.5, label='Perfect Balance')

for i, v in enumerate(counts.values):
    axes[0].text(i, v + 10, str(v), ha='center', fontweight='bold')
    pct = (v / len(df_orig)) * 100
    axes[0].text(i, v/2, f'{pct:.1f}%', ha='center', fontsize=10, color='white', fontweight='bold')

axes[0].legend()

# All images
counts_all = df['gender'].value_counts()
axes[1].bar(counts_all.index, counts_all.values, color=colors)
axes[1].set_title('Gender Distribution (All Images)', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Count', fontsize=12)
axes[1].axhline(y=len(df)/2, color='red', linestyle='--', alpha=0.5, label='Perfect Balance')

for i, v in enumerate(counts_all.values):
    axes[1].text(i, v + 100, str(v), ha='center', fontweight='bold')
    pct = (v / len(df)) * 100
    axes[1].text(i, v/2, f'{pct:.1f}%', ha='center', fontsize=10, color='white', fontweight='bold')

axes[1].legend()

plt.tight_layout()
plt.show()

print("\nGender counts (original):")
print(counts)
print(f"\nExpected per gender: {len(df_orig)/2}")

## Cross-Tabulations (Original Images)

In [None]:
# Emotion × Race
ct = pd.crosstab(df_orig['emotion'], df_orig['race'])
print("Emotion × Race:")
print(ct)
print(f"\nExpected count per cell: {len(df_orig) / (4 * 3)}")

plt.figure(figsize=(10, 6))
ct.plot(kind='bar', ax=plt.gca(), width=0.8, color=['#3498db', '#e67e22', '#2ecc71'])
plt.title('Emotion Distribution by Race (Original Images)', fontsize=14, fontweight='bold')
plt.xlabel('Emotion', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.legend(title='Race', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.axhline(y=len(df_orig)/(4*3), color='red', linestyle='--', alpha=0.5, label='Perfect Balance')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# Emotion × Gender
ct = pd.crosstab(df_orig['emotion'], df_orig['gender'])
print("\nEmotion × Gender:")
print(ct)
print(f"\nExpected count per cell: {len(df_orig) / (4 * 2)}")

plt.figure(figsize=(10, 6))
ct.plot(kind='bar', ax=plt.gca(), width=0.8, color=['#e74c3c', '#3498db'])
plt.title('Emotion Distribution by Gender (Original Images)', fontsize=14, fontweight='bold')
plt.xlabel('Emotion', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.legend(title='Gender')
plt.axhline(y=len(df_orig)/(4*2), color='red', linestyle='--', alpha=0.5, label='Perfect Balance')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# Race × Gender
ct = pd.crosstab(df_orig['race'], df_orig['gender'])
print("\nRace × Gender:")
print(ct)
print(f"\nExpected count per cell: {len(df_orig) / (3 * 2)}")

plt.figure(figsize=(10, 6))
ct.plot(kind='bar', ax=plt.gca(), width=0.8, color=['#e74c3c', '#3498db'])
plt.title('Race Distribution by Gender (Original Images)', fontsize=14, fontweight='bold')
plt.xlabel('Race', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.legend(title='Gender')
plt.axhline(y=len(df_orig)/(3*2), color='red', linestyle='--', alpha=0.5, label='Perfect Balance')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## Heatmap: All 24 Strata

In [None]:
# Create a heatmap showing all 24 strata
emotions = ['anger', 'fear', 'calm', 'surprise']
races = ['Caucasian', 'Asian', 'African-American']
genders = ['male', 'female']

# Create matrix for heatmap
data = []
row_labels = []
col_labels = [f"{race}_{gender}" for race in races for gender in genders]

for emotion in emotions:
    row = []
    for race in races:
        for gender in genders:
            count = len(df_orig[
                (df_orig['emotion'] == emotion) &
                (df_orig['race'] == race) &
                (df_orig['gender'] == gender)
            ])
            row.append(count)
    data.append(row)
    row_labels.append(emotion)

# Create heatmap
plt.figure(figsize=(12, 6))
sns.heatmap(data, annot=True, fmt='d', cmap='YlGnBu', 
            xticklabels=col_labels, yticklabels=row_labels,
            cbar_kws={'label': 'Count'})
plt.title('All 24 Strata: Emotion × Race × Gender (Original Images)', fontsize=14, fontweight='bold')
plt.xlabel('Race × Gender', fontsize=12)
plt.ylabel('Emotion', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

print(f"All cells should be: {len(df_orig) / 24}")

## Dataset Source Distribution

In [None]:
# Dataset source distribution
counts = df_orig['dataset_source'].value_counts()

plt.figure(figsize=(10, 6))
plt.bar(counts.index, counts.values, color=['#e74c3c', '#3498db', '#2ecc71'])
plt.title('Dataset Source Distribution (Original Images)', fontsize=14, fontweight='bold')
plt.xlabel('Dataset Source', fontsize=12)
plt.ylabel('Count', fontsize=12)

for i, v in enumerate(counts.values):
    plt.text(i, v + 20, str(v), ha='center', fontweight='bold')
    pct = (v / len(df_orig)) * 100
    plt.text(i, v/2, f'{pct:.1f}%', ha='center', fontsize=10, color='white', fontweight='bold')

plt.tight_layout()
plt.show()

print("\nDataset source counts:")
print(counts)

## Summary Statistics

In [None]:
print("="*70)
print("BALANCED DATASET SUMMARY")
print("="*70)

print(f"\nTotal images: {len(df):,}")
print(f"  Original images: {len(df_orig):,}")
print(f"  Augmented images: {len(df[df['augmented'] != 'original']):,}")

print("\n" + "-"*70)
print("PERFECT BALANCE METRICS (Original Images)")
print("-"*70)

print(f"\nTotal strata: 4 emotions × 3 races × 2 genders = 24")
print(f"Images per stratum: {len(df_orig) / 24}")

print("\nEmotion Balance:")
for emotion, count in df_orig['emotion'].value_counts().sort_index().items():
    expected = len(df_orig) / 4
    diff = count - expected
    print(f"  {emotion:10s}: {count:4d} (Expected: {expected:.0f}, Diff: {diff:+.0f})")

print("\nRace Balance:")
for race, count in df_orig['race'].value_counts().items():
    expected = len(df_orig) / 3
    diff = count - expected
    print(f"  {race:20s}: {count:4d} (Expected: {expected:.0f}, Diff: {diff:+.0f})")

print("\nGender Balance:")
for gender, count in df_orig['gender'].value_counts().items():
    expected = len(df_orig) / 2
    diff = count - expected
    print(f"  {gender:10s}: {count:4d} (Expected: {expected:.0f}, Diff: {diff:+.0f})")

print("\n" + "-"*70)
print("AUGMENTATION BREAKDOWN")
print("-"*70)
aug_counts = df['augmented'].value_counts()
for aug_type in ['original', 'rotation', 'dark', 'high_contrast', 'light_noise', 'blur',
                 'top_rectangle', 'top_left_diagonal', 'top_right_diagonal', 'forehead_bar', 'heavy_hair']:
    if aug_type in aug_counts.index:
        count = aug_counts[aug_type]
        pct = (count / len(df)) * 100
        print(f"  {aug_type:20s}: {count:5d} ({pct:5.2f}%)")

print("\n" + "="*70)