In [None]:
import os
import re
import numpy as np
import pandas as pd

In [None]:
DATA_DIR = '../data/'

In [None]:
os.listdir(DATA_DIR)

### Read IMF data

In [None]:
df = pd.read_parquet(os.path.join(DATA_DIR, 'imf_dot.parq'))

### Read Coordinates lookup from Google
The data was downloaded from https://developers.google.com/public-data/docs/canonical/countries_csv

In [None]:
coord = pd.read_csv(os.path.join(DATA_DIR, 'google_country_coordinates_lookup.csv'))

### Read Country Code provided by IMF

In [None]:
country_code = pd.read_csv(os.path.join(DATA_DIR, 'DOT_03-28-2022 04-15-57-36_timeSeries', 'Metadata_DOT_03-28-2022 04-15-57-36_timeSeries.csv'))

In [None]:
country_code_dict = {}
country_code_iso3 = {}
for i, row in country_code.iloc[11:1835].iterrows():  # only row 11 - 1835 contains useful information
    if 'Country ISO 2 Code' in row['Metadata Attribute']:
        country = row['Country Name']
        code = row['Metadata Value']
        if pd.isna(code):
            print(f'No country code found for country: {country}')
            continue
        country_code_dict[country] = code
    if 'Country ISO 3 Code' in row['Metadata Attribute']:
        country = row['Country Name']
        code = row['Metadata Value']
        if pd.isna(code):
            print(f'No country ISO3 code found for country: {country}')
            continue
        country_code_iso3[country] = code

In [None]:
df['Country Code'] = df['Country Name'].map(country_code_dict)
df['Country Code ISO3'] = df['Country Name'].map(country_code_iso3)

In [None]:
df['Counterpart Country Code'] = df['Counterpart Country Name'].map(country_code_dict)
df['Counterpart Country Code ISO3'] = df['Counterpart Country Name'].map(country_code_iso3)

In [None]:
coord_dict = coord.set_index('country')[['latitude', 'longitude']].to_dict(orient='index')

In [None]:
def get_coordinates(code):
    try:
        return coord_dict[code]['latitude'], coord_dict[code]['longitude']
    except KeyError:
        return None, None

In [None]:
df['latitude'], df['longitude'] = zip(*df['Country Code'].apply(get_coordinates))

In [None]:
df['Counterpart latitude'], df['Counterpart longitude'] = zip(*df['Counterpart Country Code'].apply(get_coordinates))

### Drop NaN for records with empty Country lat long or Counterpart Country lat long
drawing arrows on map requires a pair of coordinates. Cannot be drawn if either one is missing

In [None]:
df = df.dropna(subset=['latitude', 'longitude', 'Country Code', 'Counterpart latitude', 'Counterpart longitude', 'Counterpart Country Code'])

In [None]:
df = df[df['Indicator Name']=='Trade Balance']

In [None]:
df.to_parquet(os.path.join(DATA_DIR, 'processed_for_map.snappy.parquet'))

### Plotting

In [13]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os

In [14]:
DATA_DIR = '../data/'
df = pd.read_parquet(os.path.join(DATA_DIR, 'processed_for_map.snappy.parquet'))

In [15]:
df['Country Name'].unique()

