# Validation of PyPSA-Eur model inputs focusing on the generation and flows 

## Imports and Configuration

In [None]:
import pandas as pd
import pypsa
import pycountry
import matplotlib.pyplot as plt
import numpy as np
import os

from itertools import combinations


year = 2023
network_path = "../results/validation_2023/networks/new_base_s_39__3H_2025.nc"
ember_monthly_data_path = "../ember_data/europe_monthly_full_release_long_format.csv"
ember_yearly_data_path = "../ember_data/yearly_full_release_long_format.csv"
power_flows_data_path = "../entsoe_data/physical_energy_power_flows_2023.csv"

countries = ['AL', 'AT', 'BA', 'BE', 'BG', 'CH', 'CY', 'CZ', 'DE', 'DK', 'EE', 'ES',
             'FI', 'FR', 'GB', 'GR', 'HR', 'HU', 'IE', 'IT', 'LT', 'LU', 'LV', 'ME',
             'MK', 'MT', 'NL', 'NO', 'PL', 'PT', 'RO', 'RS', 'SE', 'SI', 'SK']

color_dict = {
    "Bioenergy": "#baa741",
    "Gas": "#e05b09",
    "Hard coal": "#545454",
    "Hydro": "#298c81",
    "Lignite": "#826837",
    "Nuclear": "#ff8c00",
    "Offshore wind": "#6895dd",
    "Onshore wind": "#235ebc",
    "Other fossil": "#000000",
    "Other renewables": "#e3d37d",
    "Solar": "#f9d002"
}

## Load data from ember

In [None]:
n = pypsa.Network(network_path)
ember_monthly = pd.read_csv(ember_monthly_data_path)

## Helper function to detect columns

In [None]:
def detect_column(columns, keywords):
    for keyword in keywords:
        for col in columns:
            if keyword.lower() in col.lower():
                return col
    return None

## Process Ember generation data (Yearly CSV)

In [None]:
def process_ember_generation_yearly():
    print(f"Processing yearly CSV: {ember_yearly_data_path}")
    if not os.path.exists(ember_yearly_data_path):
        print(f"Yearly CSV not found at {ember_yearly_data_path}. Falling back to monthly aggregation.")
        return None

    try:
        df = pd.read_csv(ember_yearly_data_path)
        print(f"Yearly CSV loaded. Shape: {df.shape}")
        print(f"Columns: {df.columns.tolist()}")
        
        iso_col = detect_column(df.columns, ['iso 3 code', 'iso3', 'country code'])
        variable_col = detect_column(df.columns, ['variable', 'fuel', 'technology'])
        value_col = detect_column(df.columns, ['value', 'generation', 'amount'])
        unit_col = detect_column(df.columns, ['unit'])
        subcategory_col = detect_column(df.columns, ['subcategory', 'category'])
        continent_col = detect_column(df.columns, ['continent', 'region'])
        year_col = detect_column(df.columns, ['year', 'date'])

        print(f"Detected columns: ISO={iso_col}, Variable={variable_col}, Value={value_col}, "
              f"Unit={unit_col}, Subcategory={subcategory_col}, Continent={continent_col}, Year={year_col}")

        def iso3_to_iso2(iso3):
            try:
                return pycountry.countries.get(alpha_3=iso3).alpha_2
            except:
                return None

        if iso_col:
            df["ISO"] = df[iso_col].apply(iso3_to_iso2)
        else:
            print("Warning: ISO column not found. Assuming 'ISO' column exists.")
            iso_col = "ISO"

        filters = []
        if continent_col:
            filters.append(df[continent_col] == "Europe")
        if iso_col:
            filters.append(df["ISO"].isin(countries))
        if unit_col:
            filters.append(df[unit_col] == "TWh")
        if subcategory_col:
            filters.append(df[subcategory_col] == "Fuel")
        if year_col:
            filters.append(df[year_col].astype(str).str.startswith(str(year)))

        if filters:
            df = df[np.logical_and.reduce(filters)]
        print(f"After filtering: Shape {df.shape}")

        required_cols = ["ISO", variable_col or "Variable", value_col or "Value"]
        available_cols = [col for col in required_cols if col in df.columns]
        df = df[available_cols]
        df = df.rename(columns={variable_col: "Variable", value_col: "Value"})

        df = df.groupby(["ISO", "Variable"], as_index=False)["Value"].sum()
        print(f"Processed yearly data sample:\n{df.head().to_string()}")
        return df.set_index(["ISO", "Variable"])
    except Exception as e:
        print(f"Error processing yearly CSV: {e}")
        return None

