DATA PROJECT

MACROECONOMIC DATA

In [None]:
import pandas as pd

# =========================================================
# 1. CONFIGURATION
# =========================================================
wb_files = {
    "GDP": r"C:\Users\Utente\Documents\WIP\CSV files for py\GDP.csv",
    "POP": r"C:\Users\Utente\Documents\WIP\CSV files for py\Population.csv",
    "POP15_64": r"C:\Users\Utente\Documents\WIP\CSV files for py\POP15_64.csv",
    "LBF": r"C:\Users\Utente\Documents\WIP\CSV files for py\LaborForce.csv",
    "UNEMR": r"C:\Users\Utente\Documents\WIP\CSV files for py\UNEMR.csv",
    "GDPDEFL": r"C:\Users\Utente\Documents\WIP\CSV files for py\GDPDEFL.csv",
    "IMPORT": r"C:\Users\Utente\Documents\WIP\GOVDATA\C+S+G+NE\Import_cur.csv",
    "EXPORT": r"C:\Users\Utente\Documents\WIP\GOVDATA\C+S+G+NE\Export_cur.csv",
    "GSAVINGS": r"C:\Users\Utente\Documents\WIP\GOVDATA\C+S+G+NE\GrossSavings_cur.csv",
    "CONSUMPTION": r"C:\Users\Utente\Documents\WIP\GOVDATA\C+S+G+NE\Consumption_cur.csv"
}

extra_files = {
    "USGDP": r"C:\Users\Utente\Documents\WIP\CSV files for py\USGDP.csv",
    "USPOP": r"C:\Users\Utente\Documents\WIP\CSV files for py\USPOP.csv",
    "CHIGDP": r"C:\Users\Utente\Documents\WIP\CSV files for py\CHIGDP.csv",
    "CHIPOP": r"C:\Users\Utente\Documents\WIP\CSV files for py\CHIPOP.csv",
    "CAGDP": r"C:\Users\Utente\Documents\WIP\CSV files for py\CAGDP.csv",
    "CAPOP": r"C:\Users\Utente\Documents\WIP\CSV files for py\CAPOP.csv",
    "INDGDP": r"C:\Users\Utente\Documents\WIP\CSV files for py\INDGDP.csv",
    "INDPOP": r"C:\Users\Utente\Documents\WIP\CSV files for py\INDPOP.csv",
}

cols_to_remove = [1,2,3,-1]
rows_to_remove = [
    1,3,7,36,49,61,62,63,64,65,68,73,74,95,98,102,103,104,105,107,
    110,128,134,135,136,139,140,142,153,156,161,170,181,183,191,
    197,198,204,215,217,218,230,231,236,238,240,241,249,259
]

rename_dict = {
    "Venezuela, RB":"Venezuela","Turkiye":"Turkey",
    "Congo, Dem. Rep.":"Dem. Rep. Congo","Congo, Rep.":"Congo",
    "Egypt, Arab Rep.":"Egypt","Iran, Islamic Rep.":"Iran",
    "Korea, Dem. People's Rep.":"North Korea","Korea, Rep.":"South Korea",
    "Yemen, Rep.":"Yemen","Slovak Republic":"Slovakia",
    "Bosnia and Herzegovina":"Bosnia and Herz.",
    "Viet Nam":"Vietnam","Lao PDR":"Laos",
    "Micronesia, Fed. Sts.":"Micronesia",
    "Equatorial Guinea":"Eq. Guinea",
    "Kyrgyz Republic":"Kyrgyzstan",
    "Central African Republic":"Central African Rep.",
    "Syrian Arab Republic":"Syria","South Sudan":"S. Sudan",
    "Dominican Republic":"Dominican Rep.",
    "Cote d'Ivore":"Côte d'Ivoire",
    "Russian Federation":"Russia","Bahamas, The":"Bahamas"
}

# =========================================================
# 2. HELPERS
# =========================================================
def to_numeric(df):
    df.iloc[:,1:] = df.iloc[:,1:].apply(pd.to_numeric, errors="coerce")
    return df

def clean_wb(df):
    df = df.drop(df.columns[cols_to_remove], axis=1, errors="ignore")
    df = to_numeric(df)
    df = df.drop(index=rows_to_remove, errors="ignore").reset_index(drop=True)
    df.rename(columns={df.columns[0]: "Country Name"}, inplace=True)
    df["Country Name"] = df["Country Name"].replace(rename_dict)
    return df

def standardize_years(df, start_year=1960):
    meta = df[["Country Name"]].copy()

    year_cols = (
        pd.to_numeric(df.columns.drop("Country Name"), errors="coerce")
        .dropna()
        .astype(int)
    )

    full_years = list(range(start_year, year_cols.max() + 1))

    panel = pd.DataFrame(
        index=df.index,
        columns=full_years,
        dtype="float64"
    )

    panel.loc[:, year_cols] = df.loc[:, df.columns != "Country Name"].values

    panel = panel.sort_index(axis=1)

    panel.columns = panel.columns.astype(str)

    return pd.concat([meta, panel], axis=1)


def load_clean_standardize(files_dict):
    data = {}
    for name, path in files_dict.items():
        df = pd.read_csv(path, skiprows=4)
        df = clean_wb(df)
        df = standardize_years(df)
        data[name] = df
    return data

# =========================================================
# 3. LOAD DATA
# =========================================================
datasets = load_clean_standardize(wb_files)

# POP15-64 absolute
POP15_64 = datasets["POP15_64"].copy()
POP15_64.iloc[:,1:] = datasets["POP"].iloc[:,1:] * (POP15_64.iloc[:,1:]/100)

# Load AGA
AGM = pd.read_csv(r"C:\Users\Utente\Documents\WIP\CSV files for py\AGM.csv")
AGM["Country Name"] = AGM["Country Name"].replace(rename_dict)
AGM = standardize_years(to_numeric(AGM))
AGA = AGM.copy()
AGA.iloc[:,1:] *= 12

# Load COM extra files
# Load COM extra files
extra = {k: pd.read_csv(v) for k, v in extra_files.items()}
extra["USGDP"].iloc[:, 1:] *= 1_000_000

# -----------------------------
# Define EU country groups
# -----------------------------
eu28 = [
    "Austria", "Belgium", "Bulgaria", "Cyprus", "Czechia", "Germany", "Denmark",
    "Spain", "Estonia", "Finland", "France", "Greece", "Croatia", "Hungary",
    "Ireland", "Italy", "Lithuania", "Luxembourg", "Latvia", "Malta",
    "Netherlands", "Poland", "Portugal", "Romania", "Slovakia",
    "Slovenia", "Sweden", "United Kingdom"
]
eu27 = eu28[:-1]  # remove UK

# -----------------------------
# Create COM datasets
# -----------------------------
GDP_COM = standardize_years(pd.concat([
    datasets["GDP"],
    extra["USGDP"],
    extra["CHIGDP"],
    extra["CAGDP"],
    extra["INDGDP"]
]))

POP_COM = standardize_years(pd.concat([
    datasets["POP"],
    extra["USPOP"],
    extra["CHIPOP"],
    extra["CAPOP"],
    extra["INDPOP"]
]))

# -----------------------------
# Add EU28 and EU27 aggregates
# -----------------------------
def add_eu_rows(df, country_list, label):
    # Select only rows for countries in the list
    mask = df["Country Name"].isin(country_list)
    if mask.sum() == 0:
        # none of the countries present
        return df
    # Sum numeric columns
    summed = df.loc[mask, df.columns.drop("Country Name")].sum()
    # Create new row
    new_row = pd.DataFrame([[label] + summed.tolist()], columns=df.columns)
    # Append
    df = pd.concat([df, new_row], ignore_index=True)
    return df

GDP_COM = add_eu_rows(GDP_COM, eu28, "EU28")
GDP_COM = add_eu_rows(GDP_COM, eu27, "EU27")

POP_COM = add_eu_rows(POP_COM, eu28, "EU28")
POP_COM = add_eu_rows(POP_COM, eu27, "EU27")

# -----------------------------
# GDP per capita (COM)
# -----------------------------
GDPCAP_COM = (
    GDP_COM.set_index("Country Name")
    .div(POP_COM.set_index("Country Name"))
)

# FORCE correct year order
year_cols = sorted(GDPCAP_COM.columns, key=int)
GDPCAP_COM = GDPCAP_COM[year_cols]

GDPCAP_COM = GDPCAP_COM.reset_index()



# GDP deflator COM
GDPDEFL_COM = datasets["GDPDEFL"].copy()

us_row = datasets["GDPDEFL"][datasets["GDPDEFL"]["Country Name"] == "United States"].iloc[:, 1:]
us_states = extra["USGDP"]["Country Name"].tolist()
us_defl_df = pd.DataFrame(
    [[state, *us_row.values[0]] for state in us_states],
    columns=datasets["GDPDEFL"].columns
)

ca_row = datasets["GDPDEFL"][datasets["GDPDEFL"]["Country Name"] == "Canada"].iloc[:, 1:]
ca_provs = extra["CAGDP"]["Country Name"].tolist()
ca_defl_df = pd.DataFrame(
    [[prov, *ca_row.values[0]] for prov in ca_provs],
    columns=datasets["GDPDEFL"].columns
)

GDPDEFL_COM = pd.concat([GDPDEFL_COM, us_defl_df, ca_defl_df], ignore_index=True)


# =========================================================
# 4. MACRO DATASET CLASS
# =========================================================
class MacroDataset:
    def __init__(self, gdp, pop, defl, lbf=None, unem=None, pop15_64=None, aga=None,
                 imports=None, exports=None, gross_savings=None, consumption=None):
        self.gdp = gdp.set_index("Country Name").iloc[:,1:]
        self.pop = pop.set_index("Country Name").iloc[:,1:]
        self.defl = defl.set_index("Country Name").iloc[:,1:]
        self.lbf = lbf.set_index("Country Name").iloc[:,1:] if lbf is not None else None
        self.unem = unem.set_index("Country Name").iloc[:,1:] if unem is not None else None
        self.pop15_64 = pop15_64.set_index("Country Name").iloc[:,1:] if pop15_64 is not None else None
        self.aga = aga.set_index("Country Name").iloc[:,1:] if aga is not None else None
        self.imports = imports.set_index("Country Name").iloc[:,1:] if imports is not None else None
        self.exports = exports.set_index("Country Name").iloc[:,1:] if exports is not None else None
        self.gross_savings = gross_savings.set_index("Country Name").iloc[:,1:] if gross_savings is not None else None
        self.consumption = consumption.set_index("Country Name").iloc[:,1:] if consumption is not None else None

    def rebase_deflator(self, base_year):
        years = self.defl.columns.astype(int)
        base = base_year if base_year in years else years[years>base_year].min()
        return self.defl.div(self.defl[str(base)], axis=0)*100

    def real_gdp(self, base_year): return self.gdp.div(self.rebase_deflator(base_year))*100
    def real_gdp_per_capita(self, base_year): return self.real_gdp(base_year).div(self.pop)
    def employed(self): return self.lbf*(1-self.unem/100) if self.lbf is not None and self.unem is not None else None
    def real_gdp_per_employed(self, base_year): return self.real_gdp(base_year).div(self.employed())
    def real_gdp_per_pop15_64(self, base_year): return self.real_gdp(base_year).div(self.pop15_64)
    def real_aga(self, base_year): return self.aga.div(self.rebase_deflator(base_year))*100 if self.aga is not None else None
    def real_government_series(self, series_name, base_year):
        series = getattr(self, series_name)
        return series.div(self.rebase_deflator(base_year))*100 if series is not None else None
    def to_df(self, df): return df.reset_index()

