In [None]:
import os
import string
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from ng_lca_toolbox import excel2dfs

warnings.filterwarnings("ignore")

date = "2024-04-16"
path_to_ng_data = Path(f"./data/ng-supply-demand-raw/{date}-Gas_Tracker/")

SAVE_FIGS = False

# NG Supply

In [None]:
# https://www.bruegel.org/dataset/european-natural-gas-imports
df_imports = pd.read_excel(path_to_ng_data / f"country_data_{date}.xlsx")
df_imports.set_index("week", inplace=True)
df_imports.dropna(axis=0, how="all", inplace=True)
df_imports[
    df_imports.select_dtypes(include=["number"]).columns
] /= 1000  # transform to BCM

for cat in ["EU", "LNG", "Russia", "Norway", "Algeria", "UK", "Azerbaijan"]:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    df_imports[[f"{cat}_2024", f"{cat}_2023", f"{cat}_2022", f"{cat}_2021"]].plot(
        # ylim=(0.0, 10.0),
        xlim=(0, 52),
        title=f"{cat} imports",
        grid=True,
        ax=ax1,
    )
    ax1.set_ylabel("BCM")
    ax1.set_xlabel("Week")
    ax1.legend(loc="upper center", ncol=4)

    # plot cumulative imports on the same graph highlighting the final value
    df_imports[
        [f"{cat}_2024", f"{cat}_2023", f"{cat}_2022", f"{cat}_2021"]
    ].cumsum().plot(
        # ylim=(0.0, 400.0),
        xlim=(0, 55),
        title=f"Cumulative {cat} imports",
        grid=True,
        ax=ax2,
    )
    # annotate the difference between 2021 and 2023 in week 52
    diff = (
        df_imports[f"{cat}_2023"].cumsum().max().max()
        - df_imports[f"{cat}_2021"].cumsum().max().max()
    )
    ax2.annotate(
        f"{diff:.1f} BCM",
        xy=(54, df_imports[f"{cat}_2023"].cumsum().max()),
        xytext=(
            56,
            df_imports[f"{cat}_2023"].cumsum().max()
            + df_imports[[f"{cat}_2024", f"{cat}_2023", f"{cat}_2022", f"{cat}_2021"]]
            .cumsum()
            .max()
            .max()
            * 0.1,
        ),
        arrowprops=dict(facecolor="red", arrowstyle="->"),
        color="red",
        fontsize=12,
    )
    # add also a red line on x=52 to highlight the difference between 2021 and 2023
    ax2.axvline(
        x=54,
        ymin=df_imports[f"{cat}_2023"].cumsum().max().max(),
        ymax=df_imports[f"{cat}_2021"].cumsum().max().max(),
        color="red",
        # linestyle="--",
        linewidth=2.0,
    )
    # with winkers
    ax2.plot(54, df_imports[f"{cat}_2023"].cumsum().max().max(), "rx")
    ax2.plot(54, df_imports[f"{cat}_2021"].cumsum().max().max(), "rx")

    ax2.set_ylabel("BCM")
    ax2.set_xlabel("Week")
    ax2.get_legend().remove()

    handles = ax1.get_legend().legend_handles
    leg = fig.legend(
        handles,
        ["2024", "2023", "2022", "2021"],
        ncol=4,
        loc="lower center",
        borderaxespad=0,
        frameon=True,
    )
    ax1.get_legend().remove()

    leg.get_frame().set_linewidth(0.3)

    plt.tight_layout()
    plt.show()

In [None]:
df_imports_yearly = {}
for year in range(2021, 2024):
    df_imports_yearly_ = (
        df_imports[[col for col in df_imports.columns if col.endswith(f"_{year}")]]
        .sum()
        .to_dict()
    )
    df_imports_yearly[year] = {
        k.split("_")[0]: v for k, v in df_imports_yearly_.items() if not k.startswith("EU_")
    }
df_imports_yearly = pd.DataFrame(df_imports_yearly)
df_imports_yearly.loc["total"] = df_imports_yearly.sum()
df_imports_yearly

