In [2]:
# Jonathan_Hamburg_DATS6103
# Project 1 -- Military Spending by Country from 1949 - 2016

# Set up ------
import pandas as pd
import numpy as np
import functools
# Must have bokeh 0.12.5 installed, 0.12.7 breaks
from bokeh.charts import *
from bokeh.layouts import column, gridplot, layout
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
import bokeh.palettes as bpal
import bokeh.models as bmodels
from ipywidgets import widgets
from IPython.core.display import display, HTML
from IPython.core.debugger import set_trace
from fbprophet import Prophet

In [3]:
# update Jupyter Notebook display
display(HTML("<style>.container {width:85% !important;} </style>"))

In [4]:
# For bokeh plotting
output_notebook()

In [5]:
#######################################
# Download Raw data and read into Python ----
#######################################

# Read in Military Data
rawFile = "SIPRI-Milex-data-1949-2016.xlsx"

# Object names
dats = ['milSpend', 'milPerGDP', 'milPerCap']

# Sheet info
sheetInfo = [['Current USD', 5], 
             ['Share of GDP', 5],
             ['Per capita', 6]]

# Function to read in data
rawDat = dict()
def readInRaw(datName, sheetName, skipRow):
    rawDat[datName + 'Raw'] = pd.read_excel(rawFile,
                                            sheetname = sheetName,
                                            skiprows = skipRow)
# Actually read in data
for dat in range(len(dats)):
    sheet = sheetInfo[dat]
    readInRaw(dats[dat], sheet[0], sheet[1])

In [6]:
#######################################
# Get a list of countries to region for analysis later
#######################################
tmp = rawDat[dats[0]+'Raw']
ctryList = tmp.Country

# Need to strip out blank lines and footnotes
ctryList = ctryList[ctryList.notnull()]
footNoteRow = np.where(ctryList.str.contains('footnote', case = False))[0]
ctryList = ctryList.iloc[:int(footNoteRow)]

In [7]:
# Create a dimension table that has country to region. In the raw data
# the list of countries have region headers so all of the rows with
# all NA values for the years are the regions for the countries listed below it
# Will create the table using the following steps:
#    1. Filter raw data to the country list
#    2. Get a dictionary of countries per region
#    3. 

def isint(value):
    """
    Need this in case the columns are all strings, need to convert
    strings to numbers"""
    try:
        int(value)
        return True
    except ValueError:
        return False

def getNumCols(dat):
    return(dat.columns[pd.Series(dat.columns.map(isint)).eq(True)])

def getNewObs(ctry, curRegion, curContinent, datCols):
    newDict = {'Country': ctry, 'Region': curRegion,
               'Continent': curContinent}
    return(pd.Series(newDict))

# Find region names by identifying 
def getRegionDict(dat):
    datCols = ['Country', 'Region', 'Continent']
    res = pd.DataFrame(columns=datCols)
    numCols = getNumCols(dat)
    ctries = None
    for row in range(dat.shape[0]):
        filtDat = dat.iloc[row, :]
        ctry = filtDat.Country

        # If all years are missing, then region name
        if filtDat.loc[numCols].isnull().all():
            #set_trace()
            # Continent are immediately followed by another region
            if ctries == []:
                curContinent = dat.iloc[row - 1, :].Country
                
            curRegion = str(ctry)
            ctries = []
        else:
            ctries.append(str(ctry))
            newObs = getNewObs(ctry, curRegion, curContinent, datCols)
            res = res.append(newObs, ignore_index = True)

        if row == dat.shape[0]:
            newObs = getNewObs(ctry, curRegion, curContinent, datCols)
            res = res.append(newObs, ignore_index = True)
            
    # To minimize number of regions, will combine all Asia countries
    # into just Asia
    asiaBool = res.Region.str.contains('Asia')
    asiaRegions = res.Region[asiaBool].unique()
    asiaIfElseFunc = lambda x: np.where(x.Region.isin(asiaRegions),
                                        'Asia', x.Region)
    res = res.assign(Region = asiaIfElseFunc(res))
    return(res)

In [8]:
#######################################
# Create region/country dimension table
#######################################
tmp = tmp[tmp.Country.isin(ctryList)]
regionDimTbl = getRegionDict(tmp)
regionDimTbl.head()

