In [2]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# Load data
df = pd.read_csv("./owid-covid-data.csv")
df['date'] = pd.to_datetime(df['date'])
df = df[~df['location'].isin([
    'World', 'Asia', 'Africa', 'Europe', 'High-income countries',
    'Upper-middle-income countries', 'Lower-middle-income countries',
    'Low-income countries', 'North America', 'South America', 'European Union (27)'
])]
countries = df['location'].unique()
years = sorted(df['date'].dt.year.unique())

def get_color_map(selected_countries):
    palette = px.colors.qualitative.Plotly
    return {country: palette[i % len(palette)] for i, country in enumerate(selected_countries)}

def update_line_plot(selected_countries, selected_metric, year_range):
    lower, upper = year_range
    dff = df[
        df['location'].isin(selected_countries) &
        (df['date'].dt.year >= lower) &
        (df['date'].dt.year <= upper)
    ].copy()
    dff['year'] = dff['date'].dt.year
    metric_by_year = dff.groupby(['location', 'year'])[selected_metric].max().reset_index()

    color_map = get_color_map(selected_countries)

    # Average for ALL countries (in the selected year range)
    df_all = df[(df['date'].dt.year >= lower) & (df['date'].dt.year <= upper)].copy()
    df_all['year'] = df_all['date'].dt.year
    metric_all = df_all.groupby(['location', 'year'])[selected_metric].sum().reset_index()
    avg_by_year = metric_all.groupby('year')[selected_metric].mean().reset_index()

    fig = px.line(
        metric_by_year,
        x='year',
        y=selected_metric,
        color='location',
        markers=True,
        color_discrete_map=color_map,
        title='',
        labels={selected_metric: selected_metric.replace("_", " ").title(), 'year': 'Year', 'location': 'Country'},
        template="simple_white"
    )
    # Add average line (all countries)
    fig.add_trace(
        go.Scatter(
            x=avg_by_year['year'],
            y=avg_by_year[selected_metric],
            mode='lines+markers',
            name='Average (All Countries)',
            line=dict(color='black', width=3, dash='dash')
        )
    )
    fig.update_layout(
        showlegend=False,
        xaxis_title=None,
        yaxis_title=None
    )
    return fig

# Example usage in notebook:
selected_countries = ['Poland']
selected_metric = 'total_deaths'
year_range = [2020, 2024]
fig = update_line_plot(selected_countries, selected_metric, year_range)
fig.show()