https://ourworldindata.org/grapher/children-per-woman-fertility-rate-vs-level-of-prosperity

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv(
    "../data/raw/children-per-woman-fertility-rate-vs-level-of-prosperity.csv"
)
df = df.copy()
df = df.dropna(subset=["GDP per capita", "Births per woman"])

In [None]:
gdp_max = df["GDP per capita"].max()
gdp_min = df["GDP per capita"].min()
births_max = df["Births per woman"].max()
births_min = df["Births per woman"].min()
gdp_max, gdp_min, births_max, births_min

In [None]:
import plotly.express as px

latest_year = df["Year"].max()
df_latest = df[df["Year"] == latest_year]

fig = px.scatter(
    df_latest,
    x="GDP per capita",
    y="Births per woman",
    size="Population",
    color="World region according to OWID",
    hover_name="Entity",
    log_x=True,
    size_max=60,
    range_x=[gdp_min * 0.9, gdp_max * 1.1],
    range_y=[0, births_max * 1.1],
    title=f"GDP per Capita vs Births per Woman ({latest_year})",
)
fig.update_layout(height=650)
fig.show()

In [11]:
import plotly.graph_objects as go

countries = ["France", "Israel", "Saudi Arabia", "South Korea"]

# Build the base scatter from the latest year (no deepcopy, no animation artifacts)
fig_country = px.scatter(
    df_latest,
    x="GDP per capita",
    y="Births per woman",
    size="Population",
    color="World region according to OWID",
    hover_name="Entity",
    log_x=True,
    size_max=60,
    range_x=[gdp_min * 0.9, gdp_max * 1.1],
    range_y=[0, births_max * 1.1],
)

fig_country.update_layout(
    title=dict(
        text="GDP per Capita vs Births per Woman — Country Trajectories",
        font=dict(size=22, family="Inter, Arial, sans-serif"),
        x=0.5,
        xanchor="center",
    ),
    xaxis=dict(
        title=dict(text="GDP per capita (USD, log scale)", font=dict(size=14)),
        gridcolor="rgba(200, 200, 200, 0.3)",
        showline=True,
        linewidth=1,
        linecolor="rgba(150, 150, 150, 0.5)",
    ),
    yaxis=dict(
        title=dict(text="Births per woman", font=dict(size=14)),
        gridcolor="rgba(200, 200, 200, 0.3)",
        showline=True,
        linewidth=1,
        linecolor="rgba(150, 150, 150, 0.5)",
    ),
    height=700,
    width=1200,
    plot_bgcolor="white",
    paper_bgcolor="white",
    legend=dict(
        title=dict(text="Region", font=dict(size=13)),
        font=dict(size=11),
        bgcolor="rgba(255, 255, 255, 0.85)",
        bordercolor="rgba(200, 200, 200, 0.5)",
        borderwidth=1,
        yanchor="top",
        y=0.98,
        xanchor="left",
        x=1.02,
        itemsizing="constant",
    ),
    margin=dict(l=60, r=180, t=80, b=60),
)

# Make the background bubble traces slightly transparent
fig_country.update_traces(marker=dict(opacity=0.25))

# Country trajectory colors
trajectory_colors = ["#e63946", "#2a9d8f", "#e9c46a", "#f4a261"]

for country, color in zip(countries, trajectory_colors):
    country_df = df[df["Entity"] == country].dropna(
        subset=["GDP per capita", "Births per woman"]
    )
    fig_country.add_trace(
        go.Scatter(
            x=country_df["GDP per capita"],
            y=country_df["Births per woman"],
            mode="lines+markers",
            name=country,
            legendgroup="trajectories",
            legendgrouptitle=dict(text="Country Trajectories", font=dict(size=13)),
            marker=dict(size=5, color=color),
            line=dict(width=2, color=color),
            customdata=country_df[["Year", "Population"]].values,
            hovertemplate=(
                f"<b>{country}</b><br>"
                "Year: %{customdata[0]}<br>"
                "GDP: %{x:,.0f}<br>"
                "Births: %{y:.2f}<br>"
                "Population: %{customdata[1]:,.0f}<br>"
                "<extra></extra>"
            ),
        )
    )

    # Add start and end year annotations
    first = country_df.iloc[0]
    last = country_df.iloc[-1]
    for row, label in [
        (first, str(int(first["Year"]))),
        (last, str(int(last["Year"]))),
    ]:
        fig_country.add_annotation(
            x=row["GDP per capita"],
            y=row["Births per woman"],
            text=label,
            showarrow=False,
            font=dict(size=9, color=color),
            xshift=12,
            yshift=8,
            xref="x",
            yref="y",
        )

fig_country.write_image(file="plots/gdp-demographic.png", format="png", scale=2)
fig_country.show()