Unnamed: 0,Country,Region,Continent
0,Algeria,North Africa,Africa
1,Libya,North Africa,Africa
2,Morocco,North Africa,Africa
3,Tunisia,North Africa,Africa
4,Angola,Sub-Saharan,Africa


In [9]:
#######################################
# Create functions to clean datasets for analysis ----
#######################################

def cleanRawData(dat, varType):
    """
    Will do the following:
    1) Filter out extra rows by deleting off any row with
       missing data in all numeric columns
    2) Clean up columns
    3) Melt data so year is one column
    """
    # find numeric columns
    numCols = getNumCols(dat)
    
    # delete out rows with no data in all numeric columns
    rowsWtihMissCols = dat[numCols].isnull().all(1)
    tmp = dat[~rowsWtihMissCols]
    
    # filter data to country and year data only
    cols = ['Country']
    cols.extend(numCols.tolist())
    tmp = tmp[cols]
    
    # Switch from wide to long data
    res = tmp.melt(id_vars = 'Country', var_name = 'year', 
                   value_name = varType)
    
    # Clean up variable class
    res['year'] = pd.to_numeric(res.year, errors = 'coerce')
    res[varType] = pd.to_numeric(res[varType], errors = 'coerce')
    return(res)

In [10]:
#######################################
# Clean each dataset and create as series
#######################################
cleanDat = dict()
for i in range(len(dats)):
    cleanDat[dats[i] + 'Clean'] = cleanRawData(rawDat[dats[i] + 'Raw'],
                                               dats[i])


In [11]:
#######################################
# Merge combine into one dataset
#######################################

# Will first need to convert the dictionary to a list
# and can them combine into one item
dictList = []

for key, value in cleanDat.items():
    dictList.append(value)
    
mergeFunc = lambda df1, df2: pd.merge(df1, df2, how = "left", on = ['Country', 'year'])

allDat = functools.reduce(mergeFunc, dictList)

In [12]:
# Clean up workspace
del cleanDat, dictList, rawDat, i, tmp

In [13]:
#######################################
# Finalize dataset by creating addt'l variables
#######################################

# Fix units of military spending, convert GDP % to scale of 100
# and create extra variables: GDP, GDP per Cap
allDat = allDat.assign(milSpend = lambda x: x.milSpend * 1000000)
allDat = allDat.assign(gdp = lambda x: x.milSpend / x.milPerGDP,
                       totalPop = lambda x: x.milSpend / x.milPerCap)
allDat = allDat.assign(gdpPerCap = lambda x: x.gdp / x.totalPop)
allDat = allDat.assign(milPerGDP = allDat['milPerGDP'] * 100)

# Add Dimension columns to data
allDat = allDat.merge(regionDimTbl, how = 'left', on = 'Country')

# Calculate growth and growth rate
diffCalc = lambda x: x['milSpend'] - x['milSpendLagged']
rateCalc = lambda x: round(x['growthAmt'] / x['milSpendLagged'], 2) * 100
growthCalc = lambda x: np.where(x['milSpend'].isnull() | 
                                    x['milSpendLagged'].isnull(), 
                                np.nan, diffCalc(x))
growthRateCalc = lambda x: np.where(x['milSpendLagged'].isnull() |
                                      x['growthAmt'].isnull(), 
                                np.nan, rateCalc(x))

allDat['milSpendLagged'] = allDat.groupby(['Country'])['milSpend'].shift()
allDat['growthAmt'] = growthCalc(allDat)
allDat['growthRate'] = growthRateCalc(allDat)

In [14]:
# To better visualize growth will create an index for 2001
def indexByYear(dat, indexYr, indexVar):
    # Get data for year to index to
    yrDat = dat[dat['year'] == indexYr][['Country', indexVar]]
    
    # Rename variable to merge with all data
    newName =  indexVar + 'Index'
    renameDict = dict()
    renameDict[indexVar] = newName
    yrDat = yrDat.rename(columns = renameDict)
    
    # Delete off newName from original dataset if already exists
    #   This allows the function to be re-ran without all code proceding
    #   needing to be reran.
    
    if newName in dat.columns:
        dat = dat.drop([newName], axis = 1)
    
    # Merge data and calculate index value
    newDat = dat.merge(yrDat, on = 'Country')
    indexFunc = lambda x: x[indexVar] / x[newName]
    newDat = newDat.assign(newName = indexFunc(newDat))
    
    # Rename new column and drop merged variable
    newDat = newDat.drop([newName], axis = 1)
    renameDict = dict(newName = newName)
    newDat = newDat.rename(columns = renameDict)
    return(newDat)