In [None]:
df_imports_yearly_mid = {}
for year in range(2021, 2024):
    df_imports_yearly_ = (
        df_imports[[col for col in df_imports.columns if col.endswith(f"_{year+1}")]]
        .loc[1:25]
        .sum()
        .to_dict()
    )
    df_imports_yearly_.update(
        df_imports[[col for col in df_imports.columns if col.endswith(f"_{year}")]]
        .loc[25::]
        .sum()
        .to_dict()
    )

    country_list = list(set([k.split("_")[0] for k in df_imports_yearly_.keys()]))
    country_list.remove("EU")
    df_imports_yearly_mid[f"Jun{year}-Jun{year+1}"] = {
        f"{country}": sum(
            v for k, v in df_imports_yearly_.items() if k.startswith(f"{country}")
        )
        for country in country_list
    }

df_imports_yearly_mid = pd.DataFrame(df_imports_yearly_mid)
df_imports_yearly_mid.loc["total"] = df_imports_yearly_mid.sum()
df_imports_yearly_mid

### LNG

In [None]:
df_LNG = pd.read_excel(path_to_ng_data / f"LNG plot data 2024-04-09.xlsx")
df_LNG.drop("Unnamed: 0", axis=1, inplace=True)
df_LNG.set_index("dates", inplace=True)
df_LNG.index = pd.to_datetime(df_LNG.index, format="%m/%Y")
df_LNG /= 1000  # transform to BCM
df_LNG_year = df_LNG.groupby(df_LNG.index.year).sum()
df_LNG_year["total"] = df_LNG_year.sum(axis=1)
df_LNG_year.loc[2021:2024].T

In [None]:
df_LNG_mid = pd.read_excel(path_to_ng_data / f"LNG plot data 2024-04-09.xlsx")
df_LNG_mid.drop("Unnamed: 0", axis=1, inplace=True)
df_LNG_mid.set_index("dates", inplace=True)
df_LNG_mid.index = pd.to_datetime(df_LNG_mid.index, format="%m/%Y")
df_LNG_mid /= 1000  # transform to BCM

# sum from june to june
df_LNG_mid_year = {}
for year in range(2021, 2024):
    df_LNG_mid_year[f"Jun{year}-Jun{year+1}"] = df_LNG_mid.loc[f"{year}-06-01":f"{year+1}-05-31"].sum()

df_LNG_mid_year = pd.DataFrame(df_LNG_mid_year)
df_LNG_mid_year.loc["total"] = df_LNG_mid_year.sum(axis=0)
df_LNG_mid_year

# NG demand

In [None]:
# https://ec.europa.eu/eurostat/databrowser/view/nrg_cb_gasm__custom_10929344/default/table?lang=en
# https://ec.europa.eu/eurostat/databrowser/view/nrg_cb_gasm__custom_10929344/default/table?lang=en
path_to_ng_demand_data = Path("./data/ng-supply-demand-raw/eurostat-demand-ng/")

df_demand = pd.read_excel(
    path_to_ng_demand_data / "2024-04-18-nrg_cb_gasm__custom_10949402_spreadsheet.xlsx",
    skiprows=8,
    sheet_name="Sheet 1",
    index_col=[0, 1],
    parse_dates=True,
)

# clean data
df_demand.replace(":", None, inplace=True)
df_demand.replace("not available", None, inplace=True)
# replace all ascii characters to none
[df_demand.replace(char, None, inplace=True) for char in string.ascii_lowercase]
df_demand.replace("provisional", None, inplace=True)

df_demand.dropna(axis=1, how="all", inplace=True)
df_demand.dropna(axis=0, how="all", inplace=True)
df_demand = df_demand.astype(float) / 1_000  # transform to BCM
# change columns type to datetime
df_demand.columns = pd.to_datetime(df_demand.columns)
df_demand.index.names = ["Location", "Category"]

# select only EU27 countries one by one from Location
df_demand = df_demand.loc[
    [
        "European Union - 27 countries (from 2020)",
        "Austria",
        "Belgium",
        "Bulgaria",
        "Croatia",
        "Cyprus",
        "Czechia",
        "Denmark",
        "Estonia",
        "Finland",
        "France",
        "Germany",
        "Greece",
        "Hungary",
        "Ireland",
        "Italy",
        "Latvia",
        "Lithuania",
        "Luxembourg",
        "Malta",
        "Netherlands",
        "Poland",
        "Portugal",
        "Romania",
        "Slovakia",
        "Slovenia",
        "Spain",
        "Sweden",
    ]
]