ember_generation_yearly = process_ember_generation_yearly()

## Process Ember generation data (Monthly CSV) 

In [None]:
def process_ember_generation_monthly():
    print(f"Processing monthly CSV: {ember_monthly_data_path}")
    df = ember_monthly[ember_monthly["Continent"] == "Europe"].copy()
    print(f"Monthly CSV loaded. Shape: {df.shape}")

    def iso3_to_iso2(iso3):
        try:
            return pycountry.countries.get(alpha_3=iso3).alpha_2
        except:
            return None

    df["ISO"] = df["ISO 3 code"].apply(iso3_to_iso2)
    df = df[df["ISO"].isin(countries)]
    df = df[df["Date"].str.startswith(str(year))]
    df = df[df["Unit"] == "TWh"]
    df = df[df["Subcategory"] == "Fuel"]
    df = df[["ISO", "Date", "Variable", "Value", "Unit"]]

    df_yearly = df.groupby(["ISO", "Variable"], as_index=False)["Value"].sum()
    df_yearly["Unit"] = "TWh"
    print(f"Processed monthly aggregated data sample:\n{df_yearly.head().to_string()}")
    return df_yearly.set_index(["ISO", "Variable"]).drop(["Unit"], axis=1)

ember_generation_monthly = process_ember_generation_monthly()

## Wrapper function to choose processing method

In [None]:
def process_ember_generation(use_yearly=False):
    if use_yearly:
        result = process_ember_generation_yearly()
        if result is not None:
            return result
        print("Falling back to monthly aggregation due to yearly CSV issues.")
    return process_ember_generation_monthly()

ember_generation = process_ember_generation(use_yearly=False)

## Process PyPSA-Eur model generation data

In [None]:
def process_pypsa_generation():
    pypsa_to_ember = {
        "biomass": "Bioenergy", "Bioenergy": "Bioenergy",
        "gas": "Gas", "Gas": "Gas", "CCGT": "Gas", "OCGT": "Gas",
        "coal": "Hard coal", "Hard coal": "Hard coal",
        "lignite": "Lignite", "Lignite": "Lignite",
        "hydro": "Hydro", "Hydro": "Hydro", "PHS": "Hydro", "ror": "Hydro",
        "Nuclear": "Nuclear", "nuclear": "Nuclear",
        "offwind-ac": "Offshore wind", "offwind-dc": "Offshore wind",
        "offwind-float": "Offshore wind", "Offshore wind": "Offshore wind",
        "onwind": "Onshore wind", "Onshore wind": "Onshore wind",
        "oil": "Other fossil", "Other fossil": "Other fossil",
        "geothermal": "Other renewables", "Other renewables": "Other renewables",
        "solar": "Solar", "solar-hsat": "Solar", "Solar": "Solar"
    }

    gen_meta = n.generators[["bus", "carrier"]].copy()
    gen_meta.loc[:, "country"] = gen_meta["bus"].str[:2]
    gen_energy = n.generators_t.p.T.multiply(n.snapshot_weightings.generators).T.sum(axis=0) / 1e6  # MWh to TWh
    gen_energy.index.name = "generator"
    gen_energy = gen_energy.reset_index().rename(columns={0: "Value"})
    gen_energy = gen_energy.merge(gen_meta, left_on="generator", right_index=True)
    gen_grouped = gen_energy.groupby(["country", "carrier"], as_index=False)["Value"].sum()
    gen_grouped["Ember_Variable"] = gen_grouped["carrier"].map(pypsa_to_ember).fillna(gen_grouped["carrier"])
    gen_renamed = gen_grouped.groupby(["country", "Ember_Variable"], as_index=False)["Value"].sum()
    gen_renamed = gen_renamed.rename(columns={"country": "ISO", "Ember_Variable": "Variable"})

    sto_meta = n.storage_units[["bus", "carrier"]].copy()
    sto_meta.loc[:, "country"] = sto_meta["bus"].str[:2]
    sto_energy = n.storage_units_t.p.T.multiply(n.snapshot_weightings.stores).T.sum(axis=0) / 1e6  # MWh to TWh
    sto_energy.index.name = "storage_unit"
    sto_energy = sto_energy.reset_index().rename(columns={0: "Value"})
    sto_energy = sto_energy.merge(sto_meta, left_on="storage_unit", right_index=True)
    sto_grouped = sto_energy.groupby(["country", "carrier"], as_index=False)["Value"].sum()
    sto_grouped["Ember_Variable"] = sto_grouped["carrier"].map(pypsa_to_ember).fillna(sto_grouped["carrier"])
    sto_renamed = sto_grouped.groupby(["country", "Ember_Variable"], as_index=False)["Value"].sum()
    sto_renamed = sto_renamed.rename(columns={"country": "ISO", "Ember_Variable": "Variable"})

    gen_and_sto = pd.concat([gen_renamed, sto_renamed], ignore_index=True)
    gen_and_sto = gen_and_sto.groupby(["ISO", "Variable"], as_index=False)["Value"].sum()
    print(f"PyPSA generation sample:\n{gen_and_sto.head().to_string()}")
    return gen_and_sto.set_index(["ISO", "Variable"])

