In [46]:
import pandas as pd
import numpy as np
from bokeh.layouts import gridplot, layout, widgetbox, row
from bokeh.plotting import figure, show, ColumnDataSource, output_server, output_notebook
from bokeh.models import NumeralTickFormatter, LabelSet, Label, LinearAxis, Range1d, HoverTool, VBox
from bokeh.models.widgets import MultiSelect, CheckboxGroup
from bokeh.io import show, curdoc
from bokeh.resources import CDN
from bokeh.embed import file_html, autoload_server, components
from bokeh.client import push_session

dataset = pd.read_csv('data/dataset.csv')
us_avg_price = pd.read_csv('data/us_avg_price.csv')
us_avg_sales = pd.read_csv('data/us_avg_sales.csv')


def plot_state_by_year():
    if 0 in checkbox_group.active:
        average = True
    else:
        average = False

    tools = []

    state_list = multi_select.value

    df = dataset[dataset['state'].isin(state_list)]

    # Seting the params for the first figure.
    s2 = figure(plot_width=800, plot_height=500, tools=tools, y_range=[1.00,3.00],
                title="Cigarette Prices in the US from 1963 to 1992")

    # Setting the second y axis range name and range
    s2.extra_y_ranges = {"sales": Range1d(start=40, end=180)}

    if average == True:
        s2.line(us_avg_price['year'],
                us_avg_price['mean'],
                color='orange',
                line_width=4,
                line_dash="dashed",
                legend="min price national average")

    for state in state_list:
        data = df[df['state'] == state]
        s2.line(data['year'], data['adjusted_min_price'],
                color='orange', legend='min price by state',
                line_width=0.5, line_alpha=0.5)

    if average == True:
        s2.line(us_avg_sales['year'],
                us_avg_sales['mean'],
                color='gray',
                line_width=4,
                y_range_name="sales",
                line_dash="dashed",
                legend="sales national average")

    for state in state_list:
        data = df[df['state'] == state]
        s2.line(data['year'], data['sales'], color='gray',
                legend='sales by state', line_width=0.5,
                line_alpha=0.75,  y_range_name="sales")

    s2.yaxis.axis_label = 'minimum price of cigarette pack (adjusted to 2016 USD)'
    s2.add_layout(LinearAxis(y_range_name="sales", axis_label='cigarette sales in packs per capita'), 'right')
    s2.xaxis.axis_label = 'year'

    s2.legend.location = "bottom_right"
    s2.grid.grid_line_alpha = 0.3
    s2.axis.axis_line_color = 'lightgray'
    s2.axis.minor_tick_line_color = 'lightgray'
    s2.axis.major_tick_line_color = 'lightgray'
    s2.yaxis[0].formatter = NumeralTickFormatter(format="$0.00")
    s2.axis.major_label_text_color = 'gray'

    return s2

def make_scatterplot():
    hover = HoverTool(
            tooltips=[
                ('state', '@state'),
                ('sales', '@sales'),
                ('year', '@year'),
                ('min price', '@adjusted_min_price')
            ]
        )
    tools=[hover]
    state_list = multi_select.value

    df = dataset[dataset['state'].isin(state_list)]
    df = df.dropna()

    source = ColumnDataSource(df)

    slope, intercept = np.polyfit(df['adjusted_min_price'], df['sales'], 1)
    x = np.linspace(1.20, 3.00, 10)
    y = intercept + slope * x

    sc = figure(title = "Cigarette Sales vs. Price of Cigarette in the US from 1963 to 1992",
                width=800, height=600, tools=tools)

    sc.xaxis.axis_label = 'minimum price of cigarette pack (adjusted to 2016 USD)'
    sc.yaxis.axis_label = 'cigarette sales in packs per capita'

    sc.circle('adjusted_min_price', 'sales',
              size=6, fill_alpha=0.5,
              line_alpha=0, color='#1f78b4',
              legend='average consumption states',
              source=source)

    if 1 in checkbox_group.active:
        sc.line(x, y, line_color="lightgray", line_width=4, legend='line of best fit average consumption states')
    else:
        pass


    sc.grid.grid_line_alpha = 0
    sc.axis.axis_line_color = 'lightgray'
    sc.axis.minor_tick_line_color = 'lightgray'
    sc.axis.major_tick_line_color = 'lightgray'
    sc.axis.major_label_text_color = 'gray'

    sc.xaxis[0].formatter = NumeralTickFormatter(format="$0.00")

    return sc

def update_plots(attr, old, new):
    b.children = [plot_state_by_year(), make_scatterplot()]

def update_plot_1(attr, old, new):
    b.children[0] = plot_state_by_year()