# aggregate sum by year (columns)
df_demand_year = df_demand.groupby(df_demand.columns.year, axis=1).sum()
df_demand_year_eu27 = (
    df_demand.groupby(df_demand.columns.year, axis=1)
    .sum()
    .loc["European Union - 27 countries (from 2020)"]
)
df_demand_year_countries = (
    df_demand.groupby(df_demand.columns.year, axis=1)
    .sum()
    .drop(index="European Union - 27 countries (from 2020)")
)
df_demand_year_eu27_calc = df_demand_year_countries.groupby(
    df_demand_year_countries.index.get_level_values(1)
).sum()
five_year_avg = df_demand_year_eu27_calc.loc[:, 2017:2021].mean(axis=1)
one_year_avg = df_demand_year_eu27_calc.loc[:, 2021:2021].mean(axis=1)

loc = "European Union - 27 countries (from 2020)"
for cat in reversed(
    [
        "Indigenous production",
        "Imports",
        "Exports",
        "Stock changes - as defined in MOS GAS",
        "Inland consumption - observed",
    ]
):
    five_year_avg_cat = five_year_avg.loc[cat]
    one_year_avg_cat = one_year_avg.loc[cat]

    fig, ax = plt.subplots()
    df_demand_year_eu27_calc.loc[cat, 2017:2023].plot(
        kind="bar",
        title=" | ".join([loc, cat]),
        grid=True,
        ax=ax,  # ylim=(300, 420)
    )
    ax.set_ylabel("BCM")
    ax.set_xlabel("Year")

    ax.axhline(
        y=one_year_avg_cat,
        color="red",
        linestyle="--",
        label="2021",
    )
    ax.axhline(
        y=five_year_avg_cat,
        color="green",
        linestyle="--",
        label="mean 2017-2021",
    )

    # double-sided arrow from 2022 bin to horizontal line
    for i_year, year in enumerate([2022, 2023]):
        diff = df_demand_year_eu27_calc.loc[cat, year].max() - one_year_avg_cat
        ax.annotate(
            "",
            xy=(5 + i_year, df_demand_year_eu27_calc.loc[cat, year].max()),
            xytext=(5 + i_year, one_year_avg_cat * 1.00),
            arrowprops=dict(arrowstyle="<->", color="magenta"),
            horizontalalignment="center",
        )
        ax.text(
            5.1 + i_year,
            one_year_avg_cat + diff * 0.5,
            f"{diff:.1f} bcm",
            fontsize=8,
            # box with white bg
            bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.2"),
        )

    plt.tight_layout()
    plt.show()

In [None]:
category = "Imports"
fig, ax = plt.subplots(figsize=(15, 5))
categories = [
    "Indigenous production",
    "Imports",
    "Exports",
    "Stock changes - as defined in MOS GAS",
    "Inland consumption - observed",
]

for category in categories:
    df_demand.loc["European Union - 27 countries (from 2020)"].dropna(axis=1, how="all").loc[category].plot(
        grid=True,
        ax=ax,
        label=category,
    )

# legend outside the plot
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
ax.set_ylabel("BCM")
ax.set_xlabel("Year")
plt.tight_layout()

In [None]:
df_demand_year_eu27_calc.loc[:,:2023].style.background_gradient(
    cmap="RdYlGn", axis=1
)

In [None]:
df_demand_year_eu27_calc_T = df_demand_year_eu27_calc.loc[:,:2023].T
df_demand_year_eu27_calc_T["Net imports"] = (
    df_demand_year_eu27_calc_T["Imports"]
    - df_demand_year_eu27_calc_T["Exports"]
)
df_demand_year_eu27_calc_T[["Net imports", "Indigenous production", "Inland consumption - observed", "Stock changes - as defined in MOS GAS"]]

In [None]:
# format the following dataframe with colormap with conditional formatting
(df_demand_year_eu27 - df_demand_year_eu27_calc).style.background_gradient(
    cmap="RdYlGn", axis=1
)