# =========================================================
# 5. BUILD DATASETS
# =========================================================
WB = MacroDataset(datasets["GDP"], datasets["POP"], datasets["GDPDEFL"],
                  lbf=datasets.get("LBF"), unem=datasets.get("UNEMR"), pop15_64=POP15_64,
                  imports=datasets.get("IMPORT"), exports=datasets.get("EXPORT"),
                  gross_savings=datasets.get("GSAVINGS"), consumption=datasets.get("CONSUMPTION"))

WB_AGA = MacroDataset(datasets["GDP"], datasets["POP"], datasets["GDPDEFL"], aga=AGA)
COM = MacroDataset(GDP_COM, POP_COM, GDPDEFL_COM)

# =========================================================
# 6. AUTOMATIC COMPUTATION OF ALL SERIES
# =========================================================
def compute_all(WB, WB_AGA, COM, base_years=[2024]):
    datasets_available = {}

    def df_with_country_name(df):
        return df.reset_index().rename(columns={"index": "Country Name"}) if df is not None else None

    # -----------------------------
    # Nominal WB + COM datasets
    # -----------------------------
    datasets_available.update({
        # WB nominal series (exclude GDP_NOMINAL, POP_NOMINAL, GDPDEFL_NOMINAL, GDPCAP_NOMINAL)
        "POP15_64_NOMINAL": df_with_country_name(WB.pop15_64),
        "LBF_NOMINAL": df_with_country_name(WB.lbf),
        "UNEMR_NOMINAL": df_with_country_name(WB.unem),
        "GDPPOP15_64_NOMINAL": df_with_country_name(WB.gdp.div(WB.pop15_64)),
        "GDPLBF_NOMINAL": df_with_country_name(WB.gdp.div(WB.lbf)),
        "GDPEMPE_NOMINAL": df_with_country_name(WB.gdp.div(WB.employed()) if WB.employed() is not None else None),
        "AGA_NOMINAL": df_with_country_name(WB_AGA.aga),
        "IMPORT": df_with_country_name(WB.imports),
        "EXPORT": df_with_country_name(WB.exports),
        "GSAVINGS": df_with_country_name(WB.gross_savings),
        "CONSUMPTION": df_with_country_name(WB.consumption),
        "GSAVINGS_PER_CAP": df_with_country_name(WB.gross_savings.div(WB.pop) if WB.gross_savings is not None else None),
        # COM nominal
        "GDP_COM_NOMINAL": df_with_country_name(COM.gdp),
        "POP_COM_NOMINAL": df_with_country_name(COM.pop),
        "GDPDEFL_COM_NOMINAL": df_with_country_name(COM.defl),
        "GDPCAP_COM_NOMINAL": GDPCAP_COM.copy()
    })

    # -----------------------------
    # Real series (only 2024)
    # -----------------------------
    for y in base_years:
        datasets_available.update({
            f"REALGDP_COM_{y}": df_with_country_name(COM.real_gdp(y)),
            f"REALGDPCAP_COM_{y}": df_with_country_name(COM.real_gdp_per_capita(y)),
            f"REALGDPLBF_{y}": df_with_country_name(WB.real_gdp(y).div(WB.lbf)),
            f"REALGDPEMPE_{y}": df_with_country_name(WB.real_gdp_per_employed(y)),
            f"REALGDPPOP15_64_{y}": df_with_country_name(WB.real_gdp_per_pop15_64(y)),
            f"REALAGA_{y}": df_with_country_name(WB_AGA.real_aga(y)),
            f"REALIMPORT_{y}": df_with_country_name(WB.real_government_series("imports", y)),
            f"REALEXPORT_{y}": df_with_country_name(WB.real_government_series("exports", y)),
            f"REALGSAVINGS_{y}": df_with_country_name(WB.real_government_series("gross_savings", y)),
            f"REALCONSUMPTION_{y}": df_with_country_name(WB.real_government_series("consumption", y)),
            f"REALGSAVINGS_PER_CAP_{y}": df_with_country_name(WB.real_government_series("gross_savings", y).div(WB.pop) if WB.gross_savings is not None else None)
        })

    return datasets_available

# Compute datasets
DATASETS_AVAILABLE = compute_all(WB, WB_AGA, COM, base_years=[2024])



In [None]:
import dash
from dash import dcc, html, Input, Output, State
import plotly.graph_objs as go

# =========================================================
# 1. DATASETS (from previous script)
# =========================================================
datasets_dict = DATASETS_AVAILABLE

# Remove duplicate columns if any
for ds, df in datasets_dict.items():
    datasets_dict[ds] = df.loc[:, ~df.columns.duplicated()]

# =========================================================
# 2. DEFINE DATASET TYPES
# =========================================================
dataset_types = {
    "Nominal": [ds for ds in datasets_dict if ("REAL" not in ds) and ("1990" not in ds) and ("2024" not in ds)],
    "Real 1990": [ds for ds in datasets_dict if "1990" in ds],
    "Real 2024": [ds for ds in datasets_dict if "2024" in ds]
}

for key in dataset_types:
    dataset_types[key].sort()

# =========================================================
# 3. INITIALIZE DASH APP
# =========================================================
app = dash.Dash(__name__)
app.title = "Interactive Macro Line Chart"

# =========================================================
# 4. LAYOUT
# =========================================================
app.layout = html.Div([
    html.H1("Interactive Macro Line Chart", style={"textAlign": "center"}),

    html.Div([
        html.Label("Select Dataset Type:"),
        dcc.Dropdown(
            id="type-dropdown",
            options=[{"label": t, "value": t} for t in dataset_types.keys()],
            value="Nominal",
            clearable=False
        ),
    ], style={"width": "30%", "display": "inline-block"}),

    html.Div([
        html.Label("Select Dataset(s):"),
        dcc.Dropdown(
            id="dataset-dropdown",
            options=[],
            value=[],
            multi=True
        ),
    ], style={"width": "68%", "display": "inline-block", "marginLeft": "2%"}),

    html.Div([
        html.Label("Select Country(s):"),
        dcc.Dropdown(
            id="country-dropdown",
            options=[],
            value=[],
            multi=True
        ),
    ], style={"width": "100%", "display": "inline-block", "marginTop": "10px"}),

    dcc.Graph(id="line-chart")
])

# =========================================================
# 5. CALLBACKS
# =========================================================
@app.callback(
    Output("dataset-dropdown", "options"),
    Output("dataset-dropdown", "value"),
    Input("type-dropdown", "value")
)
def update_dataset_options(selected_type):
    datasets = dataset_types.get(selected_type, [])
    options = [{"label": ds, "value": ds} for ds in datasets]

    # Set default value according to type
    default_map = {
        "Nominal": "GDPCAP_COM_NOMINAL",
        "Real 2024": "REALGDPCAP_COM_2024",
        "Real 1990": "REALGDPCAP_COM_1990"
    }
    default_value = default_map.get(selected_type)
    
    # Only assign default if it exists in the datasets list
    value = [default_value] if default_value in datasets else ([datasets[0]] if datasets else [])
    
    return options, value


@app.callback(
    Output("country-dropdown", "options"),
    Output("country-dropdown", "value"),
    Input("dataset-dropdown", "value"),
    State("country-dropdown", "value")
)
def update_country_options(selected_datasets, selected_countries):
    if not selected_datasets:
        return [], selected_countries or ["Italy"]

    countries = set()
    for ds in selected_datasets:
        df = datasets_dict.get(ds)
        if df is not None:
            if "Country Name" not in df.columns:
                df = df.reset_index().rename(columns={"index": "Country Name"})
            countries.update(df["Country Name"].values.tolist())

    countries = [c for c in countries if isinstance(c, str)]
    options = [{"label": c, "value": c} for c in sorted(countries)]

    if not selected_countries:
        value = ["Italy"] if "Italy" in countries else [countries[0]]
    else:
        value = selected_countries

    return options, value

@app.callback(
    Output("line-chart", "figure"),
    Input("dataset-dropdown", "value"),
    Input("country-dropdown", "value")
)
def update_line_chart(selected_datasets, selected_countries):
    fig = go.Figure()
    if not selected_datasets or not selected_countries:
        fig.update_layout(title="Please select at least one dataset and one country")
        return fig

    for ds in selected_datasets:
        df = datasets_dict.get(ds)
        if df is not None:
            if "Country Name" not in df.columns:
                df = df.reset_index().rename(columns={"index": "Country Name"})
            for country in selected_countries:
                row = df[df["Country Name"] == country]
                if not row.empty:
                    years = row.columns[1:]
                    values = row.iloc[0, 1:].astype(float).values
                    fig.add_trace(go.Scatter(
                        x=years,
                        y=values,
                        mode="lines+markers",
                        name=f"{country} - {ds}"
                    ))

    fig.update_layout(
        title="Macro Time Series",
        xaxis_title="Year",
        yaxis_title="Value",
        hovermode="x unified",
        template="plotly_white"
    )
    return fig

# =========================================================
# 6. RUN APP
# =========================================================
if __name__ == "__main__":
    app.run(debug=True, port=8050)


In [None]:
import numpy as np

# ----------------------------
# 1. EU COUNTRY GROUPS
# ----------------------------
eu28 = [
    "Austria", "Belgium", "Bulgaria", "Cyprus", "Czechia", "Germany", "Denmark",
    "Spain", "Estonia", "Finland", "France", "Greece", "Croatia", "Hungary",
    "Ireland", "Italy", "Lithuania", "Luxembourg", "Latvia", "Malta",
    "Netherlands", "Poland", "Portugal", "Romania", "Slovakia",
    "Slovenia", "Sweden", "United Kingdom"
]
eu27 = eu28[:-1]  # remove the UK

