In [1]:
import numpy as np
import pandas as pd
import json
import plotly.express as px
import plotly.graph_objects as go
from matplotlib import cm
import matplotlib as mpl
import matplotlib.pyplot as plt

import dash
from dash.dependencies import Input, Output
from dash import dcc
from dash import html


# load data

In [2]:
n_countries, n_years = 185, 18
df1_raw = pd.read_csv('International_maternal_mortality_country_timeseries.csv')
df1_raw.loc[df1_raw.Country=='United States','Country'] = 'US'
df1_raw.loc[df1_raw.Country=='United Kingdom','Country'] = 'UK'
df1_raw.loc[df1_raw.Country=='Russian Federation','Country'] = 'Russia'
df1_raw.loc[df1_raw.Country=='South Africa','Country'] = 'S. Africa'
df1_raw.loc[df1_raw.Country=='Iran (Islamic Republic of)','Country'] = 'Iran'
df1_raw['pct change'] = (df1_raw['2017 MMR per 100000'] - df1_raw['2000 MMR per 100000'])/df1_raw['2000 MMR per 100000']

df1 = pd.DataFrame()
df1['Year'] = np.array([[x[:4] for x in df1_raw.columns[2:20].values]*n_countries]).flatten()

df1['Country'] = np.array([[x]*n_years for x in df1_raw['Country']]).flatten()
df1['Country ID'] = np.array([[i]*n_years for i in range(n_countries)]).flatten()

df1['mortality'] = np.array([[df1_raw.values[row_num, col_num+2] for col_num in range(n_years)] for row_num in range(n_countries)]).flatten()
df1['Relative change in mortality'] = np.array([[df1_raw.values[row_num, col_num+2]/df1_raw.values[row_num, 2] for col_num in range(n_years)] for row_num in range(n_countries)]).flatten() - 1

get_sorted_id = lambda arr, i : np.where(np.argsort(arr)==i)[0][0]
get_sorted_ids = lambda arr : [get_sorted_id(arr, i) for i in df1['Country ID']]
get_sorted_ids_feature = lambda feature : get_sorted_ids(df1_raw[feature])
df1['Sorted by 2000 mortality'] = get_sorted_ids_feature('2000 MMR per 100000')
df1['Sorted by 2017 mortality'] = get_sorted_ids_feature('2017 MMR per 100000')
df1['Sorted by pct change in mortality'] = get_sorted_ids_feature('pct change')

df1

Unnamed: 0,Year,Country,Country ID,mortality,Relative change in mortality,Sorted by 2000 mortality,Sorted by 2017 mortality,Sorted by pct change in mortality
0,2000,Afghanistan,0,1450,0.000000,182,175,36
1,2001,Afghanistan,0,1390,-0.041379,182,175,36
2,2002,Afghanistan,0,1300,-0.103448,182,175,36
3,2003,Afghanistan,0,1240,-0.144828,182,175,36
4,2004,Afghanistan,0,1180,-0.186207,182,175,36
...,...,...,...,...,...,...,...,...
3325,2013,Zimbabwe,184,509,-0.120898,154,162,144
3326,2014,Zimbabwe,184,494,-0.146805,154,162,144
3327,2015,Zimbabwe,184,480,-0.170984,154,162,144
3328,2016,Zimbabwe,184,468,-0.191710,154,162,144


# custom color maps

In [3]:
def rgba_arr_to_str(r):
    reslist = np.zeros(4)
    reslist[:3] = np.rint(r[:3]*255)
    reslist[3] = r[3]
    res = f'rgba({int(reslist[0])},{int(reslist[1])},{int(reslist[2])},{reslist[3]})'
    return res

def get_custom_rgbas(feature, countries_to_highlight, opac_low = 0.5, opac_high = 1):
    reds_cmap = plt.get_cmap('Reds')
    norm = mpl.colors.Normalize(vmin=df1[feature].min(), vmax=df1[feature].max())
    scalarMap = cm.ScalarMappable(norm=norm, cmap=reds_cmap)
    
    df1['opac'] = opac_low*np.ones(len(df1))
    df1.at[[x in countries_to_highlight for x in df1.Country], 'opac'] = opac_high
    
    mortality_rgbas = scalarMap.to_rgba(df1[feature]*.5 + .25*df1[feature].max())
    mortality_rgbas[:,3]= df1['opac']
    
    return [rgba_arr_to_str(mortality_rgbas[row_num,:]) for row_num in range(mortality_rgbas.shape[0])]