# Sankey

## Setup calculations

In [None]:
df_supply, df_demand = excel2dfs()
df_demand = df_demand[df_demand.index.notnull()]

LABEL_UNDERNEATH = True

In [None]:
alternative_list = [
    "NG",
    "Coal",
    "Oil",
    "Nuclear",
    "Hydro",
    "Solar",
    "Wind",
    "Biomass",
    "Lignite",
    "Electricity",
    "Efficiency",
    "Savings",
    "Weather",
]

supply_list = [
    "Russia",
    "LNG",
    "Norway",
    "Algeria",
    "Other sources",
    "Endogenous production",
    "NG storage ",
]

processes = [
    "Electricity production",
    "CHP",
    "Industrial heating",
    "Individual heating",
    "Other usage",
    "NG storage",
]

# products = ["Electricity ", "Steam/furnace", "Individual heating ", "Loss"]
products = ["Electricity ", "Steam/furnace", "Individual heat", "Loss"]
dummy = ["Dummy1", "Dummy2"]


def generate_scenario_dict(df_supply, df_demand, supply_sce, demand_sce):
    power_index = df_demand.index.str.endswith("POWER")
    chp_index = df_demand.index.str.endswith("CHP")
    hi_index = df_demand.index.str.endswith("INDUSTRY")
    hh_index = df_demand.index.str.endswith("HOUSEHOLDS")

    power_dict = df_demand[f"{demand_sce}_output"].loc[power_index].to_dict()
    chp_dict = df_demand[f"{demand_sce}_output"].loc[chp_index].to_dict()
    hi_dict = df_demand[f"{demand_sce}_output"].loc[hi_index].to_dict()
    hh_dict = df_demand[f"{demand_sce}_output"].loc[hh_index].to_dict()

    power_dict_input = df_demand[f"{demand_sce}_amount"].loc[power_index].to_dict()
    chp_dict_input = df_demand[f"{demand_sce}_amount"].loc[chp_index].to_dict()
    hi_dict_input = df_demand[f"{demand_sce}_amount"].loc[hi_index].to_dict()
    hh_dict_input = df_demand[f"{demand_sce}_amount"].loc[hh_index].to_dict()

    supply_dict = {k: v * 9.77 for k, v in df_supply[f"{supply_sce}_amount"].items()}

    supply_dict_proc = {
        "Russia": supply_dict["NG_RU"],
        # "LNG": sum([val for k, val in supply_dict.items() if k.startswith("LNG")]),
        "LNG": sum([val for k, val in supply_dict.items() if "LNG" in k]),
        "Norway": supply_dict["NG_NO"],
        "Algeria": supply_dict["NG_DZ"],
        "Endogenous production": sum(
            supply_dict[k]
            for k in ["NG_DE", "NG_NL", "NG_RO"] + ["ENDOG1", "ENDOG2", "ENDOG3"]
        ),
        "NG storage ": -df_demand.loc["STORAGE", f"{demand_sce}_amount"] * 9.77,
    }
    supply_dict_proc["Other sources"] = (
        df_supply.loc["TOT_NG", f"{supply_sce}_amount"] * 9.77
        - sum(supply_dict_proc.values())
        + supply_dict_proc["NG storage "]
    )

    sce_dict_proc = {
        "Electricity production": power_dict_input,
        "CHP": chp_dict_input,
        "Industrial heating": hi_dict_input,
        "Individual heating": hh_dict_input,
        # "supply": supply_dict,
        "NG storage": {"NG": df_demand.loc["STORAGE", f"{demand_sce}_amount"] * 9.77},
        "Other usage": {
            "NG": (
                df_supply.loc["TOT_NG", f"{supply_sce}_amount"]
                - df_demand.loc["STORAGE", f"{demand_sce}_amount"]
                - df_demand.loc["NG_ENERG_DEMAND", f"{demand_sce}_amount"]
            )
            * 9.77
        },
        "Electricity ": {
            "Electricity production": sum(power_dict.values()),
            "CHP": df_demand[f"{demand_sce}_output"]
            .loc[chp_index & ~df_demand.index.str.contains("HEAT_CHP")]
            .sum(),
        },
        "Steam/furnace": {
            "CHP": df_demand[f"{demand_sce}_output"]
            .loc[df_demand.index.str.endswith("HEAT_CHP")]
            .sum(),
            "Industrial heating": sum(hi_dict.values()),
        },
        "Individual heat": {"Individual heating": sum(hh_dict.values())},
        "Loss": {
            "Electricity production": sum(power_dict_input.values())
            - sum(power_dict.values()),
            "CHP": sum(chp_dict_input.values()) - sum(chp_dict.values()),
            "Industrial heating": sum(hi_dict_input.values()) - sum(hi_dict.values()),
            "Individual heating": sum(hh_dict_input.values()) - sum(hh_dict.values()),
        },
        "Dummy1": {"Dummy2": 50 * 9.77},
    }

    return sce_dict_proc, supply_dict_proc