allDat = indexByYear(allDat, 2001, 'milSpend')

In [15]:
# Create objects for use later in the program
startYr = allDat.year.min()
endYr = allDat.year.max()

In [16]:
#######################################
# Create functions to filter data
#######################################

# Functions to filter dataset to the top X countries
# in each year by a given variable
def topCountry(yr, limit = 10, byVar = 'milSpend'):
    tmp = allDat[allDat.year == yr]
    tmp = tmp.sort_values(by = byVar, ascending = False).reset_index()
    res = tmp.Country.head(limit).to_frame()
    res.columns = [yr]
    return(res)

def topRangeCountry(startYr, endYr, limit = 10, byVar = 'milSpend'):
    years = [topCountry(yr = yr, limit = limit, byVar = byVar) for yr in range(startYr,endYr + 1)]
    topDat = years[0].join(years[1:])
    topDF = topDat.melt(id_vars=None, value_name='Country', 
                        value_vars=topDat.columns,var_name='year')
    return(topDF)

def filtToTopCountries(startYr, endYr, limit = 10, byVar = 'milSpend'):
    res = allDat.merge(topRangeCountry(startYr = startYr, endYr = endYr,
                                       limit = limit, byVar = byVar), 
                       on = ['year', 'Country'], how = 'inner')
    return(res)

In [17]:
# Function to filter dataset by a year range and 
# the top averages of a variable over that year range

# Function to find average value of a variable over a range
def findAggValOverRange(startYr, endYr, avgVar, groupVar):
    yrLen = endYr - startYr + 1
    tmp = allDat[(allDat['year'] >= startYr) & (allDat['year'] <= endYr)]
    
    group = tmp.groupby([groupVar])
    
    # Will delete off those countries that don't have data for at least 
    # 60% of the year range (make sure data isn't skewed)
    enoughDat = (group[avgVar].count() / yrLen) > .6
    okayCtries = enoughDat[enoughDat == True].index.tolist()
    group = tmp[tmp['Country'].isin(okayCtries)].groupby([groupVar])
    
    # Find average values
    res = group[avgVar].mean()
    return(res)

# Function will take X countries with the highest average val
def findTopCtriesAggOverRange(startYr, endYr, avgVar, groupVar, limit):
    tmp = findAggValOverRange(startYr, endYr, avgVar, groupVar)
    ctries = tmp.sort_values(ascending = False).head(limit).index.tolist()
    return(ctries)

# Function will filter dataset down to the top countries
def filtDatToTopCtriesAggOverRange(startYr, endYr, avgVar, 
                                   groupVar = 'Country', limit = 5):
    ctryList = findTopCtriesAggOverRange(startYr, endYr, avgVar, groupVar, limit)
    res = allDat[(allDat['Country'].isin(ctryList)) & 
                    (allDat['year'] >= startYr) & 
                    (allDat['year'] <= endYr)]
    return(res)

