# Nice plot and pandas get data

Loads funcions to:
- plot timeseries with plotly customizable
- color scheemes (economist)
- pandaread for FRED data

In [10]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas_datareader.data as web
import datetime as dt
import numpy as np
import warnings
warnings.filterwarnings('ignore')

In [11]:
# load color scheemes
#%run ../../../aid/Color_Scheeme.ipynb

### Color_Scheeme

In [12]:
# my color scheme
my=['rgb(41, 58, 143)',
    'rgb(11, 102, 189)',
    'rgb(69, 144, 185)',
    'rgb(255,102,204)',
    'rgb(118,42,131)',
    'rgb(165,0,38)',
    'rgb(215,48,39)',
    'rgb(244,109,67)',
    'rgb(102,189,99)',
    'rgb(26,152,80)',
    'rgb(0,104,55)',
    'red',
    'orange',
    'magenta',
    'chartreuse',
    'darksalmon', ]

In [13]:
# http://pattern-library.economist.com/color.html
economist = ['#91b8bd', # mid green
             '#9ae5de', # bright turquoise
             '#acc8d4', # turquoise         
             '#C3CBF9', # my blue ###
             '#d4dddd', # light gray
             '#244747', # dark green
             '#336666', # green
             '#8abbd0', # mid blue
             '#efe8d1', # beige
             '#e3120b', # red
             '#4a4a4a'] # dark gray

In [14]:
def adjust_lightness(color, amount=0.85):
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return 'rgb' + str(colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2]))

In [15]:
col_new = [adjust_lightness(color, amount=0.85) for color in economist]
col_new

['rgb(0.4458333333333333, 0.6423106060606059, 0.6675)',
 'rgb(0.424750656167979, 0.8519160104986876, 0.8120472440944881)',
 'rgb(0.5257142857142858, 0.6857142857142855, 0.7542857142857142)',
 'rgb(0.5272727272727272, 0.59030303030303, 0.9527272727272728)',
 'rgb(0.6891341991341993, 0.7541991341991341, 0.7541991341991341)',
 'rgb(0.12000000000000002, 0.23666666666666664, 0.23666666666666666)',
 'rgb(0.17000000000000004, 0.3399999999999999, 0.33999999999999997)',
 'rgb(0.3959756097560976, 0.6489430894308944, 0.7573577235772357)',
 'rgb(0.8692473118279569, 0.812043010752688, 0.6240860215053763)',
 'rgb(0.7566666666666666, 0.059999999999999984, 0.036666666666666625)',
 'rgb(0.24666666666666667, 0.24666666666666667, 0.24666666666666667)']

In [16]:
col_scheeme = my[:5] + col_new[:5] +economist[5:7] + col_new[7:9] + economist[9:9] + my[5:]

In [17]:
# color scheme
color = px.colors.qualitative.Set2

### Get and organize data from FRED

In [18]:
# If dates begin later, they will adjust automatically
start = dt.datetime(1960, 1, 1) 
# end = dt.datetime(2020, 5, 1) # defaul is today

In [19]:
def print_fred(key, value, start=start):
    a = web.DataReader(key, "fred", start) # read data from web
    a.columns = ['value']   # all value columns will have same name for concat
    a['type']=value  # keep track of variables when concating
    return a

In [20]:
def make_hrizontal_tbl(dic, start=start):
    '''returns a long df where for each observation 
    variables sum to 1, ready to use on plotly'''
    
    list_df = []

    for key, value in dic.items():
        df = web.DataReader(key, "fred", start) 
        list_df.append(df)
        
    df = pd.concat(list_df, join='outer', axis=1) # horizontal if using go.figure

    return df

In [21]:
def make_pct_tbl(dic, start=start):
    '''returns a long df where for each observation 
    variables sum to 1, ready to use on plotly'''
    
    list_df = []

    for key, value in dic.items():
        df = web.DataReader(key, "fred", start) 
        list_df.append(df)
        
    df = pd.concat(list_df, join='outer', axis=1) # horizontal if using go.figure
    df = df.div(df.sum(axis=1), axis=0)
    df.reset_index(inplace=True)
    df = df.melt(id_vars = ['DATE'],var_name='type')
    df['type'] = df.type.map(dic)
    
    return df

In [22]:
def make_long_tbl(dic, start=start):
    '''returns a long df, ready to use on plotly
    '''  
    list_df = []

    for key, value in dic.items():
        df = print_fred(key,value)
        list_df.append(df)
        
    #pd.concat([a,b,c], join='outer', axis=1) horizontal if using go.figure
    df = pd.concat(list_df, axis=0)
    df.reset_index(inplace=True)
    return df