base_dict_proc, base_supply_dict_proc = generate_scenario_dict(
    df_supply, df_demand, "y2021_mid", "base"
)
repower_dict_proc, repower_supply_dict_proc = generate_scenario_dict(
    df_supply, df_demand, "alternative_mid", "repower"
)

y2022_dict_proc, y2022_supply_dict_proc = generate_scenario_dict(
    df_supply, df_demand, "y2022_mid", "y2022"
)

## Nodes

In [None]:
nodes = {
    label: ite
    for ite, label in enumerate(
        supply_list + alternative_list + processes + products + dummy
    )
}
colors = {label: "#a6cee3" for label in supply_list}  # NG
colors.update({label: "#D3D3D3" for label in products})  # Final products
colors.update(
    {
        "NG": "#a6cee3",
        "NG storage": "#a6cee3",
        "NG storage ": "#a6cee3",
        "Coal": "#b15928",
        "Oil": "#000000",
        "Nuclear": "#6a3d9a",
        "Hydro": "#8da0cb",
        "Solar": "#FF00FF",
        "Wind": "#e31a1c",
        "Biomass": "#808000",
        "Lignite": "#964B00",
        "Electricity": "#ff7f00",
        "Efficiency": "#fc8d62",
        "Savings": "#cab2d6",
        "Weather": "#8da0cb",
    }
)
colors.update(
    {
        "Electricity production": "#fdbf6f",
        "CHP": "#b2df8a",
        "Industrial heating": "#33a02c",
        "Individual heating": "#fb9a99",
        "Other usage": "#D3D3D3",
        "Dummy1": "#D3D3D3",
        "Dummy2": "#D3D3D3",
    }
)

for color in colors:
    colors[color] = (
        f"rgba({int(colors[color].lstrip('#')[0:2], 16)}, {int(colors[color].lstrip('#')[2:4], 16)}, {int(colors[color].lstrip('#')[4:6], 16)}, 0.75)"
    )

assert len(nodes) == len(colors)
assert nodes.keys() == colors.keys()

In [None]:
x_dict = {k: 0.01 for k in supply_list}
if LABEL_UNDERNEATH:
    x_dict.update({k: 0.3 if (i%2==1 or i==0) else 0.4 for i, k in enumerate(alternative_list)})
else:
    x_dict.update({k: 0.3 for i, k in enumerate(alternative_list)})
x_dict.update({k: 0.60 for k in processes})
x_dict.update({k: 0.99 for k in products})
x_dict.update({"Dummy1": 0.3, "Dummy2": 0.55})

y_dict = {k: y for k, y in zip(supply_list, np.linspace(0.05, 0.95, len(supply_list)))}
y_dict["NG"] = 0.05
try:
    [alternative_list.remove(x) for x in ["Biomass", "Lignite"]]
except ValueError:
    pass
y_dict.update({k: y for k, y in zip(alternative_list[1::], np.linspace(0.25, 0.95, len(alternative_list[1::])))})
y_dict.update({k: y for k, y in zip(processes, np.linspace(0.05, 0.95, len(processes)))})
y_dict.update({k: y for k, y in zip(products, np.linspace(0.05, 0.95, len(products)))})
y_dict.update({"Dummy1": 1.3, "Dummy2": 1.3})