pypsa_generation = process_pypsa_generation()

## Process PyPSA generation sector

In [None]:
def process_pypsa_generation_sector():
    pypsa_to_ember = {
        "Bioenergy": "Bioenergy", "urban central solid biomass CHP": "Bioenergy",
        "urban central solid biomass CHP CC": "Bioenergy",
        "gas": "Gas", "Gas": "Gas", "CCGT": "Gas", "OCGT": "Gas", "urban central gas CHP": "Gas",
        "urban central gas CHP": "Gas", "urban central gas CHP CC": "Gas",
        "coal": "Hard coal", "Hard coal": "Hard coal","urban central coal CHP": "Hard coal", 
        "lignite": "Lignite", "Lignite": "Lignite", "urban central lignite CHP":"Lignite", 
        "hydro": "Hydro", "Hydro": "Hydro", "PHS": "Hydro", "ror": "Hydro",
        "Nuclear": "Nuclear", "nuclear": "Nuclear",
        "offwind-ac": "Offshore wind", "offwind-dc": "Offshore wind",
        "offwind-float": "Offshore wind", "Offshore wind": "Offshore wind",
        "onwind": "Onshore wind", "Onshore wind": "Onshore wind",
        "oil": "Other fossil", "Other fossil": "Other fossil",
        "geothermal": "Other renewables", "Other renewables": "Other renewables",
        "solar": "Solar", "solar-hsat": "Solar", "Solar": "Solar", "solar rooftop": "Solar"
    }

    gen_meta = n.generators[["bus", "carrier"]].copy()
    gen_meta.loc[:, "country"] = gen_meta["bus"].str[:2]
    # start by aggregating vres
    vres_carriers = ['offwind-dc', 'offwind-ac', 'solar', 'solar-hsat', 'offwind-float', 'onwind', 'ror', 'solar rooftop'] # ideally this is not hardcoded !
    vres =  n.generators.query("carrier in @vres_carriers").index
    gen_energy = n.generators_t.p.T.multiply(n.snapshot_weightings.generators).loc[vres].T.sum(axis=0) / 1e6  # MWh to TWh
    gen_energy.index.name = "generator"
    gen_energy = gen_energy.reset_index().rename(columns={0: "Value"})
    gen_energy = gen_energy.merge(gen_meta, left_on="generator", right_index=True)
    gen_grouped = gen_energy.groupby(["country", "carrier"], as_index=False)["Value"].sum()
    gen_grouped["Ember_Variable"] = gen_grouped["carrier"].map(pypsa_to_ember).fillna(gen_grouped["carrier"])
    gen_renamed = gen_grouped.groupby(["country", "Ember_Variable"], as_index=False)["Value"].sum()
    gen_renamed = gen_renamed.rename(columns={"country": "ISO", "Ember_Variable": "Variable"})

    # then by aggregating thermal generation & biomass
    conv_buses = list(
        n.generators.query("carrier in ['gas', 'coal', 'uranium', 'lignite', 'biomass', 'oil', 'solid biomass', 'unsustainable solid biomass']").bus
    )
    AC_buses = n.buses.query("carrier == 'AC'").index
    link_meta = n.links[["bus1", "carrier"]].copy()
    link_meta.loc[:, "country"] = link_meta["bus1"].str[:2]

    gen_links = n.links.query("bus0 in @conv_buses and bus1 in @AC_buses").index
    gen_energy_links = -n.links_t.p1[gen_links].T.multiply(n.snapshot_weightings.generators).T.sum(axis=0) / 1e6
    gen_energy_links.index.name = "links"

    gen_energy_links = gen_energy_links.reset_index().rename(columns={0: "Value"})
    gen_energy_links = gen_energy_links.merge(link_meta, left_on="links", right_index=True)
    gen_grouped_links = gen_energy_links.groupby(["country", "carrier"], as_index=False)["Value"].sum()
    gen_grouped_links["Ember_Variable"] = gen_grouped_links["carrier"].map(pypsa_to_ember).fillna(gen_grouped_links["carrier"])
    gen_renamed_links = gen_grouped_links.groupby(["country", "Ember_Variable"], as_index=False)["Value"].sum()
    gen_renamed_links = gen_renamed_links.rename(columns={"country": "ISO", "Ember_Variable": "Variable"})

    # and finally aggregating storage values
    sto_meta = n.storage_units[["bus", "carrier"]].copy()
    sto_meta.loc[:, "country"] = sto_meta["bus"].str[:2]
    sto_energy = n.storage_units_t.p.T.multiply(n.snapshot_weightings.stores).T.sum(axis=0) / 1e6  # MWh to TWh
    sto_energy.index.name = "storage_unit"
    sto_energy = sto_energy.reset_index().rename(columns={0: "Value"})
    sto_energy = sto_energy.merge(sto_meta, left_on="storage_unit", right_index=True)
    sto_grouped = sto_energy.groupby(["country", "carrier"], as_index=False)["Value"].sum()
    sto_grouped["Ember_Variable"] = sto_grouped["carrier"].map(pypsa_to_ember).fillna(sto_grouped["carrier"])
    sto_renamed = sto_grouped.groupby(["country", "Ember_Variable"], as_index=False)["Value"].sum()
    sto_renamed = sto_renamed.rename(columns={"country": "ISO", "Ember_Variable": "Variable"})

    gen_and_sto = pd.concat([gen_renamed, gen_renamed_links, sto_renamed], ignore_index=True)
    gen_and_sto = gen_and_sto.groupby(["ISO", "Variable"], as_index=False)["Value"].sum()
    print(f"PyPSA generation sample:\n{gen_and_sto.head().to_string()}")
    return gen_and_sto.set_index(["ISO", "Variable"])