In [23]:
def line_plot(title, dic, start=start):
    
    df = make_long_tbl(dic, start)      
    fig = px.line(df, x="DATE", y="value",color='type',title=title,
                  template="plotly_white",
                  color_discrete_sequence=[ "red",'#C3CBF9', "lightgray",],
                  labels={'value':'', 'DATE':''},
                  height=400, width=600)
    
    annotations=[]
    annotations.append(dict(xref='paper', yref='paper', x=0, y=-0.1,
                                  xanchor='left', yanchor='top',
                                  text='Source: US Census Bureau',
                                  font=dict(family='Arial',size=12, color='rgb(150,150,150)'),
                                  showarrow=False))

    fig.update_layout(annotations=annotations)
    
    fig.show()

In [24]:
def line_plot_PCT(title, dic, start=start):
    
    df = make_pct_tbl(dic, start)      
    fig = px.line(df, x="DATE", y="key",color='type',title=title,
                  template="plotly_white",
                  color_discrete_sequence=[ "lightgray",'#C3CBF9', "red"],
                  labels={'value':'', 'DATE':''}.update(dic),
                  height=400, width=600)
    
    annotations=[]
    annotations.append(dict(xref='paper', yref='paper', x=0, y=-0.1,
                                  xanchor='left', yanchor='top',
                                  text='Source: US Census Bureau',
                                  font=dict(family='Arial',size=12, color='rgb(150,150,150)'),
                                  showarrow=False))

    fig.update_layout(annotations=annotations)


    fig.show()

In [25]:
def area_plot(title, dic, start=start):
    
    df = make_pct_tbl(dic, start) 
    fig = px.area(df, x="DATE", y="value", line_group="type", color='type',title=title,
                  template="plotly_white",
                  color_discrete_sequence=[ "lightgray",'#C3CBF9', "red"],
                  labels={'value':'', 'DATE':''},
                  height=400, width=600)  
    
    annotations=[]
    annotations.append(dict(xref='paper', yref='paper', x=0, y=-0.1,
                                  xanchor='left', yanchor='top',
                                  text='Source: US Census Bureau',
                                  font=dict(family='Arial',size=12, color='rgb(150,150,150)'),
                                  showarrow=False))

    fig.update_layout(annotations=annotations)


    fig.show()

In [26]:
def bar_plot(title, dic, start=start):
    
    df = make_long_tbl(dic, start) # get rid of total    
    fig = px.bar(df, x="DATE", y="value",  color='type',title=title,
                  template="plotly_white",
                  color_discrete_sequence=[ "lightgray",'#C3CBF9', "red"],
                  labels={'value':'', 'DATE':''},
                  height=400, width=600)  
      
    annotations=[]
    annotations.append(dict(xref='paper', yref='paper', x=0, y=-0.1,
                                  xanchor='left', yanchor='top',
                                  text='Source: US Census Bureau',
                                  font=dict(family='Arial',size=12, color='rgb(150,150,150)'),
                                  showarrow=False))

    fig.update_layout(annotations=annotations)
    
    fig.show()

In [27]:
def dict_from_FRED_description(s):
    '''take series as 'blabla [x89y87t]' and
       returns a dict {'x89y87t':'blabla'} to pass on plots 
       (replaces series name by nice names on legends)''' 
    s = pd.Series(s)
    s_dic = s.str.extract('(.*)\s\[')[0].str.replace('for the United States','')
    s_dic.index= s.str.extract('\[(.*)\]')[0]
    return dict(s_dic)

In [28]:
def text_source(source='US Census Bureau'):
    
    '''add source to the end of plots'''
    annotations=[]
    annotations.append(dict(xref='paper', yref='paper', x=0, y=-0.1,
                                  xanchor='left', yanchor='top',
                                  text='Source: ' + source,
                                  font=dict(family='Arial',size=12, color='rgb(150,150,150)'),
                                  showarrow=False))

    return fig.update_layout(annotations=annotations)

### Plot function

This function label time series lines with annotation, adjusts position of last number,
and show toggle buttons for NBER recessions

In [29]:
# workaround to let function exist with dict and df from environment
dic=dict()
df=pd.DataFrame()