In [None]:
df_nodes = pd.DataFrame({"node": list(nodes.keys())}).reset_index().set_index("node")
df_nodes["color"] = df_nodes.index.map(colors)
df_nodes["x"] = x_dict
df_nodes["y"] = y_dict

## Edges

In [None]:
def refactor_node_name(name):
    return (
        name.split("_")[0]
        .capitalize()
        .replace("Ng", "NG")
        .replace("Lng", "LNG")
        .replace("Chp", "CHP")
    )


def create_df_edges(supply_dict, alternative_dict):
    link_list = [
        {"source_name": refactor_node_name(sorce), "target_name": "NG", "value": val}
        for sorce, val in supply_dict.items()
    ]

    for key, alt_dict in alternative_dict.items():
        if key == "supply":
            continue

        link_list += [
            {
                "source_name": refactor_node_name(source),
                "target_name": key,
                "value": val,
            }
            for source, val in alt_dict.items()
        ]

    df_edges = pd.DataFrame(link_list)
    df_edges["source"] = df_edges["source_name"].map(nodes)
    df_edges["target"] = df_edges["target_name"].map(nodes)
    df_edges["color"] = df_edges["source_name"].map(df_nodes["color"])
    df_edges["label"] = (
        df_edges["source_name"] + " -> " + df_edges["target_name"] + " in TWh"
    )

    return df_edges

## Sankey plots

In [None]:
# plot sankey diagram with df_edges and df_node
import plotly.graph_objects as go

if SAVE_FIGS and os.path.exists("./figs") is False:
    os.mkdir("./figs")

def plot_sankey(df_nodes, df_edges):
    positive_edge_sources = (
        df_edges.loc[df_edges["value"] > 0]["source_name"].unique().tolist()
    )
    positive_edge_target = (
        df_edges.loc[df_edges["value"] > 0]["target_name"].unique().tolist()
    )
    positve_nodes = list(set(positive_edge_sources + positive_edge_target))
    # positve_nodes = list(set(df_edges.source_name.tolist() + df_edges.target_name.tolist()))
    positve_nodes = df_nodes.loc[positve_nodes].sort_values("index").index.tolist()
    existing_edges = df_edges[df_edges["value"] != 0]

    fig = go.Figure(
        data=[
            go.Sankey(
                node=dict(
                    pad=100,
                    thickness=20,
                    line=dict(color="black", width=0.75),
                    label=df_nodes.index,
                    color=df_nodes["color"],
                    x=df_nodes["x"].loc[positve_nodes],
                    y=df_nodes["y"].loc[positve_nodes],
                ),
                link=dict(
                    source=existing_edges["source"],
                    target=existing_edges["target"],
                    value=existing_edges["value"],
                    color=existing_edges["color"],
                    label=existing_edges["label"],
                ),
            )
        ]
    )
    fig.update_layout(
        # title_text = input_dict2['title_text'][0] if isinstance(input_dict2['title_text'], tuple) else input_dict2['title_text'],
        font_size=16,
        font_color="black",
        font_family="calibri",
        paper_bgcolor="white",  #'rgba(1,1,1,1)',#'white', #'#f5f2ea',  # 'rgba(220,220,220,0.1)',#'#D3D333',
        plot_bgcolor="rgba(0,0,0,0)",
        hovermode="x",
        xaxis={
            "showgrid": False,
            "zeroline": False,
            "visible": False,
        },
        yaxis={
            "showgrid": False,
            "zeroline": False,
            "visible": False,
        },
        margin_t=90,
        margin_b=60,
        margin_l=30,
        margin_r=30,
        title_pad_t=0,
        title_pad_b=0,
        title_pad_l=30,
        title_pad_r=30,
    )

    return fig


