In [None]:
# %% [markdown]
# # Analysis of the SMI Impact in Spain (2001-2023)
# ## Notebook 1: Data Processing, Analysis, and Visualization
#
# **Author:** Marcos Lacasa-Cazcarra
#
# **Objective:** This notebook loads raw data on salaries, employees, and tax withholdings. It then processes this data to generate analytical datasets and key visualizations that explore wage inequality, the redistributive role of taxes, and the impact of the Minimum Interprofessional Salary (SMI) across different economic phases in Spain.
#
# **Key steps:**
# 1.  Load raw data from the `data/raw` directory.
# 2.  Define functions for data analysis (group creation, inequality indices).
# 3.  Generate and save processed datasets to the `data/processed` directory.
# 4.  Create and save visualizations to the `figures` directory.

# %% [code]
# --- 1. Setup and Imports ---
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter

# --- Matplotlib and Seaborn Styling ---
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("viridis")
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 16

# %% [markdown]
# ### 2. Data Loading and Path Definition
#
# Define relative paths for portability and create output directories if they don't exist. This ensures the script works seamlessly when cloned from GitHub.

# %% [code]
# --- Define relative paths ---
RAW_DATA_PATH = '../data/raw/'
PROCESSED_DATA_PATH = '../data/processed/'
FIGURES_PATH = '../figures/'

# --- Create output directories if they do not exist ---
os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)
os.makedirs(FIGURES_PATH, exist_ok=True)

# --- Load raw data ---
try:
    salarios_df = pd.read_csv(os.path.join(RAW_DATA_PATH, 'Salarios.csv'))
    asalariados_df = pd.read_csv(os.path.join(RAW_DATA_PATH, 'Asalariados.csv'))
    retenciones_df = pd.read_csv(os.path.join(RAW_DATA_PATH, 'Retenciones.csv'))
    print("Raw data loaded successfully.")
except FileNotFoundError as e:
    print(f"Error loading data: {e}")
    print(f"Ensure the raw CSV files are in the '{RAW_DATA_PATH}' directory.")

# %% [markdown]
# ### 3. Core Analysis Functions
#
# This section contains all the functions used for creating data groups, calculating inequality indices, and defining economic phases.

# %% [code]
def get_smi_data():
    """Returns annual SMI data as a dictionary."""
    return {
        2001: 6068.30, 2002: 6190.80, 2003: 6316.80, 2004: 6871.20, 2005: 7182.00,
        2006: 7572.60, 2007: 7988.40, 2008: 8400.00, 2009: 8736.00, 2010: 8866.20,
        2011: 8979.60, 2012: 8979.60, 2013: 9034.20, 2014: 9034.20, 2015: 9080.40,
        2016: 9172.80, 2017: 9907.80, 2018: 10302.60, 2019: 12600.00, 2020: 13300.00,
        2021: 13510.00, 2022: 14000.00, 2023: 15120.00
    }

def get_economic_phases():
    """Returns a dictionary defining economic phases."""
    return {
        'Pre-Crisis': list(range(2001, 2008)),
        'Financial Crisis': list(range(2008, 2015)),
        'Recovery': list(range(2015, 2020)),
        'Pandemic & Post-Pandemic': list(range(2020, 2024))
    }

def create_salary_groups(salarios_df, year='2001', num_groups=5):
    """Creates salary bracket groups based on the salary mass of a base year."""
    df_sorted = salarios_df.sort_values('franja')
    total_salary = df_sorted[year].sum()
    target_group_size = total_salary / num_groups

    groups = []
    current_group_franjas = []
    current_sum = 0

    # Iterate through brackets and group them until the target mass is reached
    for _, row in df_sorted.iterrows():
        current_group_franjas.append(row['franja'])
        current_sum += row[year]
        if current_sum >= target_group_size and len(groups) < num_groups - 1:
            groups.append(current_group_franjas)
            current_group_franjas = []
            current_sum = 0

    # Add all remaining brackets to the last group
    groups.append(current_group_franjas)

    # Create a definition DataFrame for the groups
    groups_def = []
    for i, franjas in enumerate(groups):
        if franjas:
            groups_def.append({'group': i + 1, 'min_franja': min(franjas), 'max_franja': max(franjas), 'franjas': franjas})

    return pd.DataFrame(groups_def)