In [30]:
def plot_nice(title='', dic=dic, df=df, margin=dict(autoexpand=False,l=50,r=80,t=90,b=70), 
              vertical_label_gutter=60, show_series_label=True, log = False, 
              show_endpoints=True, 
              source = 'U.S. Census Bureau', height=600, width=1020, y_units='', 
              colors = col_scheeme, line_wid=1.5, recessions=False,
              function= False, fsize=20, fx=.3, fy=.8):
    
    labels = [values for values in dic.values()] 

    dic_colors=pd.Series(colors[:len(dic)])
    dic_colors.index=[key for key in dic.keys()] 
    dic_colors = dict(dic_colors)

    # if x axes do are different, make a stack for looping on it
    x_data = df.index

    fig = go.Figure()

    for key, value in dic.items():
        fig.add_trace(go.Scatter(x=x_data, y=df[key], mode='lines',
            name=value,
            line=dict(color=dic_colors[key], width=line_wid),
            connectgaps=True,
        ))

        # endpoints
        fig.add_trace(go.Scatter(
            x=[x_data[0], x_data[-1]],
            #x=[x_data[i][0], x_data[i][-1]], # if x values are different
            y=[list(df[key])[0], list(df[key])[-1]],
            mode='markers',
            marker=dict(color=dic_colors[key], size=1)
        ))

    fig.update_layout(
        xaxis=dict(showline=True,showgrid=True,showticklabels=True,
                   linecolor='rgb(204, 204, 204)',linewidth=2,ticks='outside',
                   tickfont=dict(family='Arial',size=12,color='rgb(82, 82, 82)',),),
        yaxis=dict(showgrid=True,zeroline=False,showline=True,showticklabels=True,gridcolor='rgb(204, 204, 204)'),
        
        autosize=False,
        margin=margin,
        showlegend=False,
        plot_bgcolor='white'
    )

    annotations = []

    # Adding labels
    for key in dic.keys():  

        # label each line
        
        if show_series_label:
          
        ################################### manually adjusting label position #############################    
            if ((key != 'NYXRCNSA')& (key != 'SFXRCNSA') &
                (key != 'SFXRNSA')& (key != 'DAXRNSA') & 
                (key != 'SEXRNSA') & (key != 'SFXRNSA')&
                (key != 'LXXRHTNSA') & (key != 'NYXRHTNSA')&
                (key != 'LVXRLTNSA') & (key != 'SFXRLTNSA')&(key != 'SEXRLTNSA')
               ):
                
                annotations.append(dict(xref='paper', x=0.95, y=df[key][-1]+vertical_label_gutter,
                                              xanchor='left', yanchor='middle',
                                              #text=dic[key] + ' {}%'.format(df[key][0]),
                                              text=dic[key] ,
                                              font=dict(family='Arial',size=15, color=dic_colors[key]),
                                              showarrow=False))
            elif ((key == 'NYXRCNSA') | (key == 'DAXRNSA')):
                annotations.append(dict(xref='paper', x=0.95, y=df[key][-1]+vertical_label_gutter-5,
                                              xanchor='left', yanchor='middle',
                                              #text=dic[key] + ' {}%'.format(df[key][0]),
                                              text=dic[key] ,
                                              font=dict(family='Arial',size=15, color=dic_colors[key]),
                                              showarrow=False))
            elif ((key == 'SFXRCNSA')|(key == 'SFXRNSA')|(key == 'LVXRLTNSA')|(key == 'SFXRLTNSA')|(key == 'SEXRLTNSA')):
                annotations.append(dict(xref='paper', x=0.95, y=df[key][-1]+vertical_label_gutter+5,
                                              xanchor='left', yanchor='middle',
                                              #text=dic[key] + ' {}%'.format(df[key][0]),
                                              text=dic[key] ,
                                              font=dict(family='Arial',size=15, color=dic_colors[key]),
                                              showarrow=False))
            elif (key == 'SEXRNSA'):
                annotations.append(dict(xref='paper', x=0.95, y=df[key][-1]+vertical_label_gutter-3.5,
                                              xanchor='left', yanchor='middle',
                                              #text=dic[key] + ' {}%'.format(df[key][0]),
                                              text=dic[key] ,
                                              font=dict(family='Arial',size=15, color=dic_colors[key]),
                                              showarrow=False))
                
            elif ((key == 'LXXRHTNSA') | (key == 'NYXRHTNSA')):
                annotations.append(dict(xref='paper', x=0.95, y=df[key][-1]+vertical_label_gutter+7,
                                              xanchor='left', yanchor='middle',
                                              #text=dic[key] + ' {}%'.format(df[key][0]),
                                              text=dic[key] ,
                                              font=dict(family='Arial',size=15, color=dic_colors[key]),
                                              showarrow=False))


        ################################### manually adjusting label position #############################

        
        # making k if number is too large
        text = ('{:.0f}'.format(list(df[key])[0]) if (list(df[key])[0] <1000) else '{:.0f}k'.format(list(df[key])[0]/1000))
        text = (text if show_endpoints else '')
        
        # labeling the left_side of the plot
        annotations.append(dict(xref='paper', x=0.05, y=list(df[key])[0],
                                      #xanchor='right', yanchor='middle',
                                      xanchor='right', yanchor='top',
                                      #text=dic[key] + ' {}%'.format(df[key][0]),
                                      #text=dic[key] + '{:.0f}k'.format(df[key][0]/1000), # with line label
                                      text=text,
                                      font=dict(family='Arial',size=12),
                                      showarrow=False,
                                      ))
        
        
        # making k if number is too large
        text = ('{:.0f}'.format(list(df[key])[-1]) if (list(df[key])[-1] <1000) else '{:.0f}k'.format(list(df[key])[-1]/1000))
        text = (text if show_endpoints else '')
    
        # labeling the right_side of the plot
        annotations.append(dict(xref='paper', x=0.95, y=list(df[key])[-1],
                                      xanchor='left', yanchor='top',
                                      #text='{}%'.format(df[key][-1]),
                                      text= text,
                                      font=dict(family='Arial',size=12),
                                      showarrow=False))
    # Title
    annotations.append(dict(xref='paper', yref='paper', x=0.0, y=1.05,
                                      xanchor='left', yanchor='bottom',
                                      text=title,
                                      font=dict(family='Arial',size=22,color='rgb(37,37,37)'),
                                      showarrow=False))
    # Source
    #annotations.append(dict(xref='paper', yref='paper', x=0, y=-0.1,
    annotations.append(dict(xref='paper', yref='paper', x=0, y=-0.20,     ##### position for short graphs
                                  xanchor='left', yanchor='top',
                                  text='Source: ' + source,
                                  font=dict(family='Arial',size=12, color='rgb(150,150,150)'),
                                  showarrow=False))
    
    #units on y-axis
    annotations.append(dict(xref="paper",yref="paper", x= 0.0,y=1.06,#textangle=-90,
                            font=dict(family='Arial',size=12, color='black'),
                            showarrow=False,text = y_units,))
    
    # function or text over graph
    if function:
        annotations.append(dict(xref="paper",yref="paper", x= fx,y=fy,#textangle=-90,
                            font=dict(family='Arial',size=fsize, color='black'),
                            showarrow=False,text = function))
    
    

    fig.update_layout(annotations=annotations, height=height, width=width)

    if log:
        fig.update_layout(yaxis=dict(type= 'log'))
    
    # add recessions
    if recessions:
        # get recession form FRED and used customized functions to create dataframes
        s = ['NBER Recessions [USRECQM]']
        dic = dict_from_FRED_description(s)
        df = make_hrizontal_tbl(dic, start=min(df.index))
        
        # get initial and final dates (where markers change from 0->1 and 1->0)
        init_date = df[(df.USRECQM == 1) & (df.USRECQM.shift() != 1)].index.to_list()
        end_date =  df[(df.USRECQM == 1) & (df.USRECQM.shift(-1) != 1)].index.to_list()
        # Add shape regions

        shapes = []
        i = 0
        for init_date, end_date in list(zip(init_date, end_date)):
            shapes.append(dict(
                                type="rect",
                                # x-reference is assigned to the x-values
                                xref="x",
                                # y-reference is assigned to the plot paper [0,1]
                                yref="paper",
                                x0=init_date,
                                y0=0,
                                x1=end_date,
                                y1=0.97,
                                fillcolor="lightgray",
                                line=dict(color="white",width=.1,),
                                #opacity=0.5,
                                layer="below",
                                line_width=0,
                                
                            ),)
            i+=1 # count how many recessions to use on "visible" multiplier
    

        #fig.update_layout(shapes=shapes)
        
        fig.update_layout(
        updatemenus=[
        dict(
            type="buttons",
            direction="right",
            active=0,
            #x=1,
            x=1.1,  ##### position for short graphs
            y=1.2,
            buttons=list([
                dict(label="-",
                     method="relayout",
                     args=["shapes", []]),
                
                dict(label="+",
                     method="relayout",
                     args=["shapes", shapes])
                
                        ]),
            )
        ]) # end of update layout
        
        fig.show()
        
    return fig