In [18]:
#######################################
# Function to build side by side line graphs
# that has some interactivity with hovering.
# This can be used with Jupyter Interact statement
# to build dynamic graphs
#######################################
def graphPlots(yr = (1960, 1970), 
               xCol = None, 
               yCol1 = None,
               yCol2 = None, 
               colorMapCol = None,
               xlabel = None,
               ylabel1 = None,
               ylabel2 = None,
               title1 = None,
               title2 = None,
               yaxisLog = False,
               twoPlts = True,
               onePltPlt = 'plt1',
               removeCountry = 'tmp',
               filtByAgg = False,
               aggBy = None):
    
    # Objects to identify which plot to plot
    plt1Plt = False
    plt2Plt = False
    if twoPlts == True:
        plt1Plt = True
        plt2Plt = True
    else:
        if onePltPlt == 'plt1':
            plt1Plt = True
        else:
            plt2Plt = True
    
    #########################################
    # Get data source object
    
    if filtByAgg == False:
            source = filtToTopCountries(yr[0], yr[1])
    else:
        source = filtDatToTopCtriesAggOverRange(startYr = yr[0],
                                                endYr = yr[1],
                                                avgVar = aggBy,
                                                limit = 10)
        
    ## delete off observations where both y1 and y2 are null
    source = source[~(source[yCol1].isnull() & source[yCol2].isnull())]
    
    ## delete off countries that are passed in as 'remove country'
    if removeCountry != 'tmp':
        delList = removeCountry.split(',')
        source = source[~source.Country.isin(delList)]
    datSource = bmodels.ColumnDataSource(source)   
    
    #########################################3
    # Tooltip objects
    tips = [
            ("Country", "@Country"),
            ("Military Spending", "@milSpend{0.00 a}"),
            ("% of GDP", "%@milPerGDP{%0.2f}"),
            ("Military Per Cap", "@milPerCap{0.00 a}"),
            ("GDP Per Cap", "@gdpPerCap{0.00 a}"),        
    ]
               
    if plt1Plt == True:
        hover1 = bmodels.HoverTool(tooltips = tips)
    
    if plt2Plt == True:
        hover2 = bmodels.HoverTool(tooltips = tips)
    #########################################3
    # Figure options and Bokeh tools
    pWidth = 800
    if twoPlts == True:
        pWidth = 500      
    
    plotOps = dict(width = pWidth, plot_height = 450)
    
    if yaxisLog == True:
        plotOps['y_axis_type'] = 'log'
        
    TOOLS = tools ='box_select box_zoom reset'.split()
    #########################################3
    # Create a mapping object for color
    mapVals = [key for key, group in source.groupby([colorMapCol])]
    numVals = len(mapVals)
    
    if numVals < 20:
        pal = bpal.Category20[numVals]
    else:
        pal = bpal.viridis(numVals)
    
    mapper = bmodels.CategoricalColorMapper(
        factors = mapVals,
        palette = pal
    )
    ############################################  
    # Military Spending Plot
    if plt1Plt == True:
        plt1 = figure(
            title = '{} from {} to {}'.format(title1, str(yr[0]),str(yr[1])),
            tools = [hover1, *TOOLS],
            **plotOps)

        plt1.xaxis.axis_label = xlabel
        plt1.yaxis.axis_label = ylabel1

        r1 = plt1.circle(xCol, yCol1, 
                 source = datSource,
                 color = {'field': colorMapCol, 'transform': mapper})

        ## Add lines to points
        mspLineX = [group[xCol] for key, group in source.groupby([colorMapCol])]
        mspLineY = [group[yCol1] for key, group in source.groupby([colorMapCol])]
        plt1.multi_line(xs = mspLineX, ys = mspLineY, line_color = pal)

        ## Update tickmarks
        plt1.xaxis.minor_tick_line_color = None
    
    ############################################
    # GDP Plot
    if plt2Plt == True:
        
        
        plt2 = figure(
            title = '{} from {} to {}'.format(title2, str(yr[0]),str(yr[1])),
            tools = [hover2, *TOOLS], **plotOps)
            #x_range = plt1.x_range, **plotOps)

        # Define x range if it depends on 2 plots
        if twoPlts == True:
            plt2.x_range = plt1.x_range


        plt2.xaxis.axis_label = xlabel
        plt2.yaxis.axis_label = ylabel2

        r2 = plt2.circle(xCol, yCol2, 
                 source = datSource,
                 color = {'field': colorMapCol, 'transform': mapper})

        gdpLineX = [group[xCol] for key, group in source.groupby([colorMapCol])]
        gdpLineY = [group[yCol2] for key, group in source.groupby([colorMapCol])]
        plt2.multi_line(xs = gdpLineX, ys = gdpLineY, line_color = pal)

        ## Update tickmarks
        plt2.xaxis.minor_tick_line_color = None
    
    ############################################
    # Dummy plot to create a legend
    p0 = figure(tools=[], logo=None, width = 150,
               x_range = bmodels.Range1d(1000, 1000),
               y_range = bmodels.Range1d(1000, 1000),
               toolbar_location = None)
    p0.circle(xCol, yCol1,source = datSource,
              color = {'field': colorMapCol, 'transform': mapper},
              legend = colorMapCol)
    p0.outline_line_alpha = 0
    p0.legend.location = "top_left"
    p0.renderers.visible = False
    
    if twoPlts == True:
        p = gridplot([[plt1, plt2, p0]])
    else: 
        if plt1Plt == True:
            p = gridplot([[plt1, p0]])
        else:
            p = gridplot([[plt2, p0]])
    
    show(p)

