In [51]:
import logging
# Suppresses warning errors, there are too many of them there about caches from wbdata
logging.getLogger("shelved_cache.persistent_cache").setLevel(logging.ERROR)
import wbdata
import pandas as pd
import datetime
import os
import numpy as np

In [52]:
class HDIData:
    def __init__(self, filepath: str):
        """
        Initialize the HDIData object by loading a CSV file into a DataFrame.
        """
        # Read CSV safely
        self.df = pd.read_csv(filepath, encoding="ISO-8859-1")
        
        # Standardize column names
        self.df.columns = (
            self.df.columns
            .str.strip()
            .str.lower()
            .str.replace(' ', '_')
        )
    
    def _reshape_long(self, df, indicators: dict):
        """
        Converts wide-format year columns into long format for easier filtering.
        """
        id_vars = ['iso3', 'country', 'region']
        
        # Determine columns that match indicator prefixes
        value_vars = [
            col for col in df.columns 
            if any(col.startswith(prefix + "_") for prefix in indicators.keys())
        ]
        
        # Melt to long format
        long_df = df.melt(
            id_vars=id_vars,
            value_vars=value_vars,
            var_name='metric_year',
            value_name='value'
        )
        
        # Split 'metric_year' into 'metric' and 'year'
        long_df[['metric', 'year']] = long_df['metric_year'].str.rsplit('_', n=1, expand=True)
        long_df['year'] = long_df['year'].astype(int)
        long_df.drop(columns='metric_year', inplace=True)

        # Add readable metric names
        long_df['metric_name'] = long_df['metric'].map(indicators)
        
        return long_df

    def get_data(self, indicators: dict, countries=None, start_year=None, end_year=None):
        """
        Retrieve filtered data based on an indicator dictionary, countries, and year range.

        Args:
            indicators (dict): Dictionary of indicator IDs and readable names.
            countries (list or str): Country or list of countries to filter by.
            start_year (int): Start year for filtering.
            end_year (int): End year for filtering.
        """
        # Convert to long format using the indicators provided
        long_df = self._reshape_long(self.df, indicators)
        
        # Filter by countries
        if countries is not None:
            if isinstance(countries, str):
                countries = [countries]
            long_df = long_df[long_df['country'].str.lower().isin([c.lower() for c in countries])]
        
        # Filter by year range
        if start_year is not None:
            long_df = long_df[long_df['year'] >= start_year]
        if end_year is not None:
            long_df = long_df[long_df['year'] <= end_year]
        
        return long_df.reset_index(drop=True)


In [53]:
def load_indicators(indicators: dict, countries=None, years=None, folder_path='P5_Indicator'):
    """
    Load environmental/social indicators from CSV files and filter by country/year.

    Parameters
    ----------
    indicators : dict
        Dictionary mapping variable abbreviations to readable names, e.g.:
        {"BCA": "Biodiversity Conservation Area", "BER": "Biodiversity Expenditure Ratio"}.
    countries : str | list[str] | None
        Country or list of countries to filter.
    years : int | list[int] | tuple(int, int) | None
        Single year, list of years, or range (start, end) to filter.
    folder_path : str
        Path to the folder containing CSV files.

    Returns
    -------
    pd.DataFrame
        Long-format DataFrame with columns:
        ['country', 'iso', 'variable', 'variable_name', 'year', 'value']
    """
    all_dfs = []

    for var, var_name in indicators.items():
        filename = os.path.join(folder_path, f"{var}_ind_na.csv")
        if not os.path.exists(filename):
            raise FileNotFoundError(f"File {filename} not found.")
        
        # Load CSV safely
        df = pd.read_csv(filename)
        
        # Standardize columns
        df.columns = df.columns.str.strip().str.lower().str.replace(' ', '_')
        
        # Identify year columns for this variable (e.g., bca.ind.1990)
        year_cols = [col for col in df.columns if col.startswith(var.lower() + '.ind.')]
        
        # Melt wide -> long
        long_df = df.melt(
            id_vars=['iso', 'country'],
            value_vars=year_cols,
            var_name='metric_year',
            value_name='value'
        )
        
        # Extract year and add metadata
        long_df['year'] = long_df['metric_year'].str.split('.').str[-1].astype(int)
        long_df['variable'] = var
        long_df['variable_name'] = var_name
        long_df = long_df.drop(columns='metric_year')
        
        all_dfs.append(long_df)
    
    # Combine multiple indicators
    result = pd.concat(all_dfs, ignore_index=True)
    
    # Filter by countries
    if countries is not None:
        if isinstance(countries, str):
            countries = [countries]
        result = result[result['country'].str.lower().isin([c.lower() for c in countries])]
    
    # Filter by years
    if years is not None:
        if isinstance(years, tuple):  # range
            result = result[(result['year'] >= years[0]) & (result['year'] <= years[1])]
        elif isinstance(years, list):
            result = result[result['year'].isin(years)]
        else:  # single year
            result = result[result['year'] == years]
    
    return result.reset_index(drop=True)