# ----------------------------
# 2. SUBNATIONAL ROWS BY NAME
# ----------------------------
subnational_names = [
    # India
    "Andaman and Nicobar Islands","Andhra Pradesh","Arunachal Pradesh","Assam","Bihar",
    "Chandigarh","Chhattisgarh","Delhi","Goa","Gujarat","Haryana","Himachal Pradesh",
    "Jammu and Kashmir","Jharkhand","Karnataka","Kerala","Madhya Pradesh","Maharashtra",
    "Manipur","Meghalaya","Mizoram","Nagaland","Odisha","Puducherry","Punjab","Rajasthan",
    "Sikkim","Tamil Nadu","Telangana","Tripura","Uttar Pradesh","Uttarakhand","West Bengal",
    # China
    "Anhui","Beijing","Chongqing","Fujian","Gansu","Guangdong","Guangxi","Guizhou","Hainan",
    "Hebei","Heilongjiang","Henan","Hubei","Hunan","Inner Mongolia","Jiangsu","Jiangxi","Jilin",
    "Liaoning","Ningxia","Qinghai","Shaanxi","Shandong","Shanghai","Shanxi","Sichuan","Tianjin",
    "Tibet","Xinjiang","Yunnan","Zhejiang",
    # Canada
    "Alberta","British Columbia","Manitoba","New Brunswick","Newfoundland and Labrador",
    "Northwest Territories","Nova Scotia","Nunavut","Ontario","Prince Edward Island","Quebec",
    "Saskatchewan","Yukon",
    # United States
    "Alabama","Alaska","Arizona","Arkansas","California","Colorado","Connecticut","Delaware",
    "District of Columbia","Florida","Georgia","Hawaii","Idaho","Illinois","Indiana","Iowa",
    "Kansas","Kentucky","Louisiana","Maine","Maryland","Massachusetts","Michigan","Minnesota",
    "Mississippi","Missouri","Montana","Nebraska","Nevada","New Hampshire","New Jersey",
    "New Mexico","New York","North Carolina","North Dakota","Ohio","Oklahoma","Oregon",
    "Pennsylvania","Rhode Island","South Carolina","South Dakota","Tennessee","Texas","Utah",
    "Vermont","Virginia","Washington","West Virginia","Wisconsin","Wyoming"
]

# ----------------------------
# 3. HELPER SETS
# ----------------------------
US_STATES = set(subnational_names[-50:])
CANADA_PROVINCES = set(subnational_names[-63:-50])
CHINA_PROVINCES = set(subnational_names[33:66])
INDIA_STATES = set(subnational_names[:33])

# ----------------------------
# 5. MASK FUNCTIONS
# ----------------------------
def mask_grouped(df, year):
    df = df.copy()
    all_year_cols = df.columns.drop("Country Name")
    year = int(year)

    # Hide subnational rows
    df.loc[df["Country Name"].isin(subnational_names), all_year_cols] = np.nan
    # Hide EU27 members individually (they will be in EU aggregates)
    df.loc[df["Country Name"].isin(eu27), all_year_cols] = np.nan

    if year <= 2018:
        # Pre-Brexit: show EU28, hide UK
        if "EU28" in df["Country Name"].values:
            pass  # keep EU28 values
        if "United Kingdom" in df["Country Name"].values:
            df.loc[df["Country Name"] == "United Kingdom", all_year_cols] = np.nan
        if "EU27" in df["Country Name"].values:
            df.loc[df["Country Name"] == "EU27", all_year_cols] = np.nan
    else:
        # Post-Brexit: show EU27, show UK, hide EU28
        if "EU27" in df["Country Name"].values:
            pass  # keep EU27 values
        if "United Kingdom" in df["Country Name"].values:
            pass  # show UK separately
        if "EU28" in df["Country Name"].values:
            df.loc[df["Country Name"] == "EU28", all_year_cols] = np.nan

    return df


def mask_standard(df, year):
    df = df.copy()
    all_year_cols = df.columns.drop("Country Name")
    df.loc[df["Country Name"].isin(["EU27","EU28"]), all_year_cols] = np.nan
    df.loc[df["Country Name"].isin(subnational_names), all_year_cols] = np.nan
    return df

def mask_detailed(df, year):
    df = df.copy()
    all_year_cols = df.columns.drop("Country Name")
    # Hide EU aggregates
    df.loc[df["Country Name"].isin(["EU27","EU28"]), all_year_cols] = np.nan
    # Hide national totals for countries that have subnational detail
    for country, provinces in [("United States", US_STATES), ("China", CHINA_PROVINCES),
                               ("Canada", CANADA_PROVINCES), ("India", INDIA_STATES)]:
        if country in df["Country Name"].values:
            df.loc[df["Country Name"] == country, all_year_cols] = np.nan
    return df

def apply_mask(df, mask_name, year):
    if mask_name == "Grouped": return mask_grouped(df, year)
    if mask_name == "Standard": return mask_standard(df, year)
    if mask_name == "Detailed": return mask_detailed(df, year)
    return df.copy()


In [None]:
import dash
from dash import dcc, html, Input, Output
import plotly.graph_objs as go

# ----------------------------
# 1. CONFIG
# ----------------------------
MASKS = ["Grouped", "Standard", "Detailed"]

DATASETS_DICT = DATASETS_AVAILABLE.copy()

EXCLUDE_NOMINAL = {"GDP_NOMINAL", "POP_NOMINAL", "GDPDEFL_NOMINAL", "GDPCAP_NOMINAL"}

def split_datasets(data_dict):
    nominal = [
        k for k in data_dict
        if ("COM" in k or "POP15_64" in k or "LBF" in k or "UNEMR" in k
            or "GDPPOP15_64" in k or "GDPLBF" in k or "GDPEMPE" in k
            or "AGA" in k or "IMPORT" in k or "EXPORT" in k
            or "GSAVINGS" in k or "CONSUMPTION" in k or "GSAVINGS_PER_CAP" in k)
        and k not in EXCLUDE_NOMINAL
    ]
    real = [k for k in data_dict if "REAL" in k]
    return sorted(nominal), sorted(real)

NOMINAL_DATASETS, REAL_DATASETS = split_datasets(DATASETS_DICT)

# ----------------------------
# 2. APP INIT
# ----------------------------
app = dash.Dash(__name__)
app.title = "Macro Interactive Dashboard"

DARK_BLUE = "#1f3b5c"
BAR_BLUE = "#2f5aa6"

# ----------------------------
# 3. LAYOUT
# ----------------------------
app.layout = html.Div(
    style={
        "backgroundColor": "#f4f6f9",
        "padding": "20px",
        "fontFamily": "Arial"
    },
    children=[

        html.H2(
            "INTERACTIVE MACRO COMPARISON",
            style={
                "textAlign": "center",
                "color": DARK_BLUE,
                "marginBottom": "25px",
                "letterSpacing": "1px"
            }
        ),

        # ---- CONTROLS ROW ----
        html.Div(
            style={
                "display": "flex",
                "gap": "15px",
                "alignItems": "flex-end",
                "marginBottom": "25px"
            },
            children=[

                html.Div([
                    html.Label("REAL / NOMINAL"),
                    dcc.RadioItems(
                        id="real-nominal-toggle",
                        options=[
                            {"label": "Nominal", "value": "Nominal"},
                            {"label": "Real", "value": "Real"}
                        ],
                        value="Nominal",
                        inline=True
                    )
                ], style={"width": "180px"}),

                html.Div([
                    html.Label("DATASET"),
                    dcc.Dropdown(id="dataset-dropdown")
                ], style={"flex": "2"}),

                html.Div([
                    html.Label("YEAR"),
                    dcc.Dropdown(id="year-dropdown")
                ], style={"width": "120px"}),

                html.Div([
                    html.Label("MASK"),
                    dcc.Dropdown(
                        id="mask-dropdown",
                        options=[{"label": m, "value": m} for m in MASKS],
                        value="Grouped"
                    )
                ], style={"width": "160px"}),

                html.Div([
                    html.Label("TOP / BOTTOM"),
                    dcc.Input(id="top-x", type="number", value=10, min=1, step=1),
                    dcc.RadioItems(
                        id="top-bottom",
                        options=[
                            {"label": "Top", "value": "top"},
                            {"label": "Bottom", "value": "bottom"}
                        ],
                        value="top",
                        inline=True
                    )
                ], style={"width": "180px"})
            ]
        ),

        # ---- CHART ----
        dcc.Graph(
            id="bar-chart",
            style={
                "height": "900px",
                "backgroundColor": "white",
                "borderRadius": "12px",
                "boxShadow": "0 4px 12px rgba(0,0,0,0.08)"
            }
        )
    ]
)

# ----------------------------
# 4. CALLBACKS
# ----------------------------
@app.callback(
    Output("dataset-dropdown", "options"),
    Output("dataset-dropdown", "value"),
    Input("real-nominal-toggle", "value")
)
def update_dataset_dropdown(real_nominal):
    datasets = NOMINAL_DATASETS if real_nominal == "Nominal" else REAL_DATASETS
    options = [{"label": d, "value": d} for d in datasets]

    if real_nominal == "Nominal":
        default_value = "GDPCAP_COM_NOMINAL"
    else:
        default_value = "REALGDPCAP_COM_2024"

    value = default_value if default_value in datasets else (
        datasets[0] if datasets else None
    )

    return options, value


@app.callback(
    Output("year-dropdown", "options"),
    Output("year-dropdown", "value"),
    Input("dataset-dropdown", "value")
)
def update_year_options(dataset_name):
    if dataset_name not in DATASETS_DICT:
        return [], None
    df = DATASETS_DICT[dataset_name]
    years = sorted([c for c in df.columns if c != "Country Name"], key=int)
    options = [{"label": y, "value": y} for y in years]
    return options, (years[-1] if years else None)