# create figures

In [4]:
subtitle = 'Scroll to zoom, drag to rotate, ctrl+drag to shift'

### Scroll to zoom
### Drag to rotate
### ctrl+drag to shift

In [8]:
fig1 = go.Figure()
sorted_feature = 'Sorted by 2017 mortality'
height_feature = 'mortality'
time_feature = 'Year'
countries_to_highlight = \
['France', 'UK', 'US', 'China', 'Norway', 'India', 'Japan', 'Brazil', 'Sierra Leone', 'Afghanistan', 'Kenya', 'S. Africa', 'Iran', 'Iraq', 'Viet Nam']
df1 = df1.sort_values(by=[sorted_feature, 'Year'])
scatter1 = go.Scatter3d(
        x=df1['Year'],
        y=df1[sorted_feature],
        z=df1[height_feature],
        text=df1['Country'],
        mode='markers',
        marker=dict(
            size=5,
            color=get_custom_rgbas('mortality', countries_to_highlight, opac_low = 0.7)
        ),
        hovertemplate=
        "<b>%{text}</b><br><br>" +
        "Year: %{x}<br>" +
        "MMR: %{z}<br>" +
        "<extra></extra>"
        )
fig1.add_trace(scatter1)

fig1.update_layout(
    template="plotly_dark",
    title='Countries by 2017 mortality, MMR:= maternal mortality rate per 100,000', 
    scene_camera = dict(eye=dict(x=45, y=0, z=15)),
    scene = dict(
        aspectmode = "manual",
        aspectratio = dict( x = 15, y = 60, z = 15),
        xaxis = dict(title=time_feature),
        yaxis = dict(
            title='Countries sorted by 2017 MMR', 
            tickvals = [get_sorted_id(df1_raw['2017 MMR per 100000'], df1_raw.loc[df1_raw.Country==c].index[0]) for c in countries_to_highlight], 
            ticktext=countries_to_highlight
        ),
        zaxis = dict(
            title='MMR', 
            tickvals=[0, 500, 1000, 1500, 2000, 2500], 
            ticktext=['0', '500', '1000', '1500', '2000', '2500']
        )
    )
)

In [10]:
countries_to_highlight = \
['France', 'UK', 'US', 'China', 'India', 'Japan', 'Brazil', 'Afghanistan', 'S. Africa', 'Viet Nam', 'Belarus', 'Iran', 'Iraq']
df1 = df1.sort_values(by=['Sorted by pct change in mortality', 'Year'])
fig2 = go.Figure()
scatter2 = go.Scatter3d(
        x=df1['Year'],
        y=df1['Sorted by pct change in mortality'],
        z=df1['Relative change in mortality']*100,
        text=df1['Country'],
        mode='markers+lines',
        marker=dict(
            size=5,
            color=get_custom_rgbas('Relative change in mortality', countries_to_highlight)
        ),
        line=dict(
            width=5,
            color=get_custom_rgbas('Relative change in mortality', countries_to_highlight)
        ),
        hovertemplate=
        "<b>%{text}</b><br><br>" +
        "Year: %{x}<br>" +
        "MMR change: %{z:.2f}%<br>" +
        "<extra></extra>"
)

fig2.add_trace(scatter2)

fig2.update_layout(
    template="plotly_dark",
    title='Countries by change in mortality, MMR change:= % change in maternal mortality rate from 2000', 
    scene_camera = dict(eye=dict(x=60, y=0, z=15)),
    scene = dict(
        aspectmode = "manual",
        aspectratio = dict( x = 25, y = 60, z = 15),
        xaxis = dict(title='Year'),
        yaxis = dict(
            title='Countries sorted by MMR change', 
            tickvals = [get_sorted_id(df1_raw['pct change'], df1_raw.loc[df1_raw.Country==c].index[0]) for c in countries_to_highlight],  
            ticktext=countries_to_highlight
        ),
        zaxis = dict(
            title='MMR change', 
            tickvals=[-50,0,50], 
            ticktext=['-50%','0%','+50%']
        )
    )
)


# Launch local web app
## Requires `fig1` and `fig2` above

In [18]:
app = dash.Dash()



app.layout = html.Div([
    dcc.Graph(figure=fig1, id='fig1'),
    dcc.Graph(figure=fig2, id='fig2')
])



app.run_server(debug=True, use_reloader=False)

Dash is running on http://127.0.0.1:8050/

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: on