In [54]:
def fetch_wbdata(indicators, countries="all", start_year=None, end_year=None):
    """
    Fetch World Bank data.

    Parameters:
        indicators (dict): Mapping from indicator code to descriptive name,
                           e.g. {'NY.GDP.MKTP.CD': 'GDP', 'SP.POP.TOTL': 'Population'}
        countries (list or str): ISO2 country codes like ['US', 'CN'], or 'all'
        start_year (int): Start year (optional)
        end_year (int): End year (optional)

    Returns:
        pd.DataFrame: DataFrame with columns ['Country', 'Year', ...indicators...]
    """
    # Handle date range
    if start_year and end_year:
        date_range = (
            datetime.datetime(start_year, 1, 1),
            datetime.datetime(end_year, 12, 31),
        )
    else:
        date_range = None

    # Fetch from World Bank
    df = wbdata.get_dataframe(
        indicators,
        country=countries,
        date=date_range,
        freq='Y',
        parse_dates=True
    )

    # Reset and clean DataFrame
    df = df.reset_index().rename(columns={"country": "Country", "date": "Year"})

    # Convert Year to integer if parsed as datetime
    if pd.api.types.is_datetime64_any_dtype(df["Year"]):
        df["Year"] = df["Year"].dt.year

    # Reorder columns: Country, Year, then indicators
    indicator_columns = list(indicators.values())
    cols = ["Country", "Year"] + indicator_columns
    df = df[[c for c in cols if c in df.columns]]

    return df

In [55]:
indicators_wb = { "2.1_SHARE.TOTAL.RE.IN.TFEC" : "renewable energy share",
                 "SH.H2O.BASW.ZS" :  "basic drinking water"}
indicators_yale = { "SPI" : "Species protection index",
              "SDA" : "Air Polution SO2 trend",
              "HPE" : "Ambient PM2.5 from human resources", 
              "MHP" : "Marine Habitat Protection"}
indicators_hdi = {"gii" : "GII",
                 "pr_f" : "Shares of seats in the parliament",
                 "lfpr_f" : "Labour force participation",
                 "se_f" : "population with secondary education"}


In [56]:
wbdata = fetch_wbdata(indicators_wb, countries="all", start_year=2000, end_year=2001)
print(wbdata)

                         Country  Year  renewable energy share  \
0                    Afghanistan  2000               54.243126   
1                    Afghanistan  2001               54.055055   
2    Africa Eastern and Southern  2000                     NaN   
3    Africa Eastern and Southern  2001                     NaN   
4     Africa Western and Central  2000                     NaN   
..                           ...   ...                     ...   
613                  Yemen, Rep.  2001                1.070389   
614                       Zambia  2000               89.990741   
615                       Zambia  2001               89.809403   
616                     Zimbabwe  2000               69.258777   
617                     Zimbabwe  2001               71.530273   

     basic drinking water  