states = ['AL',
 'AR',
 'AZ',
 'CO',
 'CT',
 'DC',
 'DE',
 'GA',
 'HI',
 'IA',
 'ID',
 'IL',
 'IN',
 'KS',
 'KY',
 'LA',
 'MA',
 'MD',
 'ME',
 'MI',
 'MN',
 'MO',
 'MS',
 'MT',
 'NC',
 'NE',
 'NH',
 'NM',
 'NV',
 'NY',
 'OH',
 'OK',
 'OR',
 'PA',
 'RI',
 'SC',
 'SD',
 'TN',
 'TX',
 'UT',
 'VA',
 'VT']

multi_select = MultiSelect(title="States:", value=['VT'],
                           options=states)

checkbox_group = CheckboxGroup(
        labels=["show US average", "show regression line"], active=[0, 1])

multi_select.on_change('value', update_plots)

checkbox_group.on_change('active', update_plots)

controls = widgetbox(multi_select, checkbox_group)

b = VBox(plot_state_by_year(), make_scatterplot())
l = row(controls, b)

#curdoc().add_root(l)
session = push_session(curdoc())
script = autoload_server(l, session_id=session.id)


In [None]:
script, div = components(l)
return render_template('graph.html', script=script, div=div)

In [41]:
html = file_html(l, CDN, "cig.html")

In [42]:
text_file = open("cig.html", "w")
text_file.write(html)
text_file.close()

In [2]:
output_notebook()

In [3]:
def make_scatterplot():    
    hover = HoverTool(
            tooltips=[
                ('state', '@state'), 
                ('sales', '@sales'),
                ('year', '@year'),
                ('min price', '@adjusted_min_price')
            ]
        )
    tools=[hover]
    
    #state_list = multi_select.value
    state_list = ['AL']
    df = dataset[dataset['state'].isin(state_list)]
    df = df.dropna()
    
    slope, intercept = np.polyfit(df['adjusted_min_price'], df['sales'], 1)
    x = np.linspace(1.20, 3.00, 10)
    y = intercept + slope * x    
    
    source = ColumnDataSource(df)

    sc = figure(title = "Cigarette Sales vs. Price of Cigarette in the US from 1963 to 1992", 
                width=800, height=600, tools=tools)

    sc.xaxis.axis_label = 'minimum price of cigarette pack (adjusted to 2016 USD)'
    sc.yaxis.axis_label = 'cigarette sales in packs per capita'

    sc.circle('adjusted_min_price', 'sales', 
              size=6, fill_alpha=0.5, 
              line_alpha=0, color='#1f78b4', 
              source=source, 
              legend='average consumption states')


    sc.line(x, y, line_color="lightgray", line_width=4, legend='line of best fit average consumption states')

    sc.grid.grid_line_alpha = 0
    sc.axis.axis_line_color = 'lightgray'
    sc.axis.minor_tick_line_color = 'lightgray'
    sc.axis.major_tick_line_color = 'lightgray'
    sc.axis.major_label_text_color = 'gray'

    sc.xaxis[0].formatter = NumeralTickFormatter(format="$0.00")

    return sc

In [4]:
show(make_scatterplot())

In [23]:
dataset

Unnamed: 0.1,Unnamed: 0,observation,fips,year,price,pop,pop16,cpi,ndi,sales,pimin,state,adjusted_price,adjusted_min_price,adjusted_ndi,ndi_by_100
0,0,1,1,1963,28.6,3383.0,2236.5,30.6,1558.304530,93.9,26.1,AL,2.245343,2.049072,12234.014608,122.340146
1,1,2,1,1964,29.8,3431.0,2276.7,31.0,1684.073202,95.4,27.5,AL,2.309365,2.131126,13050.806770,130.508068
2,2,3,1,1965,29.8,3486.0,2327.5,31.5,1809.841875,98.5,28.9,AL,2.272709,2.204070,13802.830880,138.028309
3,3,4,1,1966,31.5,3524.0,2369.7,32.4,1915.160357,96.4,29.5,AL,2.335628,2.187334,14200.322950,142.003229
4,4,5,1,1967,31.6,3533.0,2393.7,33.4,2023.546368,95.5,29.6,AL,2.272891,2.129038,14554.751054,145.547511
5,5,6,1,1968,35.6,3522.0,2405.2,34.8,2202.485536,88.4,32.0,AL,2.457587,2.209067,15204.491818,152.044918
6,6,7,1,1969,36.6,3531.0,2411.9,36.7,2377.334666,90.1,32.8,AL,2.395814,2.147068,15561.890216,155.618902
7,7,8,1,1970,39.6,3444.0,2394.6,38.8,2591.039159,89.8,34.3,AL,2.451893,2.123736,16042.806274,160.428063
8,8,9,1,1971,42.7,3481.0,2443.5,40.5,2785.315971,95.4,35.8,AL,2.532859,2.123568,16521.806605,165.218066
9,9,10,1,1972,42.3,3511.0,2484.7,41.8,3034.808297,101.1,37.4,AL,2.431096,2.149480,17441.870957,174.418710


In [37]:
checkbox_group = CheckboxGroup(
        labels=["show US average", "show regression line"], active=[0, 1])


In [39]:
checkbox_group.active

[0, 1]