In [19]:
#######################################
# Military spending vs GDP per year
#######################################
widgets.interact(graphPlots, 
                 yr = widgets.IntRangeSlider(
                     value = [1960, 1970],
                     min = startYr, 
                     max = endYr,
                     step = 1,
                     description = 'Year Range:'
#                      disable = False,
#                      continuous_update = False
                 ),
                xCol = widgets.fixed('year'), 
                yCol1 = widgets.fixed('milSpend'),
                yCol2 = widgets.fixed('gdp'),
                colorMapCol = widgets.fixed('Country'),
                yaxisLog = widgets.fixed(True),
                xlabel = widgets.fixed('Year'),
                ylabel1 = widgets.fixed('Military Spending in USD'),
                ylabel2 = widgets.fixed('GDP in USD'),
                title1 = widgets.fixed('Military Spending'),
                title2 = widgets.fixed('GDP'),
                filtByAgg = widgets.fixed(False),
                 aggBy = widgets.fixed('blah'),
                onePltPlt = ['plt1', 'plt2']);

In [20]:
#######################################
# Per Capita Comparison
#######################################
widgets.interact(graphPlots, 
                 yr = widgets.IntRangeSlider(
                     value = [1988, endYr],
                     min = 1988, 
                     max = endYr,
                     step = 1,
                     description = 'Year Range:'
#                      disable = False,
#                      continuous_update = False
                 ),
                 xCol = widgets.fixed('year'), 
                 yCol1 = widgets.fixed('milPerCap'),
                 yCol2 = widgets.fixed('gdpPerCap'),
                 colorMapCol = widgets.fixed('Country'),
                 yaxisLog = False,
                xlabel = widgets.fixed('Year'),
                ylabel1 = widgets.fixed('Military Spending Per Capita in USD'),
                ylabel2 = widgets.fixed('GDP Per Capita in USD'),
                title1 = widgets.fixed('Military Spending Per Capita'),
                title2 = widgets.fixed('GDP Per Capita'),
                filtByAgg = widgets.fixed(False),
                onePltPlt = ['plt1', 'plt2'],
                 aggBy = widgets.fixed('blah'),
                removeCountry = 'Kuwait');

In [21]:
#######################################
# Fastest growing countries by total amount
# and by Growth Rate
#######################################

widgets.interact(graphPlots, 
                 yr = widgets.IntRangeSlider(
                     value = [2000, endYr],
                     min = startYr, 
                     max = endYr,
                     step = 1,
                     description = 'Year Range:',
                     continuous_update = False
                 ),
                 xCol = widgets.fixed('year'), 
                 yCol1 = widgets.fixed('milSpendIndex'),
                 yCol2 = widgets.fixed('milSpendIndex'),
                 colorMapCol = widgets.fixed('Country'),
                 xlabel = widgets.fixed('Year'),
                 ylabel1 = widgets.fixed('Military Spending Growth'),
                 ylabel2 = widgets.fixed('Military Spending Growth'),
                 title1 = widgets.fixed('Military Spending Growth Indexed to 2001'),
                 title2 = widgets.fixed('Military Spending Growth Indexed to 2001'),
                 filtByAgg = True,
                 aggBy = widgets.Dropdown(options = ['growthAmt', 'growthRate'],
                                          value = 'growthAmt',
                                          description = 'Aggregate By:'),
                 twoPlts = widgets.fixed(False),
                 onePltPlt = ['plt1', 'plt2']
                 
                 );

In [22]:
#######################################
# Military spending by % of GDP
#######################################

def milSpendByPerGDP(yr, xcol, ycol, title,
                 xlabel, ylabel, yaxisLog = False):
    
    source = allDat[allDat['year'] == yr]
    datSource = bmodels.ColumnDataSource(source)
    hover = bmodels.HoverTool(
        tooltips = [
            ("Country", "@Country"),
            ("Military Spending", "@milSpend{0.00 a}"),
            ("% of GDP", "%@milPerGDP{%0.2f}")
        ]
    )