pypsa_generation_sector = process_pypsa_generation_sector()

## Compare Ember processing methods

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

def compare_ember_processing(df_yearly, df_monthly, country_iso="DE", year=None):
    print(f"Comparing Ember processing for {country_iso}")
   
    def get_country_data(df, iso):
        if df is None:
            return pd.Series()
        
        # Add debugging prints
        print(f"\nDebug for df: index type = {type(df.index)}, shape = {df.shape}")
        if isinstance(df.index, pd.MultiIndex):
            print(f"MultiIndex levels: {df.index.names}")
            print(f"Level 0 unique values: {df.index.get_level_values(0).unique()[:10]}...")  # First 10 to avoid too much output
            print(f"Level 1 unique values: {df.index.get_level_values(1).unique()[:10]}...")
        else:
            print(f"Single index name: {df.index.name}")
            print(f"Columns: {df.columns.tolist()}")
        
        if isinstance(df.index, pd.MultiIndex):
            try:
                unstacked = df.unstack(level=1).fillna(0)
                print(f"After unstack, index unique: {unstacked.index.unique()[:10]}...")
                if iso in unstacked.index:
                    data = unstacked.loc[iso]
                else:
                    print(f"ISO '{iso}' not found in unstacked index.")
                    data = pd.Series()
            except Exception as e:
                print(f"Error in unstack/loc: {e}")
                data = pd.Series()
            data = data[data > 0]
        else:
            try:
                sliced = df.xs(iso, level="ISO")
                print(f"After xs, shape: {sliced.shape}, columns: {sliced.columns.tolist()}")
                data = sliced["Value"]
                data = data[data > 0]
            except KeyError as e:
                print(f"KeyError in xs or ['Value']: {e}")
                data = pd.Series()
            except Exception as e:
                print(f"Other error in else branch: {e}")
                data = pd.Series()
        return data
    
    yearly_data = get_country_data(df_yearly, country_iso)
    monthly_data = get_country_data(df_monthly, country_iso)
    
    if yearly_data.empty and monthly_data.empty:
        print(f"No data found for {country_iso} in either dataframe.")
        return
    
    techs = list(set(yearly_data.index).union(monthly_data.index))
    if not techs:
        print(f"No technologies with positive data for {country_iso}.")
        return
        
    yearly_values = [yearly_data.get(tech, 0) for tech in techs]
    monthly_values = [monthly_data.get(tech, 0) for tech in techs]
   
    print(f"Yearly data for {country_iso}:\n{yearly_data.to_string()}")
    print(f"Monthly aggregated data for {country_iso}:\n{monthly_data.to_string()}")
    
    plt.style.use('ggplot')
    fig, ax = plt.subplots(figsize=(10, 6))
    bar_width = 0.35
    x = np.arange(len(techs))
    ax.bar(x - bar_width/2, yearly_values, bar_width, label='Yearly CSV', color='skyblue')
    ax.bar(x + bar_width/2, monthly_values, bar_width, label='Monthly Aggregated', color='salmon')
    ax.set_xlabel('Technology')
    ax.set_ylabel('Generation (TWh)')
    title_year = f"{year}" if year is not None else "Unknown Year"
    ax.set_title(f'Ember Generation Comparison for {country_iso} (TWh, {title_year})')
    ax.set_xticks(x)
    ax.set_xticklabels(techs, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, axis='y')
    plt.tight_layout()
    
 
    output_year = year if year is not None else "unknown"
    output_path = f"results/validation_{output_year}/plots/ember_comparison_de.png"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close() 


