# Import Libraries

In [None]:
from dash import Dash, html, dcc, callback, Output, Input, State
import dash_bootstrap_components as dbc
import plotly.express as px
import plotly.graph_objects as go
from dash.exceptions import PreventUpdate

import pandas as pd
import numpy as np

import json

# Initialize App 

In [None]:
app = Dash(
    __name__,
    external_stylesheets=[
        dbc.themes.BOOTSTRAP,
        dbc.icons.FONT_AWESOME
    ]
)

# Load Data

In [None]:
athlete_events_df = pd.read_csv("./assets/athlete_events.csv")
noc_regions_df = pd.read_csv("./assets/noc_regions.csv")

# Total Medal Count by Height and Weight, grouped by Sport (Heatmap Figure)

In [None]:
# Group medal counts by height and weight
height_weight_medal_df = athlete_events_df.dropna(subset=["Height", "Weight"])
height_weight_medal_df = height_weight_medal_df.groupby(["Sport", "Height", "Weight"])["Medal"].count().reset_index(name="Medal Count")

# List of medal counts by height and weight for each sport
sports_df_list = {col:pd.pivot_table(sport, index="Height", columns="Weight", values="Medal Count").fillna(0) for (col, sport) in height_weight_medal_df.groupby('Sport')}

# Heatmap figure
heatmap_fig = go.Figure(
    data=go.Heatmap(
        x=sports_df_list["Alpine Skiing"].columns,
        y=sports_df_list["Alpine Skiing"].index,
        z=sports_df_list["Alpine Skiing"].values,
        visible=True,
        hovertemplate=
            "<b>%{x}cm-%{y}kg</b><br>" +
            "Total Medals Won: %{z}<br>"
    )
)
heatmap_fig.update_layout(
    title="Total Medals Won by Height and Weight in Alpine Skiing"
)

# Add heatmap traces for each sport
for key in sports_df_list.keys():
    heatmap_fig.add_traces(
        data=go.Heatmap(
            x=sports_df_list[key].columns,
            y=sports_df_list[key].index,
            z=sports_df_list[key].values,
            visible=False,
            hovertemplate=
                "<b>%{x}kg-%{y}cm</b><br>" +
                "Total Medals Won: %{z}<br>"
        )
    )

# Create dropdown option for each sport
sports_options = []
for i, key in enumerate(sports_df_list.keys()):
    visible_list = [False] * len(sports_df_list.keys())
    visible_list[i] = True
    
    sports_options.append(dict(
        label=key,
        method="update",
        args=[
            {"visible":visible_list},
            {"title":f"Total Medals Won by Height and Weight in {key}"}
        ]
    ))
heatmap_fig.update_layout(
    width=800,
    height=600,
    autosize=False,
    xaxis=dict(title="Weight (kg)"),
    yaxis=dict(title="Height (cm)"),
    coloraxis=dict(
        colorbar=dict(
            title="Medals Won"
        )
    ),
    updatemenus=[dict(
        active=0,
        buttons=sports_options
    )]
)

heatmap_fig.show()

# Interactive Map (Choropleth Figure)

In [None]:
# Group total medals won by NOC, removing duplicates from individuals winning medals from team sports to only count one medal
medals_country_df = athlete_events_df.dropna(subset=["Medal"])
medals_country_df = medals_country_df.drop_duplicates(subset=["NOC", "Games", "Year", "Season", "City", "Sport", "Event", "Medal"])
medals_country_df = medals_country_df.merge(noc_regions_df, on="NOC", how="left")
medals_country_df = medals_country_df.groupby(["region", "Medal"])["Medal"].count().unstack(fill_value=0).stack().reset_index(name="Medal Count")
medals_country_df = medals_country_df.groupby(["region"])["Medal Count"].sum().reset_index(name="Total Medals")

# Choropleth figure
choropleth_fig = px.choropleth(
    medals_country_df,
    title="Total Medals Won by Country",
    locations="region",
    locationmode="country names",
    hover_name="region",
    hover_data=dict(
        region=False
    ),
    color="Total Medals",
    range_color=[0, 500],
    color_continuous_scale="Viridis",
    projection="natural earth"
)

choropleth_fig.show()

# Country Data Drill-In

In [None]:
# Group total medals won by NOC, removing duplicates from individuals winning medals from team sports to only count one medal
medals_distribution_df = athlete_events_df.dropna(subset=["Medal"])
medals_distribution_df = medals_distribution_df.drop_duplicates(subset=["NOC", "Games", "Year", "Season", "City", "Sport", "Event", "Medal"])
medals_distribution_df = medals_distribution_df.merge(noc_regions_df, on="NOC", how="left")
medals_distribution_df = medals_distribution_df.groupby(["region", "Medal"])["Medal"].count().unstack(fill_value=0).stack().reset_index(name="Medal Count")

# List of medals won by NOC
noc_df_list = {col:noc for (col, noc) in medals_distribution_df.groupby("region")}

# Bar figure
medals_fig_list = {}
for key, noc in noc_df_list.items():
    medals_fig_list[key] = go.Figure(
        data=go.Bar(
            x=noc["Medal"],
            y=noc["Medal Count"]
        )
    )
medals_fig_list

# Dash Layout

In [None]:
choropleth_fig.update_layout(clickmode='event+select')

app.layout = html.Div([
    dcc.Graph(
        id="choropleth",
        figure=choropleth_fig
    ),
    dcc.Graph(
        id="heatmap",
        figure=heatmap_fig
    ),
    dbc.Modal([
        dbc.ModalHeader(dbc.ModalTitle(id="modal-text", children=[
            "Placeholder"
        ])),
        dcc.Graph(
            id="drill-in",
            figure=medals_fig_list["USA"]
        )],
        id="modal-sm",
        size="sm",
        is_open=False,
    )
])

@callback([
        Output("modal-text", "children"),
        Output("modal-sm", "is_open"),
        Output("drill-in", "figure")
    ],
    [Input("choropleth", "clickData")],
    [
        State("modal-sm", "is_open"),
        State("drill-in", "figure")
    ]
)
def update_modal(clickData, is_open, figure):
    if clickData is None:
        raise PreventUpdate
    if is_open is None:
        raise PreventUpdate
    if figure is None:
        raise PreventUpdate
        
    if clickData:
        print(medals_fig_list[clickData["points"][0]["hovertext"]])
        return clickData["points"][0]["location"], not is_open, medals_fig_list[clickData["points"][0]["hovertext"]]
    return None, is_open, None

# Run Dash App

In [None]:
app.run_server(debug=True)