#     set_trace()
    
    colMapperCol = 'Region'
    
    mapVals = source[colMapperCol].unique()
    mapper = bmodels.CategoricalColorMapper(
        factors = mapVals.tolist(),
        palette = bpal.Category20[len(mapVals)]
    )

    plotOps = dict(plot_width = 900, plot_height = 500,
                   x_axis_type = 'log')
    
    if yaxisLog == True:
        plotOps['y_axis_type'] = 'log'
        
    TOOLS = 'box_select box_zoom reset'.split()
 
    plt1 = figure(
        title = '{} in {}'.format(title, str(yr)),
        tools = [hover, *TOOLS],
        **plotOps)    

    plt1.xaxis.axis_label = xlabel
    plt1.yaxis.axis_label = ylabel

    plt1.circle(xcol, ycol, 
             source = datSource,
             color = {'field': colMapperCol, 'transform': mapper},
             size = 7)
        
#     ############################################
#     # Dummy plot to create a legend
    p0 = figure(tools=[], logo=None, width = 300,
               x_range = bmodels.Range1d(1000, 1000),
               y_range = bmodels.Range1d(1000, 1000),
               toolbar_location = None)
    p0.circle(xcol, ycol,source = datSource,
              color = {'field': colMapperCol, 'transform': mapper},
              legend = colMapperCol)
    p0.outline_line_alpha = 0
    p0.legend.location = "top_left"
    p0.renderers.visible = False
    
    p = gridplot([[plt1, p0]])
    
    show(p)
#     show(plt1)

In [23]:
widgets.interact(milSpendByPerGDP,
                 yr = widgets.IntSlider(
                     value = 1950,
                     min = startYr, 
                     max = endYr,
                     step = 5,
                     description = 'Year:'
                 ),
                 xcol = widgets.fixed('milSpend'),
                 ycol =  widgets.fixed('milPerGDP'),
                 title =  widgets.fixed('Military Spending in USD as a % of GDP'),
                 xlabel =  widgets.fixed('Log of Military Spending in USD'),
                 ylabel =  widgets.fixed('% of GDP')
                 );

In [24]:
#######################################
# Now after looking at the past, will look at some
# stats for the future and current environment
#######################################

# For this section, will only use those countries that are
# in the top 12 for 2016 military spending
top2016Dat = filtToTopCountries(startYr = 2016, endYr = 2016, 
                                byVar = 'milSpend', limit = 12)
top2016Ctries = top2016Dat.Country.tolist()

In [25]:
#######################################
# Clean dataset to run forecasts on it
####################################### 
tsDat = allDat[(allDat['Country'].isin(top2016Ctries)) & (allDat['year'] > 1000)]
tsDat = tsDat.assign(ds = pd.to_datetime(tsDat['year'], format = '%Y'))
tsDat = tsDat[['ds', 'Country', 'milSpend']]

In [26]:
#######################################
# Function to run forecasts on a given country for 
# a given number of years into the future
####################################### 
def forecastData(ctry, forecastRange = 5):
    # Filter and format data
    df = tsDat[tsDat['Country']== ctry]
    df = df.rename(columns = {'milSpend': 'y'})
    
    # Create forecast object
    forecastMod = Prophet(interval_width=0.95, weekly_seasonality=False, yearly_seasonality=False)
    forecastMod.fit(df)
    futureDates = forecastMod.make_future_dataframe(periods=forecastRange, freq='AS')
    
    # Run forcasts
    foreCasts = forecastMod.predict(futureDates)
#     forecastMod.plot(foreCasts, uncertainty=True);
    foreCasts = foreCasts[['ds', 'yhat', 'yhat_lower','yhat_upper']]
    return(foreCasts)

In [27]:
#######################################
# Run forecasts for each country and 
# recombine into one dataset
####################################### 
forecasts = []
for ctry in tsDat.Country.unique():
    tmp = forecastData(ctry)
    tmp = tmp.assign(Country = ctry)
    forecasts.append(tmp)

allForecasts = pd.concat(forecasts)
del forecasts

In [28]:
#######################################
# Combine forecasts to original data and 
# create dataset for graphing
####################################### 

# Merge forcast data and original data 
tsDatIndexed = tsDat.set_index(['ds', 'Country'])
allForecastsIndexed = allForecasts.set_index(['ds', 'Country'])
viz_df = tsDatIndexed.join(allForecastsIndexed, how = 'outer')