compare_ember_processing(ember_generation_yearly, ember_generation_monthly, country_iso="DE") 

## Plotting functions

In [None]:
def plot_country_generation_mix_donut_subplots(ember_generation_yearly, country_isos, color_dict=None):
    n = len(country_isos)
    fig, axes = plt.subplots(3, 2, figsize=(7, 10))
    axes = axes.flatten()

    pivot_df = ember_generation_yearly.unstack(level=1).fillna(0)
    pivot_df.columns = pivot_df.columns.get_level_values(1)

    legend_handles = []
    legend_labels = []

    for idx, country_iso in enumerate(country_isos):
        ax = axes[idx]
        if country_iso not in pivot_df.index:
            ax.axis('off')
            ax.set_title(f"{country_iso} not found")
            continue

        data = pivot_df.loc[country_iso]
        data = data[data > 0]

        colors = [color_dict.get(tech, "#cccccc") for tech in data.index] if color_dict else plt.cm.Set2.colors[:len(data)]

        wedges, texts = ax.pie(
            data.values,
            labels=None,
            startangle=90,
            colors=colors,
            wedgeprops=dict(width=0.7),
            autopct=None
        )

        for i, wedge in enumerate(wedges):
            angle = (wedge.theta2 + wedge.theta1) / 2
            x = 0.7 * np.cos(np.deg2rad(angle))
            y = 0.7 * np.sin(np.deg2rad(angle))
            ax.text(x, y, f"{int(round(data.values[i]))}", ha='center', va='center', fontsize=10, color='white', fontweight='bold')

        centre_circle = plt.Circle((0, 0), 0.25, color='white', fc='white', linewidth=0)
        ax.add_artist(centre_circle)
        ax.text(0, 0, country_iso, ha='center', va='center', fontsize=18, fontweight='bold')

        if idx == 0:
            legend_handles = wedges
            legend_labels = data.index

    for j in range(len(country_isos), len(axes)):
        axes[j].axis('off')

    fig.legend(
        legend_handles, legend_labels,
        loc='upper center',
        bbox_to_anchor=(0.5, 1.05),
        ncol=4,
        fontsize=10,
        frameon=True
    )

    fig.suptitle(f"Yearly Electricity Generation by Technology\n(TWh, Ember {year})", fontsize=16, weight='bold', y=1.12)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    
    output_path = f"results/validation_{year}/plots/donut_subplots.png"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.show()