def calculate_yearly_evolution(salarios_df, groups_df):
    """Calculates the annual evolution of the salary mass for the defined groups."""
    years = [str(year) for year in range(2001, 2024)]
    evolution_data = []

    for year in years:
        year_total = salarios_df[year].sum()
        for _, group in groups_df.iterrows():
            group_sum = salarios_df[salarios_df['franja'].isin(group['franjas'])][year].sum()
            evolution_data.append({
                'year': int(year),
                'group': group['group'],
                'min_franja': group['min_franja'],
                'max_franja': group['max_franja'],
                'total_salary_mass': group_sum,
                'percent_salary_mass': (group_sum / year_total) * 100 if year_total > 0 else 0
            })
    return pd.DataFrame(evolution_data)

def calculate_gini(salaries, workers):
    """Calculates the Gini index from grouped data."""
    avg_salary_per_bracket = np.divide(salaries, workers, out=np.zeros_like(salaries, dtype=float), where=workers!=0)
    sorted_indices = np.argsort(avg_salary_per_bracket)

    salaries_sorted = salaries[sorted_indices]
    workers_sorted = workers[sorted_indices]

    cum_workers = np.cumsum(workers_sorted)
    cum_salary = np.cumsum(salaries_sorted)

    total_workers = cum_workers[-1]
    total_salary = cum_salary[-1]

    if total_workers == 0 or total_salary == 0:
        return 0

    prop_workers = cum_workers / total_workers
    prop_salary = cum_salary / total_salary

    gini = 1 - np.sum((prop_salary[1:] + prop_salary[:-1]) * (prop_workers[1:] - prop_workers[:-1]))
    return gini

def calculate_theil(salaries, workers):
    """Calculates the Theil T index from grouped data."""
    total_salary = np.sum(salaries)
    total_workers = np.sum(workers)

    if total_workers == 0 or total_salary == 0:
        return 0

    avg_salary_global = total_salary / total_workers
    avg_salary_bracket = np.divide(salaries, workers, out=np.zeros_like(salaries, dtype=float), where=workers!=0)

    ratio = np.divide(avg_salary_bracket, avg_salary_global, out=np.ones_like(avg_salary_bracket, dtype=float), where=(avg_salary_global!=0) & (avg_salary_bracket!=0))
    theil = np.sum((workers / total_workers) * ratio * np.log(ratio))
    return theil

# %% [markdown]
# ### 4. Data Processing and Dataset Generation
#
# This section runs the analysis functions and saves the resulting datasets to the `data/processed/` directory.

# %% [code]
# --- ANALYSIS 1: EVOLUTION BY SALARY MASS GROUPS ---
print("Running Analysis 1: Evolution by Salary Mass Groups...")
salary_groups_df = create_salary_groups(salarios_df, year='2001', num_groups=5)
salary_evolution_df = calculate_yearly_evolution(salarios_df, salary_groups_df)
salary_evolution_df.to_csv(os.path.join(PROCESSED_DATA_PATH, 'evolucion_grupos_salariales.csv'), index=False)
print(f"Dataset saved to {PROCESSED_DATA_PATH}evolucion_grupos_salariales.csv")

# --- ANALYSIS 2: INEQUALITY INDICES ---
print("\nRunning Analysis 2: Inequality Indices...")
years = [str(y) for y in range(2001, 2024)]
smi_data = get_smi_data()
inequality_data = []

for year in years:
    y = int(year)
    salaries = salarios_df[year].values
    workers = asalariados_df[year].values
    taxes = retenciones_df[year].values
    net_salaries = salaries - taxes

    total_salary = np.sum(salaries)
    total_workers = np.sum(workers)
    mean_salary = total_salary / total_workers if total_workers > 0 else 0

    inequality_data.append({
        'year': y,
        'gini_gross': calculate_gini(salaries, workers),
        'gini_net': calculate_gini(net_salaries, workers),
        'theil_gross': calculate_theil(salaries, workers),
        'theil_net': calculate_theil(net_salaries, workers),
        'mean_salary_gross': mean_salary,
        'smi': smi_data.get(y),
        'smi_ratio': (smi_data.get(y) / mean_salary) * 100 if mean_salary > 0 else 0
    })

