In [7]:
%%capture
import pandas as pd
import numpy as np
import plotly.express as px
%pip install wbdata
import wbdata
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go

---
\n
## Helper Functions Section

#### Population Function Helpers

In [13]:
#helper functions below
def generateageranges (first_term, second_term, age_range, sex, countrylabel, yearstring):
    age_min = age_range[0] # assuming list input
    age_max = age_range[1]
    listcodes3slot = []
    ageexact = np.arange(age_min, age_max+1)
    for age in ageexact:
        code = ""
        # if age < 25:
        #     if age < 10:
        #         code = "AG0" + f"{age}"
        #     else:
        #         code = "AG" + f"{age}"
        # else:
        agelist = [
            "0004",
            "0509",
            "1014",
            "1519",
            "2024",
            "2529",
            "3034",
            "3539",
            "4044",
            "4549",
            "5054",
            "5559",
            "6064",
            "6569",
            "7074",
            "7579",
            "80UP"
        ]
    
    #next section
    total = 0
    counter = ageexact[0]
    fulldict= {}
    #print(agelist)
    for code in agelist:
        inputval = f"{first_term}.{second_term}.{code}.{sex}"
        var_label = f"{code}"
        counter += 1
        fulldict[inputval] = var_label
    totaldf = wbdata.get_dataframe(fulldict, country=countrylabel, parse_dates=True)
    filteryear = totaldf.loc[yearstring]
    #print(filteryear)
    return filteryear

def interpprep(bucketed_vals, interpolatetype = 'cubic', midpoint_selection = [2,7,12,17,22,27,32,37,42,47,52,57,62,67,72,77,89]):
    midpoints = (bucketed_vals[:-1])/5
    finalbucket = bucketed_vals[-1]/20 #not sure how much to divide by. just doing 20 for 20 items 80-100
    midpoints = pd.concat([midpoints, pd.Series([finalbucket])], ignore_index=True)
    #u can make outputs transparent below
    #print(f"midpoints are {midpoints}. These can be changed if we want to take a new approach to better account for very old ages (100+)")
    #print(f"ages are {midpoint_selection}")
    return midpoint_selection, midpoints

def interpfunc(age_midpoints, pop_values, country, year, max_age=100, graph_values = False):
    bucketed_vals = np.array(pop_values)
    midpoint_selection = np.array(age_midpoints)
    age_range = np.arange(0, max_age + 1)  # ages from 0 to max_age
    #interp time
    interpolated_values = np.interp(age_range, midpoint_selection, bucketed_vals)
    #optional graph
    if graph_values:
        popdf = pd.DataFrame({'Age': age_range, 'Population': interpolated_values})
        fig = px.line(popdf, x='Age', y='Population', title=f'Population Interpolation by Age for {country} in Year {year}')
        fig.show()
    return interpolated_values



#### Helper Functions For In Depth Visualization Function (construct_dataframe)

In [9]:
def generate_graphs(lifedf, variable_labels):
    """
    Generate time-series graphs using Plotly, where each line represents a country,
    and each graph corresponds to a different statistic.

    Parameters:
        lifedf (DataFrame): DataFrame with life-related statistics.
        variable_labels (dict): Mapping of column names to descriptive labels.
    """
    for variable, label in variable_labels.items():
        fig = go.Figure()
        # quick check for existing variable
        if variable not in lifedf.columns:
            print(f"Skipping {variable}: Not found in DataFrame")
            continue
        for country in lifedf[variable].columns:
            data = lifedf[variable][country].dropna()  # Remove NaNs
            #adding country specific lines
            fig.add_trace(go.Scatter(
                x=data.index, 
                y=data.values, 
                mode='lines+markers', 
                name=country
            ))
        fig.update_layout(
            title=f"{label} Over Time",
            xaxis_title="Year",
            yaxis_title="Value",
            legend_title="Country"
        )

        # Show the figure
        fig.show()


In [10]:
def overlay_population(country_name, df):
    """
    Function to overlay female and male population for a given country.

    Parameters:
    - country_name (str): The name of the country whose data will be plotted.
    - df (pd.DataFrame): The DataFrame containing the population data with MultiIndex columns.

    Returns:
    - A Plotly figure with the overlayed female and male population data.
    """
    
    # Check if the country exists in the subcolumns
    if country_name not in df.columns.get_level_values('country'):
        print(f"Country '{country_name}' not found in the dataset.")
        return
    
    # Extract the female and male population data for the country
    female = df[('Total Female', country_name)]
    male = df[('Total Male', country_name)]
    
    # Create a figure
    fig = go.Figure()
    
    # Add the female population plot
    fig.add_trace(go.Scatter(x=female.index, y=female, mode='lines+markers', name='Female', line=dict(dash='dash', color='blue')))
    
    # Add the male population plot
    fig.add_trace(go.Scatter(x=male.index, y=male, mode='lines+markers', name='Male', line=dict(dash='solid', color='red')))
    
    # Add title and labels
    fig.update_layout(title=f"Total Female and Male Population Over Time ({country_name})",
                      xaxis_title="Year",
                      yaxis_title="Population")
    
    # Show the plot
    fig.show()

#### Search Function to Find Region Acronym

This is another optional function created for easy region searching. Its input is a search country or any part of a country and the output is the list of acronym country pairs containing that input.

In [14]:
def acronymfinder(country): #helpful func to find region acronym by putting in full text
    """
    Finds and prints region acronyms (country IDs) by searching for a matching 
    country name containing the input string.

    This function retrieves a list of countries and their corresponding region 
    acronyms (IDs) using the `wbdata.get_countries()` method. It then searches 
    for country names that contain the given input string (case-insensitive) 
    and prints the matching country names along with their acronyms.

    Parameters:
        country (str): The substring to search for in country names.

    Returns:
        str: A message indicating no matching countries were found if no matches exist.
        None: If matches are found, it prints the matching country names and acronyms 
              and does not return a value.
    
    Example:
        acronymfinder("United")
        # Output:
        # United Arab Emirates: ARE
        # United Kingdom: GBR
        # United States: USA
        
        result = acronymfinder("XYZ")
        # Output: 'no matching countries, please try a different input'
        print(result)
        # Output: 'no matching countries, please try a different input'
    """
    country_dict = wbdata.get_countries()
    matchedacronym = ""
    country_mapping = {country['name']: country['id'] for country in country_dict}
    output = False
    for country_name in country_mapping.keys():
        if country.lower() in country_name.lower():
            print(f"{country_name}: {country_mapping[country_name]}")
            output = True
    if not output:
        return 'no matching countries, please try a different input'


In [12]:
# example usage
acronymfinder('united')

United Arab Emirates: ARE
United Kingdom: GBR
United States: USA