# Clean up merged data for plotting
viz_df.drop(['yhat_lower', 'yhat_upper'], axis = 1, inplace = True)
viz_df = viz_df.reset_index()
viz_df = viz_df.assign(year = viz_df['ds'].dt.year)
ifelseFunc = lambda x: np.where(x['year'] > endYr, x['yhat'], x['milSpend'])
viz_df = viz_df.assign(allSpend = ifelseFunc(viz_df))

In [29]:
#######################################
# Function to build line graph with vertical line
# seprating actual vs forecasted data.
# This can be used with Jupyter Interact statement
# to build dynamic graphs
#######################################
def graphForecasts(xCol = None, yCol1 = None, xlabel = None,
                   ylabel1 = None, title1 = None, yaxisLog = False,
                   colorMapCol = 'Country', removeCountry = 'tmp'):
    
    
    #########################################3
    # Get data source object
    source = viz_df   
    ## delete off countries that are passed in as 'remove country'
    if removeCountry != 'tmp':
        delList = removeCountry.split(',')
        source = source[~source.Country.isin(delList)]
    datSource = bmodels.ColumnDataSource(source)   
    
    #########################################3
    # Tooltip objects
    tips = [
            ("Country", "@Country"),
            ("Military Spending", "@yhat{0.00 a}"),
    ]
               
    hover1 = bmodels.HoverTool(tooltips = tips)

    #########################################3
    # Figure options and Bokeh tools
    pWidth = 1000
    plotOps = dict(width = pWidth, plot_height = 500)
    
    if yaxisLog == True:
        plotOps['y_axis_type'] = 'log'
        
    TOOLS = tools ='pan box_select box_zoom reset'.split()
    #########################################3
    # Create a mapping object for color
    mapVals = [key for key, group in source.groupby([colorMapCol])]
    numVals = len(mapVals)
    
    if numVals < 20:
        pal = bpal.Category20[numVals]
    else:
        pal = bpal.viridis(numVals)
    
    mapper = bmodels.CategoricalColorMapper(
        factors = mapVals,
        palette = pal
    )
    ############################################  
    # Military Spending Plot
    plt1 = figure(
        title = '{}'.format(title1),
        tools = [hover1, *TOOLS],
        **plotOps)

    plt1.xaxis.axis_label = xlabel
    plt1.yaxis.axis_label = ylabel1

    r1 = plt1.circle(xCol, yCol1, 
             source = datSource,
             color = {'field': colorMapCol, 'transform': mapper})

    ## Add lines to points
    mspLineX = [group[xCol] for key, group in source.groupby([colorMapCol])]
    mspLineY = [group[yCol1] for key, group in source.groupby([colorMapCol])]
    plt1.multi_line(xs = mspLineX, ys = mspLineY, line_color = pal)

    ## Add vertical line to signify start of forecasts
    forecastStart = bmodels.Span(location=endYr,
                         dimension='height', line_color='green',
                         line_dash='dashed', line_width=3)
    
    plt1.add_layout(forecastStart)
    
    ## Update tickmarks
    plt1.xaxis.minor_tick_line_color = None
    
    
    ############################################
    # Dummy plot to create a legend
    p0 = figure(tools=[], logo=None, width = 150,
               x_range = bmodels.Range1d(1000, 1000),
               y_range = bmodels.Range1d(1000, 1000),
               toolbar_location = None)
    p0.circle(xCol, yCol1,source = datSource,
              color = {'field': colorMapCol, 'transform': mapper},
              legend = colorMapCol)
    p0.outline_line_alpha = 0
    p0.legend.location = "top_left"
    p0.renderers.visible = False
    
    p = gridplot([[plt1, p0]])
    
    show(p)

In [31]:
#######################################
# Military Spending forecasts
#######################################
widgets.interact(graphForecasts, 
                 xCol = widgets.fixed('year'), 
                 yCol1 = widgets.fixed('allSpend'),
                 colorMapCol = widgets.fixed('Country'),
                 xlabel = widgets.fixed('Year'),
                 ylabel1 = widgets.fixed('Military Spending in USD'),
                 title1 = widgets.fixed('Forecasts of Military Spending'),
                 );