plot_country_generation_mix_donut_subplots(
    ember_generation, 
    ["DE", "NL", "IT", "PL", "CZ", "GR"], 
    color_dict=color_dict
)

In [None]:
def plot_country_generation_mix_donut_comparison(df1, df2, country_isos, color_dict=None, df1_label="Ember", df2_label="PyPSA"):
    n = len(country_isos)
    fig, axes = plt.subplots(n, 2, figsize=(6, 3 * n))
    plt.subplots_adjust(wspace=0.05)
    if n == 1:
        axes = np.array([axes])
    legend_handles = []
    legend_labels = []

    def pivot(df):
        if isinstance(df.index, pd.MultiIndex):
            p = df.unstack(level=1).fillna(0)
            p.columns = p.columns.get_level_values(1)
        else:
            p = df.copy()
        return p

    pivot1 = pivot(df1)
    pivot2 = pivot(df2)

    for idx, country_iso in enumerate(country_isos):
        for j, (pivot_df, label) in enumerate(zip([pivot1, pivot2], [df1_label, df2_label])):
            ax = axes[idx, j]
            if country_iso not in pivot_df.index:
                ax.axis('off')
                ax.set_title(f"{country_iso} not found")
                continue
            data = pivot_df.loc[country_iso]
            data = data[data > 0]
            colors = [color_dict.get(tech, "#cccccc") for tech in data.index] if color_dict else plt.cm.Set2.colors[:len(data)]
            wedges, _ = ax.pie(
                data.values,
                labels=None,
                startangle=90,
                colors=colors,
                wedgeprops=dict(width=0.7),
                autopct=None
            )
            for i, wedge in enumerate(wedges):
                angle = (wedge.theta2 + wedge.theta1) / 2
                x = 0.7 * np.cos(np.deg2rad(angle))
                y = 0.7 * np.sin(np.deg2rad(angle))
                ax.text(x, y, f"{int(round(data.values[i]))}", ha='center', va='center', fontsize=10, color='white', fontweight='bold')
            centre_circle = plt.Circle((0, 0), 0.25, color='white', fc='white', linewidth=0)
            ax.add_artist(centre_circle)
            total = int(round(data.sum()))
            ax.text(0, 0, f"{total}", ha='center', va='center', fontsize=14, fontweight='bold')
            ax.set_title(f"{label}\n{country_iso}" if idx == 0 else country_iso, fontsize=14, fontweight='bold')
            if idx == 0 and j == 0:
                legend_handles = wedges
                legend_labels = data.index

    for i in range(n, axes.shape[0]):
        for j in range(2):
            axes[i, j].axis('off')

    fig.legend(
        legend_handles, legend_labels,
        loc='upper center',
        bbox_to_anchor=(0.5, 1.02),
        ncol=4,
        fontsize=10,
        frameon=True
    )
    fig.suptitle("Yearly Electricity Generation [TWh]", fontsize=16, weight='bold', y=1.05)
    plt.tight_layout(rect=[0, 0, 1, 0.98])

    output_path = f"results/validation_{year}/plots/donut_comparison.png"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.show() 

plot_country_generation_mix_donut_comparison(
    ember_generation, 
    pypsa_generation_sector, 
    ["DE", "NL", "IT", "PL", "CZ", "GR"], 
    color_dict=color_dict, 
    df1_label="Ember", 
    df2_label="PyPSA"
)

