In [3]:
from dash import Dash, dcc, html
import plotly.graph_objects as go
import pandas as pd
from dash.dependencies import Input, Output
from plotly.subplots import make_subplots
import networkx as nx
import plotly.graph_objects as go
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import plotly.express as px


mainviz = pd.read_csv('pt2020.csv', delimiter=';')

# Normalize data
mainviz['N_Distance (Kms)'] = (mainviz['Distance (Kms)'] - mainviz['Distance (Kms)'].min()) / (mainviz['Distance (Kms)'].max() - mainviz['Distance (Kms)'].min())
mainviz['N_Gravity Index'] = (mainviz['Gravity Index'] - mainviz['Gravity Index'].min()) / (mainviz['Gravity Index'].max() - mainviz['Gravity Index'].min())

# Get the list of data from the 'Partner Namee' column
partner_names = mainviz['Partner Namee'].tolist()

# Get the country list
unique_partner_names = list(set(partner_names))

detailed = pd.read_csv('detailed.csv', delimiter=';')
# print(detailed.columns)
# Strip leading and trailing spaces from column names
detailed.columns = detailed.columns.str.strip()

# Strip leading and trailing spaces from 'Product Group' values
detailed['Product Group'] = detailed['Product Group'].str.strip()

# Drop rows where 'Product Group' is 'All Products'
detailed = detailed.loc[detailed['Product Group'] != 'All Products']
detailed2 = pd.read_csv('detailed.csv', delimiter=';')

# Strip leading and trailing spaces from 'Product Group' values
detailed2['Product Group'] = detailed2['Product Group'].str.strip()
detailed2 = detailed2[detailed2['Product Group'] == 'All Products']

df = pd.read_csv('tradeflows.csv', delimiter=';')
growthdf = pd.read_csv('growth.csv', delimiter=';')
growthdf['date'] = pd.to_datetime(growthdf['date'], format='%d/%m/%y')

# Sort the DataFrame by 'date'
growthdf = growthdf.sort_values('date')

In [4]:
# http://127.0.0.1:8050/

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import pandas as pd
import plotly.graph_objects as go


# Constants
MIN_YEAR = 2016
MAX_YEAR = 2020
STEP = 1
INITIAL_VALUE = 2016
SLIDER_MARKS = {i: str(i) for i in range(MIN_YEAR, MAX_YEAR + 1)}

app = dash.Dash(__name__, suppress_callback_exceptions=True)

app.layout = html.Div([
    html.Div([
        html.H1('Visualizing Global Commerce through Gravitational Principles', style={'textAlign': 'center', 'fontSize': '40px', 'fontFamily': 'Arial', 'color':'#373659'})
    ]),
    dcc.Location(id='url', refresh=False),

    html.Div([
        html.Div([
            dcc.Dropdown(
                id='dropdown-input',
                options=[
                    {'label': i, 'value': i} for i in mainviz['Reporter Namee'].unique()
                ],
                value=mainviz['Reporter Namee'].unique()[0],
                placeholder='Select a Country',
                style={'width': '155px', 'height': '37px', 'border-radius': '10px', 'background-color': 'white', 'box-shadow': 'px 2px 4px rgba(0, 0, 0, 0.1)', 'box-shadow': '2px 2px 4px rgba(0, 0, 0, 0.05)', 'cursor':'pointer', 'margin-right': '20px', 'text-align': 'center', 'font-family': 'sans-serif', 'padding': '0px 0px'}
            ),
        ]),
        html.Div([
            html.Button('Gravity Model', id='index-button', style={'margin-right': '10px', 'width': '150px', 'height': '37px', 'font-size': '16px','border-radius': '10px', 'background-color': 'white', 'box-shadow': '2px 2px 4px rgba(0, 0, 0, 0.1)', 'border': '1px solid #CCCCCC', 'cursor':'pointer',}),
            html.Button('Country Summary', id='page-1-button', style={'margin-left': '10px','width': '190px', 'height': '37px', 'font-size': '16px', 'border-radius': '10px', 'background-color': 'white', 'box-shadow': '2px 2px 4px rgba(0, 0, 0, 0.1)', 'border': '1px solid #CCCCCC', 'cursor':'pointer'}),
        ]),
    ], style={'display': 'flex', 'justify-content': 'center'}),

    html.Div(id='page-content'),
    html.Div([
        dcc.Slider(
            id='year-slider',
            min=MIN_YEAR,
            max=MAX_YEAR,
            step=STEP,
            value=INITIAL_VALUE,
            marks=SLIDER_MARKS,
            included=False
        )
    ], style={'width': '30%', 'margin': '0 auto', 'margin-bottom': '50px'})  ,
        html.Div([
        html.Div([        
            html.P('Masters in Engineering and Data Science | Advanced Data Visualization | 2023/2024'),
            html.P('José Namora Dias #2023164985 | Pedro Arsénio Costa #2020242456'),
        ], style={'fontSize':'75%','font-family': 'sans-serif', 'line-height': '0.25', 'color': 'white', 'text-align':'center'}),
    ], style={'margin':'auto', 'background-color':'#373659', 'padding': '2.5px'}),
])