@app.callback(
    Output("bar-chart", "figure"),
    Input("dataset-dropdown", "value"),
    Input("mask-dropdown", "value"),
    Input("top-x", "value"),
    Input("top-bottom", "value"),
    Input("year-dropdown", "value"),
    Input("real-nominal-toggle", "value")
)
def update_chart(dataset_name, mask_name, top_x, top_bottom, selected_year, real_nominal):

    if dataset_name not in DATASETS_DICT:
        return go.Figure()

    df = DATASETS_DICT[dataset_name].copy()

    year_cols = sorted([c for c in df.columns if c != "Country Name"], key=int)
    year = selected_year if selected_year in year_cols else year_cols[-1]

    try:
        df = apply_mask(df, mask_name, year)
    except Exception:
        pass

    df = df[["Country Name", year]].dropna()
    df["Country Name"] = df["Country Name"].str.upper()

    df = df.sort_values(
        year,
        ascending=(top_bottom == "bottom")
    ).head(top_x)

    fig = go.Figure(go.Bar(
        x=df[year],
        y=df["Country Name"],
        orientation="h",
        marker_color=BAR_BLUE,
        text=[f"{v:,.2f}" for v in df[year]],
        textposition="inside",
        insidetextanchor="middle"
    ))

    fig.update_layout(
        title={
            "text": f"{dataset_name} ({mask_name}) [{real_nominal}] — {top_bottom.upper()} {top_x}",
            "x": 0.5,
            "font": {"size": 20, "color": DARK_BLUE}
        },
        template="plotly_white",
        yaxis=dict(
            autorange="reversed",
            tickfont=dict(size=14)
        ),
        xaxis=dict(
            tickfont=dict(size=12),
            gridcolor="#e6e6e6"
        ),
        margin=dict(l=160, r=40, t=80, b=40),
        height=900
    )

    return fig

# ----------------------------
# 5. RUN
# ----------------------------
if __name__ == "__main__":
    app.run(
        debug=True,
        port=8051,
        use_reloader=False
    )


In [None]:
import pandas as pd
import plotly.express as px
from ipywidgets import widgets, interactive_output, VBox

# ---------------------------------------------------------
# 1. DATASET REGISTRY
# ---------------------------------------------------------
MAP_DATASETS = DATASETS_AVAILABLE.copy()


# ---------------------------------------------------------
# 2. Helper: available years per dataset
# ---------------------------------------------------------
def get_years(dataset_name):
    df = MAP_DATASETS[dataset_name]
    years = [
        int(c) for c in df.columns
        if c != "Country Name" and c.isdigit()
    ]
    return sorted(years)

# ---------------------------------------------------------
# 3. Map plotting function
# ---------------------------------------------------------
def plot_map(dataset_name, year):
    df = MAP_DATASETS[dataset_name]

    year = str(year)
    if year not in df.columns:
        print(f"Year {year} not available for {dataset_name}")
        return

    fig = px.choropleth(
        df,
        locations="Country Name",
        locationmode="country names",
        color=year,
        color_continuous_scale="Blues",
        title=f"{dataset_name} – {year}",
        width=1100,
        height=600
    )

    fig.update_layout(
        geo=dict(showframe=False, showcoastlines=True),
        coloraxis_colorbar=dict(title=dataset_name),
        margin=dict(l=0, r=0, t=50, b=0)
    )

    fig.show()

# ---------------------------------------------------------
# 4. Widgets
# ---------------------------------------------------------
dataset_widget = widgets.Dropdown(
    options=sorted(MAP_DATASETS.keys()),
    value=(
        "GDP_COM_NOMINAL"
        if "GDP_COM_NOMINAL" in MAP_DATASETS
        else (sorted(MAP_DATASETS.keys())[0] if MAP_DATASETS else None)
    ),
    description="Dataset:"
)


year_widget = widgets.IntSlider(
    value=2024,
    min=1960,
    max=2024,
    step=1,
    description="Year:"
)

# ---------------------------------------------------------
# 5. Update year slider when dataset changes
# ---------------------------------------------------------
def update_year_slider(change=None):
    if not MAP_DATASETS:
        return
    years = get_years(dataset_widget.value)
    if not years:
        return
    year_widget.min = years[0]
    year_widget.max = years[-1]
    year_widget.value = years[-1]


dataset_widget.observe(update_year_slider, names="value")
update_year_slider()

# ---------------------------------------------------------
# 6. Link widgets
# ---------------------------------------------------------
out = interactive_output(
    plot_map,
    {
        "dataset_name": dataset_widget,
        "year": year_widget
    }
)

display(VBox([dataset_widget, year_widget, out]))


In [None]:
import pandas as pd
import folium
import json
from IPython.display import display
import branca

# -----------------------------
# File paths
# -----------------------------
world_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\World.json"
us_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\US.json"
china_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\China.geojson"
india_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\india.json"
canada_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\Canada.json"

geojsons = {
    "World": world_path,
    "US": us_path,
    "China": china_path,
    "India": india_path,
    "Canada": canada_path,
}

# -----------------------------
# Choose the year to display
# -----------------------------
YEAR = "2021"  # Change this to any available year

# -----------------------------
# Prepare GDP_COM for the chosen year
# -----------------------------
GDP_map = GDP_COM.copy()

# Set top-level countries to NaN to force using subnational units
for country in ["United States", "Canada", "China", "India"]:
    GDP_map.loc[GDP_map["Country Name"] == country, YEAR] = pd.NA

# Convert to long format
GDP_long = GDP_map.melt(id_vars=["Country Name"], var_name="Year", value_name="Value")
GDP_long["Value"] = pd.to_numeric(GDP_long["Value"], errors="coerce") / 1e9  # Billions USD

# -----------------------------
# Load GeoJSONs
# -----------------------------
geo_layers = {}
for name, path in geojsons.items():
    with open(path, "r", encoding="utf-8") as f:
        geo_layers[name] = json.load(f)

# -----------------------------
# Helper functions
# -----------------------------
def get_feature_name(feature):
    for key in ["NAME", "Name", "name", "st_nm", "ST_NM", "shapeName"]:
        if key in feature["properties"]:
            return feature["properties"][key]
    return None

def get_gdp(region_name):
    row = GDP_long[(GDP_long["Country Name"] == region_name) & (GDP_long["Year"] == YEAR)]
    if not row.empty and pd.notna(row["Value"].values[0]):
        return row["Value"].values[0]
    return None

# -----------------------------
# Map plotting
# -----------------------------
m = folium.Map(location=[20, 0], zoom_start=2, height=600)

# Colormap
max_val = 5000
colormap = branca.colormap.linear.YlGnBu_09.scale(0, max_val)
colormap.caption = f"GDP (Billions USD) — Scale up to {max_val:,.0f} B"
colormap.add_to(m)

# Title
title_html = f"<h3 align='center' style='font-size:20px'>Global GDP Map — {YEAR}</h3>"
m.get_root().html.add_child(folium.Element(title_html))

# World layer
world = geo_layers["World"]
for feature in world["features"]:
    cname = get_feature_name(feature)
    val = get_gdp(cname)
    color = colormap(min(val, max_val)) if val is not None else "#f0f0f0"
    feature["properties"]["tooltip"] = f"{cname}: {val:,.1f} B USD" if val else f"{cname}: N/A"
    feature["properties"]["style"] = {
        "fillColor": color,
        "color": "black",
        "weight": 0.5,
        "fillOpacity": 0.75 if val else 0.4,
    }

folium.GeoJson(
    world,
    style_function=lambda f: f["properties"]["style"],
    tooltip=folium.GeoJsonTooltip(fields=["tooltip"], labels=False, sticky=True),
    name="World",
).add_to(m)

# Subregions
for key in ["US", "China", "India", "Canada"]:
    geo = geo_layers[key]
    for feature in geo["features"]:
        rname = get_feature_name(feature)
        val = get_gdp(rname)
        color = colormap(min(val, max_val)) if val is not None else "#f0f0f0"

        # Grayed for Chinese and Indian provinces
        fill_opacity = 0.4 if key in ["China", "India"] else 0.75 if val else 0.4

        feature["properties"]["tooltip"] = f"{rname}: {val:,.1f} B USD" if val else f"{rname}: N/A"
        feature["properties"]["style"] = {
            "fillColor": color,
            "color": "black",
            "weight": 0.5,
            "fillOpacity": fill_opacity,
        }

    folium.GeoJson(
        geo,
        style_function=lambda f: f["properties"]["style"],
        tooltip=folium.GeoJsonTooltip(fields=["tooltip"], labels=False, sticky=True),
        name=f"{key} Subregions",
    ).add_to(m)

display(m)


In [None]:
import pandas as pd
import folium
import json
from IPython.display import display
import branca

# -----------------------------
# File paths
# -----------------------------
world_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\World.json"
us_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\US.json"
china_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\China.geojson"
india_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\india.json"
canada_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\Canada.json"

geojsons = {
    "World": world_path,
    "US": us_path,
    "China": china_path,
    "India": india_path,
    "Canada": canada_path,
}

# -----------------------------
# Choose the year to display
# -----------------------------
YEAR = "2021"  # Change to any available year in POP_COM

# -----------------------------
# Prepare POP_COM for the chosen year
# -----------------------------
POP_map = POP_COM.copy()

# Set top-level countries to NaN to force using subnational units
for country in ["United States", "Canada", "China", "India"]:
    POP_map.loc[POP_map["Country Name"] == country, YEAR] = pd.NA

# Convert to long format
POP_long = POP_map.melt(id_vars=["Country Name"], var_name="Year", value_name="Value")
POP_long["Value"] = pd.to_numeric(POP_long["Value"], errors="coerce") / 1_000_000  # Millions

# -----------------------------
# Load GeoJSONs
# -----------------------------
geo_layers = {}
for name, path in geojsons.items():
    with open(path, "r", encoding="utf-8") as f:
        geo_layers[name] = json.load(f)

# -----------------------------
# Helper functions
# -----------------------------
def get_feature_name(feature):
    for key in ["NAME", "Name", "name", "st_nm", "ST_NM", "shapeName"]:
        if key in feature["properties"]:
            return feature["properties"][key]
    return None

def get_population(region_name):
    row = POP_long[(POP_long["Country Name"] == region_name) & (POP_long["Year"] == YEAR)]
    if not row.empty and pd.notna(row["Value"].values[0]):
        return row["Value"].values[0]
    return None

# -----------------------------
# Map plotting
# -----------------------------
m = folium.Map(location=[20, 0], zoom_start=2, height=600)

# Colormap
max_val = 300  # Millions, adjust based on dataset
colormap = branca.colormap.linear.YlOrRd_09.scale(0, max_val)
colormap.caption = f"Population (Millions) — Scale up to {max_val:,.0f} M"
colormap.add_to(m)

# Title
title_html = f"<h3 align='center' style='font-size:20px'>Global Population Map — {YEAR}</h3>"
m.get_root().html.add_child(folium.Element(title_html))

# World layer
world = geo_layers["World"]
for feature in world["features"]:
    cname = get_feature_name(feature)
    val = get_population(cname)
    color = colormap(min(val, max_val)) if val is not None else "#f0f0f0"
    feature["properties"]["tooltip"] = f"{cname}: {val:,.1f} M" if val else f"{cname}: N/A"
    feature["properties"]["style"] = {
        "fillColor": color,
        "color": "black",
        "weight": 0.5,
        "fillOpacity": 0.75 if val else 0.4,
    }