def dfs_with_amount(df_nodes, df_edges):
    dict_node_with_amount = {}

    for node in df_nodes.index:
        extreme_values = [
            df_edges.loc[df_edges["source_name"] == node].value.sum(),
            df_edges.loc[df_edges["target_name"] == node].value.sum(),
        ]
        amount = max(extreme_values)
        if amount == 0:
            if min(extreme_values) < 0:
                amount = min(extreme_values)

        if LABEL_UNDERNEATH:
            label_ini = f"<b>{node}</b><br>[<i>{amount:.0f} TWh</i>]"
            dict_node_with_amount[node] = label_ini
            if df_nodes.loc[node].color == "rgba(166, 206, 227, 0.75)":
                dict_node_with_amount[node] = label_ini.replace(
                    "]", f", <i>{amount/9.77:.0f} bcm</i>]"
                )

        else:
            dict_node_with_amount[node] = f"{node} [<i>{amount:.0f} TWh</i>]"

    df_nodes = df_nodes.copy()
    df_edges = df_edges.copy()

    df_nodes.index = df_nodes.index.map(dict_node_with_amount)
    df_edges.source_name = df_edges.source_name.map(dict_node_with_amount)
    df_edges.target_name = df_edges.target_name.map(dict_node_with_amount)

    return df_nodes, df_edges


is_mid_2021 = True
is_mid_2022 = True
mid2021 = "_mid" if is_mid_2021 else ""
mid2022 = "_mid" if is_mid_2022 else ""

use_storage = True

ng_storage = {
    "y2021": (57.804 - 80.801) * use_storage,
    "y2021_mid": (60.213 - 47.481) * use_storage,
    "y2022": (95.945 - 90.512) * use_storage,
    "y2022_mid": (82.344 - 60.213) * use_storage,
}  # https://www.bruegel.org/dataset/european-natural-gas-imports

df_demand.loc["STORAGE", "base_amount"] = ng_storage["y2021" + mid2021]
df_demand.loc["STORAGE", "repower_amount"] = ng_storage["y2021" + mid2021]
df_demand.loc["STORAGE", "coal_amount"] = ng_storage["y2021" + mid2021]
df_demand.loc["STORAGE", "clean_amount"] = ng_storage["y2021" + mid2021]
df_demand.loc["STORAGE", "y2022_amount"] = ng_storage["y2022" + mid2022]

base_dict_proc, base_supply_dict_proc = generate_scenario_dict(
    df_supply, df_demand, "y2021" + mid2021 if is_mid_2021 else "base", "base"
)

repower_dict_proc, repower_supply_dict_proc = generate_scenario_dict(
    df_supply, df_demand, "alternative" + mid2021, "repower"
)

y2022_dict_proc, y2022_supply_dict_proc = generate_scenario_dict(
    df_supply, df_demand, "y2022" + mid2022, "y2022"
)

df_edges_base = create_df_edges(base_supply_dict_proc, base_dict_proc)
df_edges_repower = create_df_edges(repower_supply_dict_proc, repower_dict_proc)
df_edges_y2022 = create_df_edges(y2022_supply_dict_proc, y2022_dict_proc)

df_nodes_base, df_edges_base = dfs_with_amount(df_nodes, df_edges_base)
df_nodes_repower, df_edges_repower = dfs_with_amount(df_nodes, df_edges_repower)
df_nodes_y2022, df_edges_y2022 = dfs_with_amount(df_nodes, df_edges_y2022)

with_storage = df_demand.loc["STORAGE", :].dropna().any() != 0
storage_str = "_withStorage" if with_storage else ""

fig = plot_sankey(df_nodes_base, df_edges_base)
fig.write_image(
    "./figs/base_scenario_sankey" + storage_str + ".svg", width=800, height=600
)
fig.show()

fig = plot_sankey(df_nodes_repower, df_edges_repower)
fig.write_image(
    "./figs/repower_scenario_sankey" + storage_str + ".svg", width=800, height=600
)
fig.show()

# deal with negative values
negative_rows = (df_edges_y2022["value"] < 0) & ~(
    df_edges_y2022["source_name"].str.startswith("<b>NG storage")
    | df_edges_y2022["target_name"].str.startswith("<b>NG storage")
)

df_edges_y2022.loc[
    negative_rows, ["source_name", "target_name", "source", "target"]
] = df_edges_y2022.loc[
    negative_rows, ["target_name", "source_name", "target", "source"]
].values
df_edges_y2022.loc[negative_rows, "value"] *= -1.0

fig = plot_sankey(df_nodes_y2022, df_edges_y2022)

fig.write_image(
    "./figs/y2022_scenario_sankey" + storage_str + ".svg", width=800, height=600
)
fig.show()