################## LANDING PAGE ##################
index_page = html.Div([
    html.Div([
        dcc.Graph( # advanced-viz
            id='advanced-viz',
            config={
                'responsive': True
            },
            style={'height': '1200px'}
        ),
    ], style={'width': '70%', 'display': 'inline-block', 'vertical-align': 'top'}),

    html.Div([
        html.Div([
            dcc.Graph(id='donut', config={'autosizable': True}),  # donut chart
        ], style={'height': '380px'}),  
        html.Div([
            dcc.Graph(id='treemap1', config={'autosizable': True}),  # Exports treemap
        ], style={'height': '340px'}),  
        html.Div([
            dcc.Graph(id='treemap2', config={'autosizable': True})   # Imports treemap
        ], style={'height': '320px'})  
    ], style={'width': '30%', 'display': 'inline-block', 'vertical-align': 'top'})
])


################## ADVANCED VIZ ##################
@app.callback(
    Output('advanced-viz', 'figure'),
    [Input('year-slider', 'value'),
     Input('dropdown-input', 'value')]
)
def update_chart(year, search_term):
    filtered_data = mainviz[mainviz['Year'] == year]
    if search_term:
        filtered_data = filtered_data[filtered_data['Reporter Namee'].str.contains(search_term)]

    G = nx.from_pandas_edgelist(filtered_data, 'Reporter Namee', 'Partner Namee', edge_attr=['N_Gravity Index', 'Gravity Index', 'Imports', 'Exports', 'Distance (Kms)', 'GDP per capita Importer'], create_using=nx.Graph())

    for index, row in mainviz.iterrows():
        # Get the partner name
        partner_name = row['Partner Namee']

        if partner_name in G.nodes():
            # Update the node's attributes
            G.nodes[partner_name]['Distance (Kms)'] = row['Distance (Kms)']
            G.nodes[partner_name]['Imports'] = row['Imports']
            G.nodes[partner_name]['Exports'] = row['Exports']
            G.nodes[partner_name]['GDP per capita Importer'] = row['GDP per capita Importer']

    # Define positions
    pos = {
        'Portugal': (0, 0),
        'Spain': (1, -1.5),
        'France': (1.3, 2.5),
        'Brazil': (-6, -4),
        'Italy': (2.1, -0.5),
        'Poland': (4, 3),
        'United Kingdom': (0.1, 5.5),
        'United States': (-6, 0),
        'Angola': (0, -5),
        'Germany': (3, 1),
        'Netherlands': (2.5, 3),
        'Belgium': (1, 4),
    }

    
    for node in G.nodes():
        if node not in pos:
            pos[node] = (0, 0)  # Default position


    # Get the sizes for the nodes based on 'GDP per capita Importer'
    sizes = filtered_data.groupby('Partner Namee')['GDP per capita Importer'].mean().to_dict()

    # Get the GDP per capita for Portugal from the first row of the dataset
    gdp_per_capita_portugal = filtered_data['GDP per capita Portugal'].iloc[0]

    # Create a list of sizes for each node
    node_sizes = [gdp_per_capita_portugal if node == 'Portugal' else sizes.get(node, 0) for node in G.nodes()]

    # Normalize the node sizes to a suitable range
    scaler = MinMaxScaler((20, 135))  
    node_sizes = scaler.fit_transform(np.array(node_sizes).reshape(-1, 1))

    # Create a list of additional information for each node
    node_info = [[G.nodes[node].get('Imports', 0),
                  G.nodes[node].get('Exports', 0),
                  G.nodes[node].get('Distance (Kms)', 0),
                  sizes.get(node, 0)] for node in G.nodes()] 

    node_x = [pos[node][0] for node in G.nodes()]
    node_y = [pos[node][1] for node in G.nodes()]

    node_labels = list(G.nodes())

    node_trace = go.Scatter(x=node_x, y=node_y, mode='markers+text', text=node_labels, textposition='top center',
                            marker=dict(size=node_sizes.flatten(), sizemode='diameter', opacity=1, color='#373659'),
                            customdata=node_info,
                            name='Country',
                            showlegend=False,
                            hovertemplate='Node: %{text}<br>' +
                                          'GDP per capita($): %{customdata[3]:,.0f}<br>' +  
                                          'Imports($): %{customdata[0]:,.0f}<br>' + 
                                          'Exports($): %{customdata[1]:,.0f}<br>' + 
                                          'Distance(Km): %{customdata[2]:,.0f}<extra></extra>')
    # Normalize the edge widths to a suitable range
    widths = nx.get_edge_attributes(G, 'N_Gravity Index')
    scaler = MinMaxScaler((1, 10))  
    edge_widths = scaler.fit_transform(np.array(list(widths.values())).reshape(-1, 1))

    # Calculate the maximum and minimum widths
    max_width = max(edge_widths)
    min_width = min(edge_widths)

    # Create a scatter plot for each edge with its own width
    edge_traces = []
    for i, edge in enumerate(G.edges(data=True)):
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        width = edge_widths[i].item()  
        imports = edge[2]['Imports']
        exports = edge[2]['Exports']
        distance = edge[2]['Distance (Kms)']
        gdp = edge[2]['GDP per capita Importer']
        edge_trace = go.Scatter(x=[x0, x1], y=[y0, y1], line=dict(width=width, color='#888'), mode='lines',
                                hovertemplate='Gravity Index: {}<br>Imports: {}<br>Exports: {}<br>Distance: {}<br>GDP per capita Importer: {}<extra></extra>'.format(width, imports, exports, distance, gdp),
                                showlegend=True if width in [min_width, max_width] else False,  # Show legend only for min and max widths
                                name='Gravitational Force: {}'.format(round(width)))  
        edge_traces.append(edge_trace)


    layout = go.Layout(showlegend=True,
                       title=dict(text=f"Portugal Trade Gravity Model for {year}", x=0.5, ),
                       hovermode='closest',
                       hoverlabel=dict(
                           bgcolor="white",
                           font_size=16,
                       ),
                       xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                       yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                       plot_bgcolor='rgba(0,0,0,0)',
                       paper_bgcolor='rgba(0,0,0,0)',
                       legend=dict(
                           x=1,
                           y=1,
                           traceorder="normal",
                           font=dict(
                               family="sans-serif",
                               size=12,
                               color="black"
                           ),
                           bgcolor=None, 
                           bordercolor=None, 
                           borderwidth=0
                       ))
    
    fig = go.Figure(data=[*edge_traces, node_trace], layout=layout)
    return fig