folium.GeoJson(
    world,
    style_function=lambda f: f["properties"]["style"],
    tooltip=folium.GeoJsonTooltip(fields=["tooltip"], labels=False, sticky=True),
    name="World",
).add_to(m)

# Subregions
for key in ["US", "China", "India", "Canada"]:
    geo = geo_layers[key]
    for feature in geo["features"]:
        rname = get_feature_name(feature)
        val = get_population(rname)
        color = colormap(min(val, max_val)) if val is not None else "#f0f0f0"

        # Grayed for Chinese and Indian provinces
        fill_opacity = 0.4 if key in ["China", "India"] else 0.75 if val else 0.4

        feature["properties"]["tooltip"] = f"{rname}: {val:,.1f} M" if val else f"{rname}: N/A"
        feature["properties"]["style"] = {
            "fillColor": color,
            "color": "black",
            "weight": 0.5,
            "fillOpacity": fill_opacity,
        }

    folium.GeoJson(
        geo,
        style_function=lambda f: f["properties"]["style"],
        tooltip=folium.GeoJsonTooltip(fields=["tooltip"], labels=False, sticky=True),
        name=f"{key} Subregions",
    ).add_to(m)

display(m)


In [None]:
import pandas as pd
import folium
import json
from IPython.display import display
import branca

# -----------------------------
# File paths (same as before)
# -----------------------------
world_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\World.json"
us_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\US.json"
china_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\China.geojson"
india_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\india.json"
canada_path = r"C:\Users\Utente\Documents\WIP\Geojsonmaps\Canada.json"

geojsons = {
    "World": world_path,
    "US": us_path,
    "China": china_path,
    "India": india_path,
    "Canada": canada_path,
}

# -----------------------------
# Year and dataset
# -----------------------------
YEAR = "2021"
DATASET = DATASETS_AVAILABLE["GDPCAP_COM_NOMINAL"].copy()


# -----------------------------
# Set top-level countries to NaN to force using subnational units
# -----------------------------
for country in ["United States", "Canada", "China", "India"]:
    DATASET.loc[DATASET["Country Name"] == country, YEAR] = pd.NA

# Convert to long format
DATA_long = DATASET.melt(id_vars=["Country Name"], var_name="Year", value_name="Value")
DATA_long["Value"] = pd.to_numeric(DATA_long["Value"], errors="coerce") / 1_000  # in Thousands USD

# -----------------------------
# Load GeoJSONs
# -----------------------------
geo_layers = {}
for name, path in geojsons.items():
    with open(path, "r", encoding="utf-8") as f:
        geo_layers[name] = json.load(f)

# -----------------------------
# Helper functions
# -----------------------------
def get_feature_name(feature):
    for key in ["NAME", "Name", "name", "st_nm", "ST_NM", "shapeName"]:
        if key in feature["properties"]:
            return feature["properties"][key]
    return None

def get_value(region_name):
    row = DATA_long[(DATA_long["Country Name"] == region_name) & (DATA_long["Year"] == YEAR)]
    if not row.empty and pd.notna(row["Value"].values[0]):
        return row["Value"].values[0]
    return None

# -----------------------------
# Map plotting
# -----------------------------
m = folium.Map(location=[20, 0], zoom_start=2, height=600)

# Colormap (scale to 0–100k USD per capita)
max_val = 100  # Thousands USD per capita
colormap = branca.colormap.linear.YlGnBu_09.scale(0, max_val)
colormap.caption = f"GDP per Capita (Thousands USD) — Scale up to {max_val:,.0f}k"
colormap.add_to(m)

# Title
title_html = f"<h3 align='center' style='font-size:20px'>GDP per Capita Map — {YEAR}</h3>"
m.get_root().html.add_child(folium.Element(title_html))

# World layer
world = geo_layers["World"]
for feature in world["features"]:
    cname = get_feature_name(feature)
    val = get_value(cname)
    color = colormap(min(val, max_val)) if val is not None else "#f0f0f0"
    feature["properties"]["tooltip"] = f"{cname}: {val:,.1f}k USD" if val else f"{cname}: N/A"
    feature["properties"]["style"] = {
        "fillColor": color,
        "color": "black",
        "weight": 0.5,
        "fillOpacity": 0.75 if val else 0.4,
    }

folium.GeoJson(
    world,
    style_function=lambda f: f["properties"]["style"],
    tooltip=folium.GeoJsonTooltip(fields=["tooltip"], labels=False, sticky=True),
    name="World",
).add_to(m)

# Subregions
for key in ["US", "China", "India", "Canada"]:
    geo = geo_layers[key]
    for feature in geo["features"]:
        rname = get_feature_name(feature)
        val = get_value(rname)
        color = colormap(min(val, max_val)) if val is not None else "#f0f0f0"

        # Grayed for Chinese and Indian provinces
        fill_opacity = 0.4 if key in ["China", "India"] else 0.75 if val else 0.4

        feature["properties"]["tooltip"] = f"{rname}: {val:,.1f}k USD" if val else f"{rname}: N/A"
        feature["properties"]["style"] = {
            "fillColor": color,
            "color": "black",
            "weight": 0.5,
            "fillOpacity": fill_opacity,
        }

    folium.GeoJson(
        geo,
        style_function=lambda f: f["properties"]["style"],
        tooltip=folium.GeoJsonTooltip(fields=["tooltip"], labels=False, sticky=True),
        name=f"{key} Subregions",
    ).add_to(m)

display(m)


GOVERNMENT DATA

In [None]:
import pandas as pd

# ----------------------------
# Function to reshape wide datasets to long
# ----------------------------
def reshape_wide(df, value_name):
    year_cols = [col for col in df.columns if str(col).isdigit()]
    df_long = df.melt(id_vars=["Country Name"], value_vars=year_cols,
                      var_name="Year", value_name=value_name)
    df_long["Year"] = df_long["Year"].astype(int)
    return df_long

# ----------------------------
# Load and reshape IMF datasets
# ----------------------------
imf_files = {
    "GTEXP": r"C:\Users\Utente\Documents\WIP\GOVDATA\imf-tot_exp_gdp.xls",
    "GTREV": r"C:\Users\Utente\Documents\WIP\GOVDATA\imf-rev_gdp.xls",
    "GPREXP": r"C:\Users\Utente\Documents\WIP\GOVDATA\imf-prim_exp_gdp.xls",
    "GINTDEB": r"C:\Users\Utente\Documents\WIP\GOVDATA\imf-int_pay_deb_gdp.xls",
    "GTDEB": r"C:\Users\Utente\Documents\WIP\GOVDATA\imf-gro_pub_deb_gdp.xls"
}

imf_data = {name: reshape_wide(pd.read_excel(path), name) for name, path in imf_files.items()}

# ----------------------------
# Keep only common countries across IMF datasets
# ----------------------------
common_countries = set(imf_data["GTEXP"]["Country Name"])
for df in imf_data.values():
    common_countries &= set(df["Country Name"])

imf_data = {name: df[df["Country Name"].isin(common_countries)].reset_index(drop=True)
            for name, df in imf_data.items()}

# ----------------------------
# Merge IMF datasets
# ----------------------------
df_imf = imf_data["GTEXP"]
for name in ["GTREV", "GPREXP", "GINTDEB", "GTDEB"]:
    df_imf = df_imf.merge(imf_data[name], on=["Country Name", "Year"], how="left")

# ----------------------------
# Calculate fiscal variables
# ----------------------------
df_imf = df_imf.sort_values(["Country Name", "Year"])
df_imf["GDEF"] = df_imf["GTEXP"] - df_imf["GTREV"]
df_imf["GINCDEB"] = df_imf.groupby("Country Name")["GTDEB"].diff().fillna(0)
df_imf["NINCDEB"] = df_imf["GDEF"] - df_imf["GINCDEB"]

# ----------------------------
# Load local datasets
# ----------------------------
gdp = reshape_wide(pd.read_excel(r"C:\Users\Utente\Documents\WIP\GOVDATA\XLSX\GDP.xlsx"), "GDP")
pop = reshape_wide(pd.read_excel(r"C:\Users\Utente\Documents\WIP\GOVDATA\XLSX\POP.xlsx"), "POP")

# ----------------------------
# Merge IMF with local datasets
# ----------------------------
df = df_imf.merge(gdp, on=["Country Name", "Year"], how="left") \
           .merge(pop, on=["Country Name", "Year"], how="left")

# ----------------------------
# Current values (_curr) multiplied by GDP/100
# ----------------------------
fiscal_vars = ["GTEXP","GTREV","GPREXP","GINTDEB","GTDEB","GDEF","GINCDEB","NINCDEB"]
for col in fiscal_vars:
    df[f"{col}_curr"] = df[col] * df["GDP"] / 100
    df[f"{col}_percGDP"] = df[col]  # original %GDP

# ----------------------------
# Per capita values (_curr_pc)
# ----------------------------
for col in fiscal_vars:
    df[f"{col}_curr_pc"] = df[f"{col}_curr"] / df["POP"]

# ----------------------------
# Final cleaned dataset
# ----------------------------
cols_keep = ["Country Name", "Year"] + \
            [f"{col}_percGDP" for col in fiscal_vars] + \
            [f"{col}_curr" for col in fiscal_vars] + \
            [f"{col}_curr_pc" for col in fiscal_vars]

df_final = df[cols_keep].copy()


In [None]:
import dash
from dash import dcc, html, Input, Output
import plotly.graph_objs as go

# =========================================================
# 1. CONFIG
# =========================================================
datasets = ["GTEXP","GTREV","GPREXP","GINTDEB","GTDEB","GDEF","GINCDEB","NINCDEB"]
value_types = ["%GDP", "Current", "Current per capita"]

suffix_map = {
    "%GDP": "_percGDP",
    "Current": "_curr",
    "Current per capita": "_curr_pc"
}

# =========================================================
# 2. DASH APP
# =========================================================
app = dash.Dash(__name__)
app.title = "Fiscal Dashboard"

