In [2]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display
from IPython.display import display, HTML

# Load female data
female = pd.read_csv('https://ourworldindata.org/grapher/female-population-by-age-group.csv')
female.columns = female.columns.str.replace('Population - Sex: female - Age: ', '')
female.columns = female.columns.str.replace(' - Variant: estimates', '')

# Load male data
male = pd.read_csv('https://ourworldindata.org/grapher/male-population-by-age-group.csv')
male.columns = male.columns.str.replace('Population - Sex: male - Age: ', '')
male.columns = male.columns.str.replace(' - Variant: estimates', '')

# Get unique years and countries for dropdown options
years = sorted(female['Year'].unique().tolist(), reverse=True)
countries = sorted(female['Entity'].unique().tolist())

# Create dropdown widgets
year_dropdown = widgets.Combobox(
        options=[str(y) for y in years],
        description='Enter Year:',
        value=None,
        ensure_option=True
)

country_dropdown = widgets.Combobox(
        options=countries,
        description='Enter Country:',
        value=None,
        ensure_option=True
)

# Create button
plot_button = widgets.Button(description='Plot Pyramid')

# Create output widgets
plot_output = widgets.Output()
data_output = widgets.Output()

# Define button click handler
def plot_population_pyramid(b):
        plot_output.clear_output(wait=True)
        data_output.clear_output(wait=True)

        year = int(year_dropdown.value)
        country = country_dropdown.value

        # Filter female data
        female_filtered = female[(female['Entity'] == country) & (female['Year'] == year)].reset_index(drop=True)
        female_filtered = female_filtered.drop(columns=['Entity', 'Code', 'Year'])
        female_filtered = female_filtered.melt(var_name='age_range', value_name='value')

        # Filter male data
        male_filtered = male[(male['Entity'] == country) & (male['Year'] == year)].reset_index(drop=True)
        male_filtered = male_filtered.drop(columns=['Entity', 'Code', 'Year'])
        male_filtered = male_filtered.melt(var_name='age_range', value_name='value')

        # Merge data
        population = pd.merge(male_filtered, female_filtered, on='age_range', suffixes=['_male','_female'])
        
        # Display dataframe in data_output
        with data_output:
                display(population)

        # Create population pyramid in plot_output
        with plot_output:
                sns.set_style("whitegrid")
                fig, ax = plt.subplots(figsize=(12, 10))

                # Plot male population (left side, negative values)
                ax.barh(population['age_range'], -population['value_male'], height=0.8,
                                label='Male', color='#4A90E2', alpha=0.8)
                # Add male population values
                for i, value in enumerate(population['value_male']):
                    ax.text(-value, i, f'{value:,}', ha='right', va='center', fontsize=9)


                # Plot female population (right side, positive values)
                ax.barh(population['age_range'], population['value_female'], height=0.8,
                                label='Female', color='#E74C3C', alpha=0.8)
                # Add female population values
                for i, value in enumerate(population['value_female']):
                    ax.text(value, i, f'{value:,}', ha='left', va='center', fontsize=9)


                # Customize the plot
                ax.set_xlabel('Population', fontsize=12, fontweight='bold')
                ax.set_ylabel('Age Range', fontsize=12, fontweight='bold')
                ax.set_title(f'Population Pyramid - {country} ({year})', fontsize=16, fontweight='bold', pad=20)

                # Format x-axis to show absolute values
                max_val = max(population['value_male'].max(), population['value_female'].max())
                ticks = range(0, int(max_val) + 50000, 50000)
                ax.set_xticks([-t for t in ticks] + list(ticks))
                ax.set_xticklabels([f'{abs(t):,}' for t in [-t for t in ticks] + list(ticks)])

                # Add vertical line at zero
                ax.axvline(0, color='black', linewidth=0.8)

                # Add legend
                ax.legend(loc='upper right', fontsize=11)

                # Add grid
                ax.grid(axis='x', alpha=0.3)

                # Tight layout
                plt.tight_layout()
                plt.show()

# Attach handler to button
plot_button.on_click(plot_population_pyramid)

# Display widgets
display(year_dropdown, country_dropdown, plot_button)
# Display plot and dataframe side by side
display(widgets.HBox([plot_output, data_output]))


Combobox(value='', description='Enter Year:', ensure_option=True, options=('2023', '2022', '2021', '2020', '20…

Combobox(value='', description='Enter Country:', ensure_option=True, options=('Afghanistan', 'Africa (UN)', 'A…

Button(description='Plot Pyramid', style=ButtonStyle())

HBox(children=(Output(), Output()))