################## Donut ##################
@app.callback(
    Output('donut', 'figure'),
    [Input('advanced-viz', 'clickData'),
     Input('year-slider', 'value')]
)
def update_chart(clickData, slider_value):
    country = clickData['points'][0]['text'] if clickData else 'World'  # Default to 'World' if no country is clicked
    year = slider_value
    df_filtered3 = detailed2[(detailed2['Year'] == year) & (detailed2['Partner Name'] == country)]
    fig = go.Figure(data=[go.Pie(labels=['Export (US$ Thousand)', 'Import (US$ Thousand)'], 
                                 values=[df_filtered3['Export (US$ Thousand)'].sum(), df_filtered3['Import (US$ Thousand)'].sum()], 
                                 hole=.4, 
                                 marker_colors=['rgba(144,238,144,0.9)', 'rgba(240,128,128,0.9)'])])
    fig.update_layout(title_text=f"{year} Trade Balance for {country}")
    return fig


################## Treemaps ##################
#EXPORTS
@app.callback(
    Output('treemap1', 'figure'),
    [Input('advanced-viz', 'clickData'),
     Input('year-slider', 'value')]
)
def update_exports_chart(clickData, year):
    pastel_palette = ['#f29191', '#ffcb77', '#373659', '#07a0c3', '#a7f1a7']
    country = clickData['points'][0]['text'] if clickData else 'World'  # Default to 'World' if no country is clicked

    df_filtered1 = detailed[(detailed['Partner Name'] == country) & (detailed['Year'] == year)]

    df_filtered1 = df_filtered1.sort_values('Export (US$ Thousand)')

    colors = [pastel_palette[i % len(pastel_palette)] for i in range(len(df_filtered1))]
    text_values = df_filtered1['Export (US$ Thousand)'].apply(lambda x: '{:.1f} MM $'.format(x / 1000000) if x / 1000000 < 1 else '{:.2f} MM $'.format(x / 1000000))

    fig1 = go.Figure(data=[go.Treemap(
        labels=df_filtered1['Product Group'], 
        values=df_filtered1['Export (US$ Thousand)'],
        text=text_values,  
        parents=[""]*len(df_filtered1),
        marker=dict(colors=colors), 
    )])
    fig1.update_layout(title_text=f"Exports from Portugal to {country}", title_x=0.5, title_y=0.83)
    return fig1