# =========================================================
# 3. LAYOUT (MORE COMPACT)
# =========================================================
app.layout = html.Div(
    style={
        "width": "100%",
        "padding": "10px 18px",
        "fontFamily": "Arial, sans-serif",
        "backgroundColor": "white"
    },
    children=[

        html.H3(
            "Fiscal Dashboard",
            style={"textAlign": "center", "marginBottom": "2px"}
        ),

        html.P(
            "Government finance indicators",
            style={
                "textAlign": "center",
                "color": "#6c757d",
                "fontSize": "13px",
                "marginBottom": "10px"
            }
        ),

        # ----------------------------
        # CONTROLS
        # ----------------------------
        html.Div(
            style={
                "display": "flex",
                "gap": "12px",
                "marginBottom": "10px"
            },
            children=[

                html.Div([
                    html.Label("Value Type", style={"fontWeight": "600", "fontSize": "12px"}),
                    dcc.Dropdown(
                        id="value-type",
                        options=[{"label": v, "value": v} for v in value_types],
                        value="%GDP",
                        clearable=False
                    )
                ], style={"width": "18%"}),

                html.Div([
                    html.Label("Dataset(s)", style={"fontWeight": "600", "fontSize": "12px"}),
                    dcc.Dropdown(
                        id="dataset",
                        options=[{"label": d, "value": d} for d in datasets],
                        value=["GDEF"],
                        multi=True
                    )
                ], style={"width": "40%"}),

                html.Div([
                    html.Label("Country(s)", style={"fontWeight": "600", "fontSize": "12px"}),
                    dcc.Dropdown(
                        id="countries",
                        options=[
                            {"label": c, "value": c}
                            for c in sorted(df_final["Country Name"].unique())
                        ],
                        value=["Italy"],
                        multi=True
                    )
                ], style={"width": "40%"})
            ]
        ),

        # ----------------------------
        # CHART (SMALLER)
        # ----------------------------
        dcc.Graph(
            id="line-chart",
            style={"height": "380px"}
        )
    ]
)

# =========================================================
# 4. CALLBACK
# =========================================================
@app.callback(
    Output("line-chart", "figure"),
    Input("value-type", "value"),
    Input("dataset", "value"),
    Input("countries", "value")
)
def update_chart(value_type, selected_datasets, selected_countries):

    fig = go.Figure()

    if not selected_datasets or not selected_countries:
        fig.update_layout(title="Select at least one dataset and one country")
        return fig

    suffix = suffix_map[value_type]
    df_plot = df_final[df_final["Country Name"].isin(selected_countries)]

    for ds in selected_datasets:
        col = f"{ds}{suffix}"
        if col not in df_plot.columns:
            continue

        for country in selected_countries:
            sub = df_plot[df_plot["Country Name"] == country]
            fig.add_trace(go.Scatter(
                x=sub["Year"],
                y=sub[col],
                mode="lines",
                line=dict(width=2),
                name=f"{country} – {ds}"
            ))

    fig.update_layout(
        template="plotly_white",
        hovermode="x unified",
        height=380,
        margin=dict(l=50, r=25, t=35, b=35),
        xaxis=dict(
            title="Year",
            gridcolor="rgba(0,0,0,0.07)",
            tickfont=dict(size=11)
        ),
        yaxis=dict(
            title=value_type,
            gridcolor="rgba(0,0,0,0.07)",
            tickfont=dict(size=11)
        ),
        legend=dict(font=dict(size=10))
    )

    return fig

# =========================================================
# 5. RUN
# =========================================================
if __name__ == "__main__":
    app.run(debug=True)


In [None]:
EXPFUN = pd.read_csv(r"C:\Users\Utente\Documents\WIP\GOVDATA\government-spending-by-function.csv")

EXPFUN = EXPFUN.drop(EXPFUN.columns[1], axis=1)
EXPFUN_2023 = EXPFUN[EXPFUN['Year'] == 2023].reset_index(drop=True)
EXPFUN_2023 = EXPFUN_2023.drop(EXPFUN_2023.columns[1], axis=1)
EXPFUN_italy = EXPFUN[EXPFUN['Country Name'] == 'Italy'].reset_index(drop=True)
#EXPFUN_2023.head()

In [None]:
import pandas as pd
import plotly.express as px
from ipywidgets import interact, Dropdown

# --- Step 1: Build a function to reshape EXPFUN for plotting ---
def prepare_data(selected_countries):
    combined_df = []
    for country in selected_countries:
        df_country = EXPFUN[EXPFUN['Country Name'] == country].copy()
        df_long = df_country.melt(id_vars=['Country Name','Year'], 
                                  var_name='Category', 
                                  value_name='Value')
        df_long.rename(columns={'Country Name':'Country'}, inplace=True)
        combined_df.append(df_long)
    return pd.concat(combined_df)

# --- Step 2: Define color palette ---
soft_dark_colors = [
    "#4C566A", "#5E81AC", "#81A1C1", "#88C0D0",
    "#A3BE8C", "#EBCB8B", "#D08770", "#BF616A",
    "#B48EAD", "#8FBCBB", "#434C5E", "#2E3440"
]

# --- Step 3: Plotting function ---
def plot_area_chart(selected_country1, selected_country2):
    countries_to_plot = [c for c in [selected_country1, selected_country2] if c != "None"]
    if not countries_to_plot:
        print("Please select at least one country.")
        return
    
    df_all = prepare_data(countries_to_plot)
    df_all['Year'] = df_all['Year'].astype(int)
    
    fig = px.area(
        df_all,
        x='Year',
        y='Value',
        color='Category',
        facet_col='Country' if len(countries_to_plot) == 2 else None,
        title=" – ".join(countries_to_plot) + " – Stacked Area Chart",
        color_discrete_sequence=soft_dark_colors,
        hover_name='Category',
        hover_data={'Value': ':.2f', 'Year': True}
    )
    
    fig.update_traces(line=dict(width=1), hovertemplate='%{x}<br>%{fullData.name}: %{y:.2f}')
    fig.update_layout(
        xaxis=dict(dtick=1, tickmode='linear', tickformat='d'),
        legend_title_text='Category',
        legend=dict(traceorder='reversed', orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.02),
        plot_bgcolor='white',
        margin=dict(r=180, t=60),
        height=750
    )
    
    fig.show()

# --- Step 4: Interactive widget ---
countries = ["Italy"] + sorted(EXPFUN['Country Name'].unique())
interact(
    plot_area_chart,
    selected_country1=Dropdown(options=countries, description='Country 1:'),
    selected_country2=Dropdown(options=countries, description='Country 2:')
)


In [None]:
import pandas as pd
import plotly.express as px
from ipywidgets import interact, Dropdown, ToggleButtons

# --- Step 1: Build nested dictionary with absolute values ---
data = {}
for _, row in EXPFUN.iterrows():
    country = row['Country Name']
    year = row['Year']

    # Get total expenditure from GTEXP_curr in long format
    try:
        total_exp = df_final.loc[
            (df_final['Country Name'] == country) & (df_final['Year'] == year),
            'GTEXP_curr'
             ].values[0]

    except IndexError:
        continue  # skip if no data for that country/year

    # Compute absolute values for expenditure categories
    fractions = row.drop(['Country Name', 'Year']).to_dict()  # assume fractions in %
    absolute_values = {cat: total_exp * val / 100 for cat, val in fractions.items()}

    # Store in nested dictionary
    if country not in data:
        data[country] = {}
    data[country][year] = absolute_values



# --- Step 2: Soft dark color palette ---
soft_dark_colors = [
    "#4C566A", "#5E81AC", "#81A1C1", "#88C0D0",
    "#A3BE8C", "#EBCB8B", "#D08770", "#BF616A",
    "#B48EAD", "#8FBCBB", "#434C5E", "#2E3440"
]

# --- Step 3: Plotting function with value type toggle ---
def plot_country(selected_country1, selected_country2, selected_year, value_type):
    countries_to_plot = [c for c in [selected_country1, selected_country2] if c != "None"]

    if not countries_to_plot:
        print("Please select at least one country.")
        return

    # Helper: convert to per capita if requested
    def apply_value_type(df, country_col='Country', year_col='Year'):
        if value_type == "Per Capita":
            for country in df[country_col].unique():
                for year in df[year_col].unique():
                    try:
                        pop = POP_COM.loc[POP_COM['Country Name'] == country, str(int(year))].values[0]
                        df.loc[(df[country_col] == country) & (df[year_col] == year), 'Value'] /= pop
                    except (KeyError, IndexError):
                        pass
        return df

    # === AREA CHART (All years) ===
    if selected_year == 'All':
        combined_df = []
        for country in countries_to_plot:
            cdata = data[country]
            df = pd.DataFrame(cdata).T.reset_index().rename(columns={'index': 'Year'})
            df['Year'] = df['Year'].astype(int)
            df_long = df.melt(id_vars='Year', var_name='Category', value_name='Value')
            df_long['Country'] = country
            combined_df.append(df_long)
        df_all = pd.concat(combined_df)
        df_all = apply_value_type(df_all)

        fig = px.area(
            df_all,
            x='Year',
            y='Value',
            color='Category',
            facet_col='Country' if len(countries_to_plot) == 2 else None,
            title=f"{' – '.join(countries_to_plot)} – Area Chart ({value_type})",
            color_discrete_sequence=soft_dark_colors,
            hover_name='Category',
            hover_data={'Value': ':.2f', 'Year': True}
        )
        fig.update_traces(line=dict(width=1), hovertemplate='%{x}<br>%{fullData.name}: %{y:.2f}')
        fig.update_layout(
            xaxis=dict(dtick=1, tickmode='linear', tickformat='d'),
            legend_title_text='Category',
            legend=dict(traceorder='reversed', orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.02),
            plot_bgcolor='white',
            margin=dict(r=180, t=60),
            height=750
        )
        fig.show()

    # === STACKED COLUMN (Single year) ===
    else:
        year = int(selected_year)
        combined_df = []
        for country in countries_to_plot:
            if year not in data[country]:
                print(f"Year {year} not available for {country}")
                continue
            df = pd.DataFrame([data[country][year]]).melt(var_name='Category', value_name='Value')
            df['Country'] = country
            df['Year'] = year
            combined_df.append(df)
        if not combined_df:
            print(f"No data found for {selected_year}")
            return
        df_all = pd.concat(combined_df)
        df_all = apply_value_type(df_all)

        fig = px.bar(
            df_all,
            x='Category',
            y='Value',
            color='Category',
            facet_col='Country' if len(countries_to_plot) == 2 else None,
            title=f"{' – '.join(countries_to_plot)} – Stacked Column {year} ({value_type})",
            color_discrete_sequence=soft_dark_colors,
            hover_data={'Value': ':.2f'}
        )
        fig.update_traces(
            hovertemplate='%{x}<br>%{y:.2f}',
            marker_line_color='black',
            marker_line_width=1,
            text=None
        )
        fig.update_layout(
            barmode='stack',
            xaxis_title='Category',
            yaxis_title='Value',
            legend_title_text='Category',
            legend=dict(traceorder='reversed', orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.02),
            plot_bgcolor='white',
            margin=dict(r=180, t=60),
            height=750
        )
        fig.show()

