In [None]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
# Load the data
wave_2020 = pd.read_csv('../data/processed/scalar/wave_1.csv')
wave_2023 = pd.read_csv('../data/processed/scalar/wave_5.csv')

In [None]:
# Define shared category lists and mappings
income_levels = [
    'Less than 26130 Euro',
    'Between 26131 and 42785 Euro',
    'Between 42786 and 66935 Euro',
    'Between 66936 and 102540 Euro',
    'More than 102540 Euro',
    'Prefer not to say'
]

income_levels_simplified = {
    'Less than 26130 Euro': '<26,000',
    'Between 26131 and 42785 Euro': '26,000-42,000',
    'Between 42786 and 66935 Euro': '42,000-67,000',
    'Between 66936 and 102540 Euro': '67,000-102,000',
    'More than 102540 Euro': '>102,000',
    'Prefer not to say': 'Prefer not to say'
}
income_levels_simplified_cats = list(income_levels_simplified.values())

savings_amounts = [
    'Little to no savings', "Half month's wages", "1 month's wages",
    "1.5 month's wages", "2 month's wages", "3 month's wages",
    '4+ month\'s wages', "Don't know", "Prefer not to say"
]

education_level = [
    'Less than secondary education',
    'Secondary education',
    'College/tertiary education',
    'Postgraduate'
]
age = ['16-24', '25-34', '35-44', '45-54', '55-64', '65+']

# For gender, allow for different category sets per wave
gender_wave2020 = ['Female', 'Male']
gender_wave2023 = ['Female', 'Male', 'Prefer not to say']

In [None]:
# Helper function to categorize columns
def categorize_column(df, col, categories):
    df[col] = pd.Categorical(df[col], categories=categories, ordered=True)

# Prepare Wave 2020
wave_2020['income_level_simplified'] = wave_2020['income_level'].replace(income_levels_simplified)
categorize_column(wave_2020, 'income_level_simplified', income_levels_simplified_cats)
categorize_column(wave_2020, 'savings_amount', savings_amounts)
categorize_column(wave_2020, 'education', education_level)
categorize_column(wave_2020, 'age', age)
categorize_column(wave_2020, 'gender', gender_wave2020)

# Prepare Wave 2023
wave_2023['income_level_simplified'] = wave_2023['income_level'].replace(income_levels_simplified)
categorize_column(wave_2023, 'income_level_simplified', income_levels_simplified_cats)
categorize_column(wave_2023, 'savings_amount', savings_amounts)
categorize_column(wave_2023, 'education_level_agg', education_level)
categorize_column(wave_2023, 'age_binned', age)
categorize_column(wave_2023, 'gender_agg', gender_wave2023)

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(12, 12))

columns_wave_2020 = ['gender', 'age', 'education', 'income_level_simplified', 'savings_amount']
columns_wave_2023 = ['gender_agg', 'age_binned', 'education_level_agg', 'income_level_simplified', 'savings_amount']
category_names = ['Gender', 'Age', 'Education', 'Income level', 'Savings amount']

colors = ['#1f77b4', '#ff7f0e']

for i, (col1, col5) in enumerate(zip(columns_wave_2020, columns_wave_2023)):
    row = i // 2
    col_idx = i % 2

    # Get value counts for both waves, reindex to union of categories for alignment
    cats1 = wave_2020[col1].cat.categories if hasattr(wave_2020[col1], "cat") else wave_2020[col1].unique()
    cats5 = wave_2023[col5].cat.categories if hasattr(wave_2023[col5], "cat") else wave_2023[col5].unique()
    all_cats = list(dict.fromkeys(list(cats1) + list(cats5)))  # preserve order, remove duplicates

    vc1 = wave_2020[col1].value_counts(normalize=True, sort=False).reindex(all_cats, fill_value=0)
    vc5 = wave_2023[col5].value_counts(normalize=True, sort=False).reindex(all_cats, fill_value=0)

    y = range(len(all_cats))
    bar_height = 0.35

    # Plot Wave 2020
    bars1 = axes[row, col_idx].barh(
        [yy - bar_height/2 for yy in y], vc1.values * 100, height=bar_height, color=colors[0], label='Wave 2020'
    )

    # Plot Wave 2023
    bars5 = axes[row, col_idx].barh(
        [yy + bar_height/2 for yy in y], vc5.values * 100, height=bar_height, color=colors[1], label='Wave 2023'
    )

    axes[row, col_idx].set_xlabel('Respondents (%)')
    axes[row, col_idx].set_ylabel(category_names[i])
    axes[row, col_idx].set_xlim(0, 100)
    axes[row, col_idx].set_yticks(y)
    axes[row, col_idx].set_yticklabels(all_cats)

    # Add value annotations
    for j, (bar, value) in enumerate(zip(bars1, vc1.values * 100)):
        if value > 0:
            axes[row, col_idx].text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2,
                                   f'{value:.1f}%', ha='left', va='center', fontsize=9, color=colors[0])
    for j, (bar, value) in enumerate(zip(bars5, vc5.values * 100)):
        if value > 0:
            axes[row, col_idx].text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2,
                                   f'{value:.1f}%', ha='left', va='center', fontsize=9, color=colors[1])

    # Remove spines
    axes[row, col_idx].spines['top'].set_visible(False)
    axes[row, col_idx].spines['right'].set_visible(False)

    # Add legend only to the first subplot
    if i == 0:
        axes[row, col_idx].legend(loc='upper right', fontsize=10)

# Remove the unused last subplot (bottom right)
fig.delaxes(axes[2, 1])

plt.tight_layout()

# Save in 3 formats
for format in ['png', 'pdf', 'svg', 'eps']:
    plt.savefig(f'../figures/fig1.{format}', format=format, dpi=300, bbox_inches='tight')