#IMPORTS
@app.callback(
    Output('treemap2', 'figure'),
    [Input('advanced-viz', 'clickData'),
     Input('year-slider', 'value')]
)
def update_imports_chart(clickData, year):
    pastel_palette = ['#07a0c3', '#a7f1a7', '#ffcb77', '#373659', '#f29191']
    country = clickData['points'][0]['text'] if clickData else 'World'  # Default to 'World' if no country is clicked

    df_filtered2 = detailed[(detailed['Partner Name'] == country) & (detailed['Year'] == year)]

    df_filtered2 = df_filtered2.sort_values('Import (US$ Thousand)')

    colors = [pastel_palette[i % len(pastel_palette)] for i in range(len(df_filtered2))]
    text_values = df_filtered2['Import (US$ Thousand)'].apply(lambda x: '{:.1f} MM $'.format(x / 1000000) if x / 1000000 < 1 else '{:.2f} MM $'.format(x / 1000000))

    fig2 = go.Figure(data=[go.Treemap(
        labels=df_filtered2['Product Group'], 
        values=df_filtered2['Import (US$ Thousand)'],
        text=text_values,  
        parents=[""]*len(df_filtered2),
        marker=dict(colors=colors), 
    )])
    fig2.update_layout(title_text=f"Imports of Portugal from {country} ", title_x=0.5,  title_y=0.83)
    return fig2



################## PAGE 1 ##################
page_1_layout = html.Div([
    dcc.Graph(id='sankey-diagram'),
    dcc.Graph(id='growth-rate-graph')  
])

@app.callback(Output('url', 'pathname'),
              [Input('page-1-button', 'n_clicks')],
              [Input('index-button', 'n_clicks')])
def navigate(n1, n2):
    ctx = dash.callback_context
    if not ctx.triggered:
        return '/'
    else:
        button_id = ctx.triggered[0]['prop_id'].split('.')[0]
        if button_id == 'page-1-button':
            return '/page-1'
        else:
            return '/'