0               27.441856  
1               27.473580  
2               41.801456  
3               42.481960  
4               50.673598  
..                    ...  
613        

In [57]:
hdi = HDIData("HDR25_Composite_indices_complete_time_series.csv")

hdi_data = hdi.get_data(
    indicators=indicators_hdi,
    countries=None,
    start_year=2000,
    end_year=2001
)
print(hdi_data)

           iso3                          country region      value  metric  \
0           AFG                      Afghanistan     SA        NaN     gii   
1           ALB                          Albania    ECA   0.291000     gii   
2           DZA                          Algeria     AS   0.610000     gii   
3           AND                          Andorra    NaN        NaN     gii   
4           AGO                           Angola    SSA        NaN     gii   
...         ...                              ...    ...        ...     ...   
1643    ZZG.ECA          Europe and Central Asia    NaN  45.216158  lfpr_f   
1644    ZZH.LAC  Latin America and the Caribbean    NaN  49.451366  lfpr_f   
1645     ZZI.SA                       South Asia    NaN  36.036085  lfpr_f   
1646    ZZJ.SSA               Sub-Saharan Africa    NaN  61.999893  lfpr_f   
1647  ZZK.WORLD                            World    NaN  51.505877  lfpr_f   

      year                 metric_name  
0     2000            

In [58]:
epi_data = load_indicators(
    indicators=indicators_yale,
    countries=None,
    years=(2000, 2001)
)
print(epi_data)

      iso                    country  value  year variable  \
0     AFG                Afghanistan    0.3  2000      SPI   
1     ALB                    Albania   14.0  2000      SPI   
2     DZA                    Algeria   34.1  2000      SPI   
3     AND                    Andorra    NaN  2000      SPI   
4     AGO                     Angola   25.2  2000      SPI   
...   ...                        ...    ...   ...      ...   
1755  WLF  Wallis and Futuna Islands    NaN  2001      MHP   
1756  ESH             Western Sahara    0.0  2001      MHP   
1757  YEM                      Yemen    0.1  2001      MHP   
1758  ZMB                     Zambia    NaN  2001      MHP   
1759  ZWE                   Zimbabwe    NaN  2001      MHP   

                  variable_name  
0      Species protection index  
1      Species protection index  
2      Species protection index  
3      Species protection index  
4      Species protection index  
...                         ...  
1755  Marine Habi

In [59]:
wbdata.rename(columns={'Country': 'country', 'Year': 'year'}, inplace=True)

hdi_data_pivot = hdi_data.pivot_table(
    index=['country', 'year'],
    columns='metric_name',
    values='value'
).reset_index()

epi_data_pivot = epi_data.pivot_table(
    index=['country', 'year'],
    columns='variable_name',
    values='value'
).reset_index()

merged = (
    wbdata
    .merge(hdi_data_pivot, on=['country', 'year'], how='outer')
    .merge(epi_data_pivot, on=['country', 'year'], how='outer')
)

merged.rename(columns={
    'renewable energy share': 'Renewable_Energy_Share',
    'basic drinking water': 'Basic_Drinking_Water'
}, inplace=True)

print(merged.head())
print(merged.columns)
print(merged.shape)
print(np.array(merged["country"]))

                       country  year  Renewable_Energy_Share  \
0                  Afghanistan  2000               54.243126   
1                  Afghanistan  2001               54.055055   
2  Africa Eastern and Southern  2000                     NaN   
3  Africa Eastern and Southern  2001                     NaN   
4   Africa Western and Central  2000                     NaN   

   Basic_Drinking_Water  GII  Labour force participation  \
0             27.441856  NaN                         NaN   
1             27.473580  NaN                         NaN   
2             41.801456  NaN                         NaN   
3             42.481960  NaN                         NaN   
4             50.673598  NaN                         NaN   

   Shares of seats in the parliament  population with secondary education  \
0                                NaN                             1.761277   
1                                NaN                             1.901709   
2                      