array(['Argentina', 'Afghanistan, Islamic Rep. of', 'Algeria', 'Angola',
       'Anguilla', 'American Samoa', 'Albania', 'Antigua and Barbuda',
       'Australia', 'Bangladesh', 'Armenia, Rep. of', 'Belarus, Rep. of',
       'Bahamas, The', 'Bahrain, Kingdom of', 'Azerbaijan, Rep. of',
       'Barbados', 'Austria', 'Aruba, Kingdom of the Netherlands',
       'Belgium', 'Bhutan', 'Benin', 'Botswana', 'Brazil', 'Belize',
       'Bolivia', 'Bermuda', 'Bosnia and Herzegovina', 'Burkina Faso',
       'Central African Rep.', 'Chad', 'Burundi', 'Bulgaria', 'Canada',
       'Cambodia', 'Cabo Verde', 'Cameroon', 'Brunei Darussalam',
       'Costa Rica', 'China, P.R.: Hong Kong', 'Congo, Dem. Rep. of the',
       'China, P.R.: Mainland', 'China, P.R.: Macao', 'Colombia',
       'Congo, Rep. of', 'Chile', 'Comoros, Union of the',
       'Croatia, Rep. of', 'Djibouti', "Côte d'Ivoire", 'Denmark',
       'Czech Rep.', 'Cyprus', 'Cuba', 'Dominica', 'El Salvador',
       'Egypt, Arab Rep. of', 'Domin

In [111]:
chosen_country = 'China, P.R.: Mainland'
chosen_start_year = 1948
chosen_end_year = 2020
chosen_top_n = 5

line_max_width = 10
line_min_width = 1

In [112]:
country_data = df[df['Country Name'] == chosen_country]

year_columns = [str(year) for year in range(chosen_start_year, chosen_end_year+1)]
country_data['total'] = country_data[year_columns].sum(axis=1)

top_balance = country_data[country_data['Indicator Name']=='Trade Balance'].sort_values(by='total', ascending=False)



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [113]:
country_data = df[df['Country Name'] == chosen_country]

year_columns = [str(year) for year in range(chosen_start_year, chosen_end_year+1)]
country_data['total'] = country_data[year_columns].sum(axis=1)

pos_balance = country_data[country_data['total']>=0].sort_values(by='total', ascending=False)
width = (pos_balance['total'] / pos_balance['total'].max()).to_numpy()
pos_balance['width'] = (width - np.min(width)) / (np.max(width) - np.min(width)) * (line_max_width - line_min_width) + line_min_width

neg_balance = country_data[country_data['total']<0].sort_values(by='total', ascending=True)
width = (neg_balance['total'] / neg_balance['total'].max()).to_numpy()
neg_balance['width'] = (width - np.min(width)) / (np.max(width) - np.min(width)) * (line_max_width - line_min_width) + line_min_width

data = pos_balance.head(chosen_top_n).append(neg_balance.head(chosen_top_n))



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.



In [140]:
fig = go.FigureWidget()

chosen_country_code = data['Country Code ISO3'].unique().tolist()
chosen_country_lat = data['latitude'].iloc[0]
chosen_country_lon = data['longitude'].iloc[0]
chosen_country_color = 'rgb(255, 232, 84)'
# counterpart_country_color = 'rgb(235, 124, 124)'
counterpart_country_color = 'rgb(230, 230, 230)'

# add chosen country
chosen_country_map = go.Choropleth(
    locations=chosen_country_code,
    locationmode='ISO-3',
    z = [1],
    colorscale = [[0, chosen_country_color], [1, chosen_country_color]],
    marker_line_color='white',
    marker_line_width=2,
    colorbar=None,
    showscale=False,
    hovertemplate = f'<b>{chosen_country}</b>'+ f'<extra>{chosen_country_code[0]}</extra>',
    hoverlabel_bgcolor=chosen_country_color
)
fig.add_traces(chosen_country_map)

# add counterpart countries
counterpart_country_map = go.Choropleth(
    locations=data['Counterpart Country Code ISO3'],
    z=[1 for i in range(0, data['Counterpart Country Code ISO3'].shape[0])],
    locationmode='ISO-3',
    colorscale = [[0, counterpart_country_color], [1, counterpart_country_color]],
    text=data['Counterpart Country Name'],
    hovertext=data['Counterpart Country Code ISO3'],
    marker_line_color='white',
    marker_line_width=2,
    autocolorscale=False,
    showscale=False,
    hovertemplate = '<b>%{text}</b><br><extra>%{hovertext}</extra>',
    hoverlabel_bgcolor=counterpart_country_color
)
fig.add_trace(counterpart_country_map)

for i, row in data.iterrows():
    
    # add arrows
    fig.add_trace(
        go.Scattergeo(
            locationmode = 'ISO-3',
            lon = [row['longitude'], row['Counterpart longitude']],
            lat = [row['latitude'], row['Counterpart latitude']],
            mode = 'lines',
            line = dict(
                width = row['width'],
                color = 'rgb(45,237,28)' if row['total']>=0 else 'rgb(254,2,1)'
            ),
            hovertemplate = f'Counterpart Country: {row["Counterpart Country Name"]}<br>' + row['Indicator Name'] + f': {row["total"]}<extra></extra>',
        )
    )

fig.update_traces(showlegend=False)
fig.update_layout(
    autosize=False,
    margin=dict(
        l=0,
        r=0,
        b=5,
        t=50,
        pad=0,
        autoexpand=False
    ),
    width=900,
    height=600,
    hoverlabel_align = 'right',
    title_text = f"Top {chosen_top_n} & Bottom {chosen_top_n} Trade Balances of {chosen_country} from {chosen_start_year} to {chosen_end_year}",
    template='plotly_dark',
)

# fig.update_traces(, selector=dict(type='choropleth'))

fig.update_geos(
#     showcountries=True,
#     countrycolor='grey',
#     countrywidth=2,
    landcolor='rgb(51, 51, 51)',
    projection_type="orthographic",
    center=dict(lon=chosen_country_lon, lat=chosen_country_lat),
    projection_rotation=dict(lon=chosen_country_lon, lat=chosen_country_lat, roll=0)
)

# create our callback function
def update_point(trace, points, selector):
    print(trace)

fig.data[0].on_click(update_point)
fig.data[1].on_click(update_point)

fig

FigureWidget({
    'data': [{'colorscale': [[0, 'rgb(255, 232, 84)'], [1, 'rgb(255, 232, 84)']],
             …

Choropleth({
    'colorscale': [[0, 'rgb(255, 232, 84)'], [1, 'rgb(255, 232, 84)']],
    'hoverlabel': {'bgcolor': 'rgb(255, 232, 84)'},
    'hovertemplate': '<b>China, P.R.: Mainland</b><extra>CHN</extra>',
    'locationmode': 'ISO-3',
    'locations': [CHN],
    'marker': {'line': {'color': 'white', 'width': 2}},
    'showlegend': False,
    'showscale': False,
    'uid': '698d88c8-c00a-4d1f-affd-4c8b2a50872a',
    'z': [1]
})
Choropleth({
    'autocolorscale': False,
    'colorscale': [[0, 'rgb(230, 230, 230)'], [1, 'rgb(230, 230, 230)']],
    'hoverlabel': {'bgcolor': 'rgb(230, 230, 230)'},
    'hovertemplate': '<b>%{text}</b><br><extra>%{hovertext}</extra>',
    'hovertext': array(['HKG', 'USA', 'NLD', 'GBR', 'IND', 'TWN', 'KOR', 'AUS', 'JPN', 'BRA'],
                       dtype=object),
    'locationmode': 'ISO-3',
    'locations': array(['HKG', 'USA', 'NLD', 'GBR', 'IND', 'TWN', 'KOR', 'AUS', 'JPN', 'BRA'],
                       dtype=object),
    'marker': {'line': {'color': 