inequality_indices_df = pd.DataFrame(inequality_data)
# Calculate redistribution effect
inequality_indices_df['gini_redistribution'] = (inequality_indices_df['gini_gross'] - inequality_indices_df['gini_net']) / inequality_indices_df['gini_gross'] * 100
inequality_indices_df['theil_redistribution'] = (inequality_indices_df['theil_gross'] - inequality_indices_df['theil_net']) / inequality_indices_df['theil_gross'] * 100
inequality_indices_df.to_csv(os.path.join(PROCESSED_DATA_PATH, 'indices_desigualdad.csv'), index=False)
print(f"Dataset saved to {PROCESSED_DATA_PATH}indices_desigualdad.csv")


# --- ANALYSIS 3: ECONOMIC PHASES ---
print("\nRunning Analysis 3: Analysis by Economic Phases...")
phases = get_economic_phases()
phases_analysis = []

for phase, years_list in phases.items():
    phase_data = inequality_indices_df[inequality_indices_df['year'].isin(years_list)]
    start_year, end_year = min(years_list), max(years_list)

    start_gini = phase_data[phase_data['year'] == start_year]['gini_gross'].iloc[0]
    end_gini = phase_data[phase_data['year'] == end_year]['gini_gross'].iloc[0]

    phases_analysis.append({
        'phase': phase,
        'start_year': start_year,
        'end_year': end_year,
        'avg_gini_gross': phase_data['gini_gross'].mean(),
        'avg_smi_ratio': phase_data['smi_ratio'].mean(),
        'gini_change_pct': (end_gini - start_gini) / start_gini * 100 if start_gini > 0 else 0
    })

phases_analysis_df = pd.DataFrame(phases_analysis)
phases_analysis_df.to_csv(os.path.join(PROCESSED_DATA_PATH, 'analisis_fases_economicas.csv'), index=False)
print(f"Dataset saved to {PROCESSED_DATA_PATH}analisis_fases_economicas.csv")

print("\n--- All datasets generated successfully. ---")


# %% [markdown]
# ### 5. Visualization
#
# This section generates a key visualization from the analysis and saves it to the `figures/` directory.

# %% [code]
# --- Visualization: Evolution of Salary Mass Distribution ---
print("\nGenerating visualization...")

# Plotting function
def plot_salary_mass_evolution(evolution_df):
    plt.figure(figsize=(16, 9))

    # Define colors and economic phases
    group_colors = {1: '#1f77b4', 2: '#ff7f0e', 3: '#2ca02c', 4: '#d62728', 5: '#9467bd'}
    phases = get_economic_phases()
    phase_colors = {'Pre-Crisis': '#e1f5fe', 'Financial Crisis': '#fff1e0', 'Recovery': '#e8f5e9', 'Pandemic & Post-Pandemic': '#ffebee'}

    # Add background shading for economic phases
    for phase, years in phases.items():
        plt.axvspan(min(years), max(years) + 1, color=phase_colors[phase], alpha=0.5, zorder=0)
        plt.text((min(years) + max(years)) / 2, -0.1, phase, ha='center', va='top', transform=plt.gca().get_xaxis_transform(), fontsize=10)

    # Plot evolution for each group
    for group_num in sorted(evolution_df['group'].unique()):
        group_data = evolution_df[evolution_df['group'] == group_num]
        label = f"Group {group_num} (Quintile {group_num})"
        plt.plot(group_data['year'], group_data['percent'], marker='o', linewidth=2, color=group_colors.get(group_num), label=label)

    # Formatting the plot
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda y, _: f'{y:.1f}%'))
    plt.title('Evolution of Salary Mass Distribution by Quintile (2001-2023)', fontsize=18, pad=20)
    plt.ylabel('Share of Total Salary Mass (%)', fontsize=12)
    plt.xlabel('Year', fontsize=12)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.legend(title='Salary Mass Groups (based on 2001)', loc='upper left')
    plt.tight_layout(rect=[0, 0.05, 1, 1]) # Adjust layout to make space for phase labels

    # Save the figure
    fig_path = os.path.join(FIGURES_PATH, 'evolucion_masa_salarial.png')
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Plot saved to {fig_path}")

# Generate and display the plot
plot_salary_mass_evolution(salary_evolution_df)

# %% [markdown]
# ---
# ### End of Notebook
# The processed datasets are now available in the `data/processed` folder and a sample visualization is in the `figures` folder.