In [None]:
# Required data
# - ember_generation(from process_ember_generation)
# - pypsa_generation_sector(from process_pypsa_generation_sector)
# Focus countries
focus_countries = ["CZ", "DE", "GR", "IT", "NL", "PL"]
# same index for comparison
common_index = ember_generation.index.intersection(pypsa_generation_sector.index)
ember_aligned = ember_generation.loc[common_index].copy()
pypsa_aligned = pypsa_generation_sector.loc[common_index].copy()
# PyPSA - Ember (positive if PyPSA overestimates)
diff_df = pd.DataFrame(index=common_index)
diff_df['diff'] = pypsa_aligned['Value'] - ember_aligned['Value']
# Calculating continental and country totals
total_europe_ember = ember_generation['Value'].sum()
country_totals_ember = ember_generation.groupby('ISO')['Value'].sum()
# Total generation per technology per country (Ember)
tech_country_ember = ember_generation.copy()
# Per-technology deviations
# (PyPSA - Ember) / Ember_tech_country
# if Ember > PyPSA, this will be negative (underestimation)
diff_df['rel_dev_tech_pct'] = (diff_df['diff'] / tech_country_ember.loc[common_index, 'Value']) * 100
# Calculating absolute percentage error for "how off"
diff_df['abs_rel_error_pct'] = np.abs(diff_df['diff'] / tech_country_ember.loc[common_index, 'Value']) * 100
# (PyPSA - Ember) / total_country_ember
# Map country totals to the multiindex
diff_df['country_total'] = diff_df.index.get_level_values('ISO').map(country_totals_ember)
diff_df['rel_dev_country_pct'] = (diff_df['diff'] / diff_df['country_total']) * 100
# 1c. Normalized by Europe total generation (Ember): (PyPSA - Ember) / total_europe_ember
diff_df['rel_dev_europe_pct'] = (diff_df['diff'] / total_europe_ember) * 100
diff_df = diff_df.drop(columns=['country_total'])
diff_df = diff_df.replace([np.inf, -np.inf], np.nan).dropna()
focus_index = diff_df.index.get_level_values('ISO').isin(focus_countries)
diff_df = diff_df.loc[focus_index]
# 2. Total generation deviation by country (kept for completeness, but not used in plot/table)
country_totals_pypsa = pypsa_generation_sector.groupby('ISO')['Value'].sum()
common_countries = country_totals_ember.index.intersection(country_totals_pypsa.index)
country_diff = pd.DataFrame(index=common_countries)
country_diff['diff_total'] = country_totals_pypsa.loc[common_countries] - country_totals_ember.loc[common_countries]
# (PyPSA_total - Ember_total) / Ember_total_country
country_diff['rel_dev_country_total_pct'] = (country_diff['diff_total'] / country_totals_ember.loc[common_countries]) * 100
# (PyPSA_total - Ember_total) / total_europe_ember
country_diff['rel_dev_europe_total_pct'] = (country_diff['diff_total'] / total_europe_ember) * 100
country_diff = country_diff.loc[country_diff.index.isin(focus_countries)]
output_dir = f"results/validation_{year}/deviations/"
os.makedirs(output_dir, exist_ok=True)
diff_df.to_csv(os.path.join(output_dir, "per_tech_deviations_focus.csv"))
country_diff.to_csv(os.path.join(output_dir, "per_country_total_deviations_focus.csv"))
print("Per-technology deviations sample (focus countries):")
print(diff_df.head().to_string())
print("Per-country total deviations sample (focus countries):")
print(country_diff.head().to_string())
def plot_per_country_tech_deviations(df, focus_countries, color_dict, year, output_dir):
    plt.style.use('ggplot')
    n = len(focus_countries)
    rows = (n + 1) // 2 # Two columns
    fig, axes = plt.subplots(rows, 2, figsize=(14, 4 * rows), sharey=True)
    axes = axes.flatten()
  
    for i, country in enumerate(focus_countries):
        ax = axes[i]
        try:
            # Fixed: Use xs to slice MultiIndex by ISO level
            country_data = df.xs(country, level='ISO')
            if country_data.empty:
                ax.text(0.5, 0.5, f"No data for {country}", ha='center', va='center')
                ax.axis('off')
                continue
           
            # NEW: Filter to non-NaN technologies only
            valid_mask = country_data['rel_dev_tech_pct'].notna()
            if not valid_mask.any():
                ax.text(0.5, 0.5, f"No data for {country}", ha='center', va='center')
                ax.axis('off')
                continue
           
            techs = country_data.index[valid_mask]
            y = country_data['rel_dev_tech_pct'][valid_mask]
            x = np.arange(len(y))
            colors = [color_dict.get(tech, '#cccccc') for tech in techs]
            ax.bar(x, y, color=colors, alpha=0.7)
            ax.set_title(country)
            ax.set_ylabel('Relative Deviation [%]')
            ax.set_xticks(x)
            ax.set_xticklabels(techs, rotation=45, ha='right')
            ax.grid(True, axis='y')
        except KeyError:
            ax.text(0.5, 0.5, f"No data for {country}", ha='center', va='center')
            ax.axis('off')
  
    for j in range(n, len(axes)):
        axes[j].axis('off')
  
    fig.suptitle(f"Relative Deviation per Technology per Focus Country ({year})", y=1.02, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
  
    output_path = os.path.join(output_dir, "plots/rel_dev_per_tech_per_country_focus.png")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()
plot_per_country_tech_deviations(diff_df, focus_countries, color_dict, year, output_dir)
# ---------------------------
# Summary Table for Key Carriers
# ---------------------------
def generate_summary_table(df, key_carriers, focus_countries, year):
    full_multi = pd.MultiIndex.from_product([focus_countries, key_carriers], names=['ISO', 'Variable'])
    df_full = df.reindex(full_multi).sort_index()
    key_index = df_full.index.get_level_values('Variable').isin(key_carriers)
    table_df = df_full.loc[key_index, 'rel_dev_tech_pct'].unstack(level='ISO').round(0)
    table_df.columns = pd.Index(table_df.columns, name='ISO')
    table_df = table_df[focus_countries]
    table_str_df = table_df.fillna('-').astype(str) + '%'
    def highlight(val):
        if val == '-%':
            return val
        num_str = val[:-1]
        try:
            num = float(num_str)
            if abs(num) > 50:
                return f"**{val}**"
            return val
        except ValueError:
            return val
   
    table_highlight = table_str_df.applymap(highlight)
   
    # n. countries with abs dev >50% / total with data
    row_counts = []
    for carrier in key_carriers:
        row = table_df.loc[carrier]
        total_data = row.notna().sum()
        high_dev = (abs(row) > 50).sum()
        if total_data == 0:
            row_counts.append('-')
        else:
            row_counts.append(f"{high_dev}/{int(total_data)}")
   
  
    table_highlight['n. countries with abs dev > 50%'] = row_counts
   
    # n. carriers with abs dev >50% / total carriers with data
    col_counts = []
    for country in focus_countries:
        col = table_df[country]
        total_data = col.notna().sum()
        high_dev = (abs(col) > 50).sum()
        if total_data == 0:
            col_counts.append('-')
        else:
            col_counts.append(f"{high_dev}/{int(total_data)}")
   
 
    # Create bottom row with label in index
    bottom_data = col_counts + ['']
    bottom_df = pd.DataFrame([bottom_data], index=['n. carriers with abs dev > 50%'], columns=table_highlight.columns)
   
    table_final = pd.concat([table_highlight, bottom_df])
   
  
    print(f"\nDeviations from reported ones, {year}, for key carriers and focus countries")
    print(f"(% difference between TWh in PyPSA-Eur model run and Ember reported TWh) (negative means model run underestimates reported values), absolute deviations > 50% highlighted in red")
    print(table_final.to_markdown())
   
    return table_final
key_carriers = ["Gas", "Hard coal", "Lignite", "Nuclear", "Offshore wind", "Onshore wind", "Solar"]
# Generate and print the table
summary_table = generate_summary_table(diff_df, key_carriers, focus_countries, year)