# --- Step 4: Interactive controls ---
countries = ["Italy"] + sorted(data.keys())
years = sorted(EXPFUN['Year'].unique())
year_options = ['All'] + years

interact(
    plot_country,
    selected_country1=Dropdown(options=countries, description='Country 1:'),
    selected_country2=Dropdown(options=countries, description='Country 2:'),
    selected_year=Dropdown(options=year_options, description='Year:'),
    value_type=ToggleButtons(options=['Current Value', 'Per Capita'], description='Value Type:')
)


In [None]:
import pandas as pd
import numpy as np

# =========================
# 1. LOAD DATA
# =========================
BASE = r"C:\Users\Utente\Documents\WIP"

EXR = pd.read_csv(f"{BASE}\EXR.csv", skiprows=4)
PIT2024 = pd.read_csv(f"{BASE}\PIT2024.csv")
PCR2024 = pd.read_csv(f"{BASE}\PCR2024.csv")
CL2024 = pd.read_csv(f"{BASE}\CL2024.csv")
AGM = pd.read_csv(f"{BASE}\AGM.csv")
USCA_AGA = pd.read_csv(rf"{BASE}\USCA_AGA.csv")

# =========================
# 2. AGA_COM2024 PROCESSING
# =========================

# Convert to numeric

num = AGM.iloc[:, 1:].apply(pd.to_numeric, errors="coerce")
num = num.ffill(axis=1)
num = num.infer_objects(copy=False)
AGM.iloc[:, 1:] = num * 12


AGM.iloc[:, 1:] = AGM.iloc[:, 1:] * 12

AGA_2024 = AGM.iloc[:, [0, -1]].copy()

USCA_sel = USCA_AGA.iloc[1:, [0, -1]].copy()
USCA_sel.columns = AGA_2024.columns

AGA_COM_2024 = pd.concat([AGA_2024, USCA_sel], ignore_index=True)

# =========================
# 3. CLEAN PIT & PCR
# =========================
def clean_basic(df):
    df = df.drop(df.columns[1], axis=1)
    df.columns = range(df.shape[1])
    return df

PIT2024 = clean_basic(PIT2024)
PCR2024 = clean_basic(PCR2024)

# =========================
# 4. EXR PROCESSING
# =========================

# Drop unwanted columns/rows (replace cols_to_remove and rows_to_remove with your lists)
EXR = EXR.drop(EXR.columns[cols_to_remove], axis=1)
EXR = EXR.drop(index=rows_to_remove)

EXR2024 = EXR.iloc[:, [0, -1]].copy()
EXR2024.columns = [0, 1]  # <--- THIS LINE fixes KeyError

# Remove Indonesia rows
EXR2024 = EXR2024[EXR2024[0] != "Indonesia"].reset_index(drop=True)


# Duplicate each row twice
EXR2024 = EXR2024.loc[EXR2024.index.repeat(2)].reset_index(drop=True)

# Get parent country values
us_val = EXR2024.loc[EXR2024.iloc[:, 0] == "United States", EXR2024.columns[1]].iat[0]
ca_val = EXR2024.loc[EXR2024.iloc[:, 0] == "Canada", EXR2024.columns[1]].iat[0]

# US states / Canadian provinces from PIT2024
us_states = PIT2024.iloc[-128:-26, 0].reset_index(drop=True)
ca_provinces = PIT2024.iloc[-26:, 0].reset_index(drop=True)

# Create DataFrames
us_rows = pd.DataFrame({EXR2024.columns[0]: us_states, EXR2024.columns[1]: us_val})
ca_rows = pd.DataFrame({EXR2024.columns[0]: ca_provinces, EXR2024.columns[1]: ca_val})

# Append subregions to EXR2024
EXR2024 = pd.concat([EXR2024, us_rows, ca_rows], ignore_index=True)

# =========================
# Add 21 identical columns to EXR2024 (same as last column)
# =========================
for i in range(21):
    EXR2024[f"{EXR2024.shape[1]}"] = EXR2024.iloc[:, -1]

# =========================
# 5. PIT ÷ EXR (ODD ROWS ONLY, row-wise safe)
# =========================
result = PIT2024.copy()
odd_idx = result.index % 2 == 1

for i in result.index[odd_idx]:
    # Extract row values as numpy arrays (skip first col)
    pit_row = result.iloc[i, 1:].to_numpy(dtype=float, copy=True)
    exr_row = EXR2024.iloc[i, 1:].to_numpy(dtype=float, copy=True)
    
    # Boolean mask for non-NaN in PIT row
    mask = ~np.isnan(pit_row)
    
    # Divide only where PIT has data
    pit_row[mask] = pit_row[mask] / exr_row[mask]
    
    # Assign back
    result.iloc[i, 1:1+len(pit_row)] = pit_row

PIT2024 = result

# =========================
# 6. BUILD COUNTRY DICTIONARY
# =========================
country_dict = {}

def process_PIT_PCR(df, name):
    out = {}
    for i in range(0, len(df) - 1, 2):
        if df.iloc[i, 0] == df.iloc[i + 1, 0]:
            even = df.iloc[i, 1:].where(pd.notna, None).tolist()
            odd  = df.iloc[i + 1, 1:].where(pd.notna, None).tolist()
            key = df.iloc[i, 0]
            if name == "PIT2024":
                out[key] = {"even_row": even, "odd_row": odd}
            else:
                out[key] = {"Employer_contribution": even, "Employee_contribution": odd}
    return {name: out}


def process_AGA(df):
    return {
        "AGA_COM_2024": {
            r.iloc[0]: {"Average Gross": r.iloc[1:].where(pd.notna, None).tolist()}
            for _, r in df.iterrows()
            if r.iloc[1:].notna().any()
        }
    }


def process_CL(df):
    return {
        "CL2024": {
            r.iloc[0]: {
                "Average Monthly Grocery": float(r.iloc[1]) if pd.notna(r.iloc[1]) else None,
                "Average Monthly Rent": float(r.iloc[2]) if pd.notna(r.iloc[2]) else None,
            }
            for _, r in df.iterrows()
            if pd.notna(r.iloc[0])
        }
    }


country_dict.update(process_PIT_PCR(PIT2024, "PIT2024"))
country_dict.update(process_PIT_PCR(PCR2024, "PCR2024"))
country_dict.update(process_AGA(AGA_COM_2024))
country_dict.update(process_CL(CL2024))


In [None]:
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import math
import numpy as np
import matplotlib.pyplot as plt

# ---------------------------
# Assumes country_dict already exists with keys:
# "PIT2024", "PCR2024", "AAW_COM2024", "CL2024"
# ---------------------------

# Helper to extract first numeric from a list-like
def first_numeric(lst):
    if not lst:
        return 0.0
    for x in lst:
        if x in ["", None] or (isinstance(x, float) and math.isnan(x)):
            continue
        try:
            val = float(x)
            return val
        except Exception:
            continue
    return 0.0

# --- Prepare country/state lists ---
all_keys = list(country_dict["PIT2024"].keys())
canadian_provinces = all_keys[-13:]
us_states = all_keys[-(13 + 50):-13]
countries = sorted([k for k in all_keys if k not in us_states + canadian_provinces])

# --- Net salary calculation ---
def calculate_net_salary(TI, thresholds, coefficients):
    thresholds = [float(x) for x in thresholds if x not in ["", None]] if thresholds else []
    coefficients = [float(a) for a in coefficients if a not in ["", None]] if coefficients else []

    if not thresholds:
        return TI

    N = TI
    accumulated = 0.0
    for i in range(len(thresholds)):
        threshold = thresholds[i]
        coefficient = coefficients[i] if i < len(coefficients) else 0.0
        if i == 0:
            if TI <= threshold:
                N = TI - coefficient * TI
                break
        else:
            prev_threshold = thresholds[i - 1]
            if prev_threshold < TI <= threshold:
                N = TI - accumulated - coefficient * (TI - accumulated)
                break
            accumulated += thresholds[i - 1] * (coefficients[i - 1] if i - 1 < len(coefficients) else 0.0)
    else:
        last_coefficient = coefficients[-1] if coefficients else 0.0
        N = TI - accumulated - last_coefficient * (TI - accumulated)
    return N

# --- UI setup ---
country_rows = []
country_vbox = widgets.VBox()
output = widgets.Output()
sort_controls = widgets.HBox([])

# Create country row with remove button
def create_country_row(default_country=None):
    country_dd = widgets.Dropdown(options=countries, description="Country:", value=default_country)
    state_dd = widgets.Dropdown(description="State/Province:")
    state_dd.layout.display = "none"
    remove_btn = widgets.Button(description="❌", layout=widgets.Layout(width="30px"))

    def on_country_change(change):
        if change["new"] == "United States":
            state_dd.options = sorted(us_states)
            state_dd.layout.display = "block"
        elif change["new"] == "Canada":
            state_dd.options = sorted(canadian_provinces)
            state_dd.layout.display = "block"
        else:
            state_dd.layout.display = "none"

    country_dd.observe(on_country_change, names="value")

    def remove_row(b):
        country_rows.remove((country_dd, state_dd, remove_btn))
        country_vbox.children = [w for pair in country_rows for w in pair]

    remove_btn.on_click(remove_row)

    return (country_dd, state_dd, remove_btn)

# --- Default selected countries ---
default_countries = ["Italy", "Germany", "France", "United Kingdom","Spain"]
for c in default_countries:
    country_rows.append(create_country_row(c))

country_vbox.children = [w for pair in country_rows for w in pair]

# --- Add Country button ---
add_button = widgets.Button(description="+ Add Country")
def add_country(b):
    new_row = create_country_row()
    country_rows.append(new_row)
    country_vbox.children = [w for pair in country_rows for w in pair]
add_button.on_click(add_country)

# --- Salary mode selection ---
gross_option = widgets.ToggleButtons(
    options=["Enter Gross Salary", "Use Average Gross", "Enter Corporate Cost"],
    description="Salary Mode:"
)
gross_widget = widgets.FloatText(value=0.0, description="Gross Salary:")
cc_widget = widgets.FloatText(value=0.0, description="Corporate Cost:")
cc_widget.layout.display = "none"

def toggle_salary_input(change):
    if change["new"] == "Use Average Gross":
        gross_widget.layout.display = "none"
        cc_widget.layout.display = "none"
    elif change["new"] == "Enter Corporate Cost":
        cc_widget.layout.display = "block"
        gross_widget.layout.display = "none"
    else:
        gross_widget.layout.display = "block"
        cc_widget.layout.display = "none"

gross_option.observe(toggle_salary_input, names="value")

