In [None]:
# %% [markdown]
# # Analysis of the Impact of SMI in Spain (2001-2023)
# ## Notebook 1: Data Processing and Dataset Creation
#
# **Author:** Marcos Lacasa-Cazcarra
#
# **Objective:** This notebook loads the raw data on salaries, employees, and tax withholdings from the Spanish Tax Agency. From these sources, it performs three main analyses and generates the clean, processed datasets that will be used for subsequent visualizations and modeling.
#
# **Generated Datasets (saved in `data/processed/`):**
#
# 1.  **`evolucion_grupos_salariales.csv`**: Evolution of proportional salary groups (quintiles based on 2001 salary mass).
# 2.  **`indices_desigualdad.csv`**: Annual calculation of Gini and Theil indices, and the redistributive impact of taxes.
# 3.  **`analisis_fases_economicas.csv`**: Aggregation of key indicators by defined economic phases.

# %% [code]
import pandas as pd
import numpy as np
import os

# %% [markdown]
# ### 1. Setup and Data Loading
#
# This section imports necessary libraries and loads the three raw data files from the `data/raw/` directory. It also creates the output directory `data/processed/` if it doesn't exist.

# %% [code]
# --- Define file paths using relative paths for portability ---
RAW_DATA_PATH = '../data/raw/'
PROCESSED_DATA_PATH = '../data/processed/'

# --- Create the processed data directory if it does not exist ---
os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)

# --- Load raw data ---
try:
    salarios_df = pd.read_csv(os.path.join(RAW_DATA_PATH, 'Salarios.csv'))
    asalariados_df = pd.read_csv(os.path.join(RAW_DATA_PATH, 'Asalariados.csv'))
    retenciones_df = pd.read_csv(os.path.join(RAW_DATA_PATH, 'Retenciones.csv'))
    print("Raw data loaded successfully.")
    print(f"Salarios.csv dimensions: {salarios_df.shape}")
    print(f"Asalariados.csv dimensions: {asalariados_df.shape}")
    print(f"Retenciones.csv dimensions: {retenciones_df.shape}")
except FileNotFoundError as e:
    print(f"Error loading data: {e}")
    print("Please ensure 'Salarios.csv', 'Asalariados.csv', and 'Retenciones.csv' are in the 'data/raw/' directory.")

# %% [markdown]
# ### 2. Analysis Functions
#
# This section contains all the necessary functions for the analysis.

# %% [code]
def create_salary_groups(salarios_df, year='2001', num_groups=5):
    """Creates salary bracket groups based on the salary mass of a base year."""
    df_sorted = salarios_df.sort_values('franja')
    total_salary = df_sorted[year].sum()
    target_group_size = total_salary / num_groups

    groups = []
    current_group_franjas = []
    current_sum = 0

    # Iterate through brackets and group them until the target mass is reached
    for _, row in df_sorted.iterrows():
        current_group_franjas.append(row['franja'])
        current_sum += row[year]
        if current_sum >= target_group_size and len(groups) < num_groups - 1:
            groups.append(current_group_franjas)
            current_group_franjas = []
            current_sum = 0

    # Add all remaining brackets to the last group
    groups.append(current_group_franjas)

    # Create a definition DataFrame for the groups
    groups_def = []
    for i, franjas in enumerate(groups):
        if franjas:
            groups_def.append({'grupo': i + 1, 'min_franja': min(franjas), 'max_franja': max(franjas), 'franjas': franjas})

    return pd.DataFrame(groups_def)

def calculate_yearly_evolution(salarios_df, groups_df):
    """Calculates the annual evolution of the salary mass for the defined groups."""
    years = [str(year) for year in range(2001, 2024)]
    evolution_data = []

    for year in years:
        year_total = salarios_df[year].sum()
        for _, group in groups_df.iterrows():
            group_sum = salarios_df[salarios_df['franja'].isin(group['franjas'])][year].sum()
            evolution_data.append({
                'year': int(year),
                'grupo': group['grupo'],
                'min_franja': group['min_franja'],
                'max_franja': group['max_franja'],
                'total': group_sum,
                'percent': (group_sum / year_total) * 100 if year_total > 0 else 0
            })
    return pd.DataFrame(evolution_data)

def calculate_gini(salaries, workers):
    """Calculates the Gini index from grouped data."""
    # Ensure no division by zero
    avg_salary_per_bracket = np.divide(salaries, workers, out=np.zeros_like(salaries, dtype=float), where=workers!=0)
    sorted_indices = np.argsort(avg_salary_per_bracket)

    salaries_sorted = salaries[sorted_indices]
    workers_sorted = workers[sorted_indices]

    cum_workers = np.cumsum(workers_sorted)
    cum_salary = np.cumsum(salaries_sorted)

    total_workers = cum_workers[-1]
    total_salary = cum_salary[-1]

    if total_workers == 0 or total_salary == 0:
        return 0

    prop_workers = cum_workers / total_workers
    prop_salary = cum_salary / total_salary

    # Area under the Lorenz curve calculation
    gini = 1 - np.sum((prop_salary[1:] + prop_salary[:-1]) * (prop_workers[1:] - prop_workers[:-1]))
    return gini

def calculate_theil(salaries, workers):
    """Calculates the Theil T index from grouped data."""
    total_salary = np.sum(salaries)
    total_workers = np.sum(workers)

    if total_workers == 0 or total_salary == 0:
        return 0

    avg_salary_global = total_salary / total_workers
    avg_salary_bracket = np.divide(salaries, workers, out=np.zeros_like(salaries, dtype=float), where=workers!=0)

    # Avoid log(0) issues
    ratio = np.divide(avg_salary_bracket, avg_salary_global, out=np.ones_like(avg_salary_bracket, dtype=float), where=avg_salary_global!=0 and avg_salary_bracket!=0)
    theil = np.sum((workers / total_workers) * ratio * np.log(ratio))
    return theil