@app.callback(Output('page-content', 'children'),
              [Input('url', 'pathname')])
def display_page(pathname):
    if pathname == '/page-1':
        return page_1_layout
    else:
        return index_page

################## SANKEY DIAGRAM ##################
@app.callback(
    Output('sankey-diagram', 'figure'),
    [Input('year-slider', 'value')]
)
def update_figure(selected_year):
    
    filtered_df = df[df['year'] == selected_year]

    df1 = filtered_df[['source', 'target', 'value', 'type1']].copy()
    df2 = filtered_df[['source2', 'target2', 'value2', 'type2']].copy()

    df1.columns = ['source', 'target', 'value', 'type']
    df2.columns = ['source', 'target', 'value', 'type']

    df1.loc[df1['source'] == 'Portugal Exports', 'type'] = 'Exports'
    df2.loc[df2['source'] == 'Portugal Imports', 'type'] = 'Imports'

    df_combined = pd.concat([df1, df2]).reset_index(drop=True)

    labels = pd.concat([df_combined['source'], df_combined['target']]).unique()

    label_dict = {label: i for i, label in enumerate(labels)}

    df_combined['source'] = df_combined['source'].map(label_dict)
    df_combined['target'] = df_combined['target'].map(label_dict)

    colors = ['rgba(144,238,144,0.5)' if type == 'Exports' else 'rgba(240,128,128,0.5)' for type in df_combined['type'].tolist()]

    # Create the Sankey diagram
    fig = go.Figure(data=[go.Sankey(
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(color = "black", width = 0.5),
          label = labels,
          color = "gray"
        ),
        link = dict(
          source = df_combined['source'].values,
          target = df_combined['target'].values,
          value = df_combined['value'].values,
          color = colors
        )
    )])

    fig.update_layout(title_text=f"Portugal Trade Flows for {selected_year}", title_x=0.5)
    return fig

################## TRADE BALANCE ##################
@app.callback(
    Output('growth-rate-graph', 'figure'),
    [Input('year-slider', 'value')]
)

def update_figure(selected_year):
    growthdf['year'] = pd.DatetimeIndex(growthdf['date']).year
    filtered_growthdf = growthdf[growthdf['year'] == selected_year].copy()
    trace1 = go.Bar(x=filtered_growthdf['date'], y=filtered_growthdf['X'], name='Exports', marker=dict(color='mediumseagreen'))
    trace2 = go.Bar(x=filtered_growthdf['date'], y=-filtered_growthdf['M'], name='Imports', marker=dict(color='salmon'))
    trace3 = go.Scatter(x=filtered_growthdf['date'], y=[0]*len(filtered_growthdf['date']), mode='lines', name='Baseline', line=dict(color='white'),showlegend=False)
    trace4 = go.Scatter(x=filtered_growthdf['date'], y=filtered_growthdf['Growth Rate'], mode='lines+markers+text', name='Growth Rate', 
                        line=dict(color='#F29191'), fill='tozeroy', text=filtered_growthdf['Growth Rate'].round(2), 
                        textposition='top center', textfont=dict(size=8), 
                        texttemplate='%{text:.2f}%', hovertemplate='%{y:.2f}%')  # Include percentage sign

    fig2 = make_subplots(rows=1, cols=2, subplot_titles=('Trade Balance', 'Growth Rate'))

    fig2.add_trace(trace1, row=1, col=1)
    fig2.add_trace(trace2, row=1, col=1)
    fig2.add_trace(trace3, row=1, col=2)
    fig2.add_trace(trace4, row=1, col=2)

    fig2.update_xaxes(title_text='Date', row=1, col=1)
    fig2.update_yaxes(title_text='Values', row=1, col=1)
    fig2.update_xaxes(title_text='Date', row=1, col=2)
    fig2.update_yaxes(title_text='Growth Rate (%)', row=1, col=2)  

    fig2.update_layout(
        barmode='overlay',
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)'
    )

    return fig2



if __name__ == '__main__':
    app.run_server(debug=True)