# --- Years and compound interest ---
years_widget = widgets.BoundedIntText(value=40, min=1, max=200, step=1, description='Years:')
compound_widget = widgets.FloatText(value=0.0, description="PF Annual %", layout=widgets.Layout(width='200px'))
calc_button = widgets.Button(description="Calculate Salary", button_style="primary")

# --- Helper for PIT formatting ---
def format_pit_table(percent_row, bracket_row):
    combined = [(p, b) for p, b in zip(percent_row, bracket_row)
                if p not in ["", None] and b not in ["", None]]
    if not combined:
        return pd.DataFrame()
    perc_values, br_values = zip(*combined)
    df = pd.DataFrame({
        "Percentage": [str(p) for p in perc_values],
        "Bracket":    [str(b) for b in br_values]
    }).T
    df.columns = [f"Tier {i+1}" for i in range(len(df.columns))]
    df.loc["Percentage"] = [f"{float(x)*100:.1f}%" for x in perc_values]
    df.loc["Bracket"] = [f"{float(x):,.2f}" for x in br_values]
    return df

# --- Main calculation ---
def calculate_all(b):
    with output:
        clear_output()
        mode = gross_option.value
        entered_gross = gross_widget.value
        entered_cc = cc_widget.value
        years = int(years_widget.value)
        compound_pct = float(compound_widget.value) / 100.0
        data = []
        pit_accordion_children = []
        pit_titles = []

        display(HTML("<h2 style='text-align:center; color:#2a3f5f;'>INCOME TAX COMPARISON</h2>"))

        for cw, sw, _ in country_rows:
            country = cw.value
            if not country:
                continue
            region_name = sw.value if (country in ["United States", "Canada"] 
                                       and sw.layout.display == "block" 
                                       and sw.value) else country

            # PCR, CL, PIT data retrieval
            pcr_country = country_dict.get("PCR2024", {}).get(country, {})
            emp_contrib = pcr_country.get("Employer_contribution", [])
            ee_contrib = pcr_country.get("Employee_contribution", [])
            total_emp_rate = sum(float(a) for a in emp_contrib if a not in ["", None]) if emp_contrib else 0.0

            # determine Gross and Corporate Cost
            if mode == "Use Average Gross":
                aaw_data = country_dict.get("AGA_COM_2024", {}).get(region_name, {})
                avg_gross_list = aaw_data.get("Average Gross", []) if aaw_data else []
                if avg_gross_list:
                    G = first_numeric(avg_gross_list)
                    CC = G * (1 + total_emp_rate)
                else:
                    continue
            elif mode == "Enter Corporate Cost":
                CC = float(entered_cc)
                G = CC / (1 + total_emp_rate) if (1 + total_emp_rate) != 0 else CC
            else:
                G = float(entered_gross)
                CC = G * (1 + total_emp_rate)

            # Country PIT
            pit_data_country = country_dict.get("PIT2024", {}).get(country, {})
            thresholds_country = pit_data_country.get("odd_row", [])
            coeffs_country = pit_data_country.get("even_row", [])

            EE = sum(G * float(a) for a in ee_contrib if a not in ["", None]) if ee_contrib else 0.0
            TI = G - EE
            N = calculate_net_salary(TI, thresholds_country, coeffs_country)

            # state/province PIT
            if country in ["United States", "Canada"] and sw.layout.display == "block" and sw.value:
                sp_name = sw.value
                sp_data_pit = country_dict.get("PIT2024", {}).get(sp_name, {})
                thresholds_sp = sp_data_pit.get("odd_row", [])
                coeffs_sp = sp_data_pit.get("even_row", [])
                if thresholds_sp or coeffs_sp:
                    N = calculate_net_salary(N, thresholds_sp, coeffs_sp)
            else:
                sp_name = ""

            # CL (rent, groceries)
            cl_region = country_dict.get("CL2024", {}).get(region_name, {})
            avg_groc = cl_region.get("Average Monthly Grocery", 0.0)
            avg_rent = cl_region.get("Average Monthly Rent", 0.0)/2

            monthly_net = N / 12.0
            monthly_savings = monthly_net - avg_groc - avg_rent

            # Pension fund
            pcr_region = country_dict.get("PCR2024", {}).get(region_name, {})
            emp_list = pcr_region.get("Employer_contribution", []) if pcr_region else emp_contrib
            ee_list = pcr_region.get("Employee_contribution", []) if pcr_region else ee_contrib
            first_emp_rate = first_numeric(emp_list)
            first_ee_rate = first_numeric(ee_list)
            pension_per_year = G * (first_emp_rate + first_ee_rate)

            pension_simple = pension_per_year * years

            pot = 0.0
            if compound_pct > 0:
                for _ in range(years):
                    pot = pot * (1 + compound_pct)
                    pot = pot + pension_per_year
                pension_with_int = pot
            else:
                pension_with_int = pension_simple

            savings_over_years = monthly_savings * 12.0 * years if monthly_savings is not None else None

            name = country + (f", {sp_name}" if sp_name else "")
            data.append({
                "Country": name,
                "Corporate Cost": CC,
                "Gross Salary": G,
                "Taxable Income": TI,
                "Net Salary": N,
                "Monthly Net": monthly_net,
                "Average Half Rent": avg_rent,
                "Average Grocery": avg_groc,
                "Monthly Savings": monthly_savings,
                f"Savings ({years}y)": savings_over_years,
                f"PensionFund ({years}y)": pension_simple,
                f"PensionFundWithInt ({years}y)": pension_with_int
            })

            # PIT Accordion
            df_country = format_pit_table(coeffs_country, thresholds_country)
            if not df_country.empty:
                pit_accordion_children.append(widgets.Output())
                with pit_accordion_children[-1]:
                    display(HTML(f"<b>PIT Structure for {country}</b>"))
                    display(df_country)
                pit_titles.append(country)

            if sp_name:
                df_state = format_pit_table(thresholds_sp, coeffs_sp)
                if not df_state.empty:
                    pit_accordion_children.append(widgets.Output())
                    with pit_accordion_children[-1]:
                        display(HTML(f"<b>PIT Structure for {sp_name}</b>"))
                        display(df_state)
                    pit_titles.append(sp_name)

        if not data:
            print("No valid data to display.")
            return

        df_main = pd.DataFrame(data)
        numeric_cols = [c for c in df_main.columns if c != "Country"]
        for col in numeric_cols:
            df_main[col] = pd.to_numeric(df_main[col], errors='coerce')

        # Sorting UI
        sort_columns = [c for c in df_main.columns if c != "Country"]
        default_sort = "Net Salary"  # default by Net Salary
        sort_dd = widgets.Dropdown(options=sort_columns, value=default_sort, description="Sort by:")
        asc_toggle = widgets.ToggleButtons(options=["Descending", "Ascending"], description="Order:")
        sort_controls.children = [sort_dd, asc_toggle]

        def update_sort(change=None):
            with output:
                clear_output(wait=True)
                display(HTML("<h2 style='text-align:center; color:#2a3f5f;'>INCOME TAX COMPARISON</h2>"))
                display(sort_controls)

                sort_by = sort_dd.value
                ascending = asc_toggle.value == "Ascending"
                sorted_df = df_main.sort_values(by=sort_by, ascending=ascending, na_position='last')

                fmt_map = {c: "{:,.2f}" for c in sorted_df.columns if c != "Country"}

                styled_df = (
                    sorted_df.style
                    .format(fmt_map)
                    .set_table_styles([
                        {"selector": "table", "props": [
                            ("margin-left", "auto"), ("margin-right", "auto"),
                            ("border-collapse", "collapse"),
                            ("border", "1px solid #ccc"),
                            ("border-radius", "12px"),
                            ("background-color", "#f9f9f9"),
                            ("width", "95%")
                        ]},
                        {"selector": "th", "props": [
                            ("text-align", "center"),
                            ("background-color", "#2a3f5f"),
                            ("color", "white"),
                            ("padding", "8px"),
                            ("font-size", "14px")
                        ]},
                        {"selector": "td", "props": [
                            ("text-align", "center"),
                            ("padding", "6px"),
                            ("font-size", "13px")
                        ]},
                        {"selector": "tr:hover", "props": [
                            ("background-color", "#eef4fb")
                        ]}
                    ])
                    .hide(axis="index")
                )

                display(HTML("<h3 style='text-align:center; color:#2a3f5f;'>Income & Net Salary</h3>"))
                display(styled_df)

                if pit_accordion_children:
                    pit_acc = widgets.Accordion(children=pit_accordion_children)
                    for idx, title in enumerate(pit_titles):
                        pit_acc.set_title(idx, title)
                    display(pit_acc)

                # --- Stacked Bar Chart with numeric labels ---
                plt.figure(figsize=(12, 7))
                countries_plot = sorted_df["Country"].values
                net_salary = sorted_df["Net Salary"].values
                taxes = sorted_df["Taxable Income"].values - net_salary
                ee_contrib = sorted_df["Gross Salary"].values - sorted_df["Taxable Income"].values
                employer_contrib = sorted_df["Corporate Cost"].values - sorted_df["Gross Salary"].values
                total_cc = sorted_df["Corporate Cost"].values
                segments = [

                    ("Net Salary", net_salary, "#8172B3"),

                    ("Taxes", taxes, "#C44E52"),

                    ("Employee Contribution", ee_contrib, "#55A868"),

                    ("Employer Contribution", employer_contrib, "#4C72B0")
                ]
                left = np.zeros(len(sorted_df))
                for label, values, color in segments:
                    bars = plt.barh(countries_plot, values, left=left, color=color, label=label)
                    for i, bar in enumerate(bars):
                        val = values[i]
                        if val > 0:
                            pct = (val / total_cc[i]) * 100
                            plt.text(
                                left[i] + val / 2,
                                bar.get_y() + bar.get_height() / 2,
                                f"{val:,.0f}\n({pct:.1f}%)",
                                ha="center",
                                va="center",
                                fontsize=10,
                                color="white",
                                fontweight="bold"
                            )
                    left += values
                plt.xlabel("Amount")
                plt.title("Salary Comparison (Stacked – % of Corporate Cost)")
                plt.legend(loc="lower right")
                plt.tight_layout()
                plt.gca().invert_yaxis()

                plt.show()


        sort_dd.observe(update_sort, names="value")
        asc_toggle.observe(update_sort, names="value")
        update_sort()

# --- Display UI ---
calc_button.on_click(calculate_all)
display(HTML("<h1 style='color:#2a3f5f;'>Salary / Tax Comparison Tool</h1>"))
display(country_vbox)
display(widgets.HBox([gross_option, years_widget, compound_widget]))
display(widgets.HBox([gross_widget, cc_widget, add_button]))
display(calc_button)
display(output)