def get_smi_data():
    """Returns annual SMI data as a dictionary."""
    return {
        2001: 6068.30, 2002: 6190.80, 2003: 6316.80, 2004: 6871.20, 2005: 7182.00,
        2006: 7572.60, 2007: 7988.40, 2008: 8400.00, 2009: 8736.00, 2010: 8866.20,
        2011: 8979.60, 2012: 8979.60, 2013: 9034.20, 2014: 9034.20, 2015: 9080.40,
        2016: 9172.80, 2017: 9907.80, 2018: 10302.60, 2019: 12600.00, 2020: 13300.00,
        2021: 13510.00, 2022: 14000.00, 2023: 15120.00
    }

# %% [markdown]
# ### 3. Processing and Dataset Generation
#
# The following cells execute the analysis functions and save the resulting datasets to the `data/processed/` directory.

# %% [code]
# --- ANALYSIS 1: EVOLUTION BY SALARY MASS GROUPS ---
print("Running Analysis 1: Evolution by Salary Mass Groups...")
grupos_salariales_df = create_salary_groups(salarios_df, year='2001', num_groups=5)
evolucion_grupos_salariales_df = calculate_yearly_evolution(salarios_df, grupos_salariales_df)
evolucion_grupos_salariales_df.to_csv(os.path.join(PROCESSED_DATA_PATH, 'evolucion_grupos_salariales.csv'), index=False)
print("Dataset 'evolucion_grupos_salariales.csv' saved.")

# --- ANALYSIS 2: INEQUALITY INDICES ---
print("\nRunning Analysis 2: Inequality Indices...")
years_str = [str(y) for y in range(2001, 2024)]
smi_data = get_smi_data()
inequality_data = []

for year in years_str:
    y = int(year)
    salaries = salarios_df[year].values
    workers = asalariados_df[year].values
    taxes = retenciones_df[year].values
    net_salaries = salaries - taxes

    # Calculate Gini and Theil for gross and net salaries
    gini_gross = calculate_gini(salaries, workers)
    gini_net = calculate_gini(net_salaries, workers)
    theil_gross = calculate_theil(salaries, workers)
    theil_net = calculate_theil(net_salaries, workers)

    # Calculate aggregate metrics
    total_salary = np.sum(salaries)
    total_workers = np.sum(workers)
    mean_salary = total_salary / total_workers if total_workers > 0 else 0

    inequality_data.append({
        'year': y,
        'gini_gross': gini_gross,
        'gini_net': gini_net,
        'gini_redistribution': (gini_gross - gini_net) / gini_gross * 100 if gini_gross > 0 else 0,
        'theil_gross': theil_gross,
        'theil_net': theil_net,
        'theil_redistribution': (theil_gross - theil_net) / theil_gross * 100 if theil_gross > 0 else 0,
        'mean_salary_gross': mean_salary,
        'smi': smi_data.get(y),
        'smi_ratio': (smi_data.get(y) / mean_salary) * 100 if mean_salary > 0 else 0
    })

indices_desigualdad_df = pd.DataFrame(inequality_data)
indices_desigualdad_df.to_csv(os.path.join(PROCESSED_DATA_PATH, 'indices_desigualdad.csv'), index=False)
print("Dataset 'indices_desigualdad.csv' saved.")


# --- ANALYSIS 3: ECONOMIC PHASES ---
print("\nRunning Analysis 3: Economic Phases...")
phases = {
    'Pre-Crisis': list(range(2001, 2008)),
    'Financial Crisis': list(range(2008, 2015)),
    'Recovery': list(range(2015, 2020)),
    'Pandemic & Post-Pandemic': list(range(2020, 2024))
}
phases_analysis = []

for phase, years_list in phases.items():
    phase_data = indices_desigualdad_df[indices_desigualdad_df['year'].isin(years_list)]
    start_year = min(years_list)
    end_year = max(years_list)

    # Safely get start and end Gini values
    start_gini_series = phase_data[phase_data['year'] == start_year]['gini_gross']
    end_gini_series = phase_data[phase_data['year'] == end_year]['gini_gross']

    start_gini = start_gini_series.iloc[0] if not start_gini_series.empty else 0
    end_gini = end_gini_series.iloc[0] if not end_gini_series.empty else 0

    phases_analysis.append({
        'phase': phase,
        'start_year': start_year,
        'end_year': end_year,
        'avg_gini_gross': phase_data['gini_gross'].mean(),
        'avg_gini_redistribution': phase_data['gini_redistribution'].mean(),
        'avg_smi_ratio': phase_data['smi_ratio'].mean(),
        'gini_change': (end_gini - start_gini) / start_gini * 100 if start_gini > 0 else 0
    })
analisis_fases_economicas_df = pd.DataFrame(phases_analysis)
analisis_fases_economicas_df.to_csv(os.path.join(PROCESSED_DATA_PATH, 'analisis_fases_economicas.csv'), index=False)
print("Dataset 'analisis_fases_economicas.csv' saved.")

print("\n--- Processing complete. Processed datasets are in 'data/processed/'. ---")

# %% [markdown]
# ### 4. Quick Preview of Generated Datasets

# %% [code]
print("--- Preview of evolucion_grupos_salariales.csv ---")
print(evolucion_grupos_salariales_df.head())

# %% [code]
print("\n--- Preview of indices_desigualdad.csv ---")
print(indices_desigualdad_df.head())

# %% [code]
print("\n--- Preview of analisis_fases_economicas.csv ---")
print(analisis_fases_economicas_df.head())