In [1]:
# import random
# import numpy as np
import pandas as pd
import dash_core_components as dcc
import dash_bootstrap_components as dbc
import dash_html_components as html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
from sklearn.model_selection import train_test_split
from pymongo import MongoClient
from xgboost import XGBRegressor
from jupyter_dash import JupyterDash
import plotly.graph_objects as go
import plotly.express as px

In [2]:
external_stylesheets = [dbc.themes.BOOTSTRAP]
app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
server = app.server
url = 'mongodb://localhost:27017/'
client = MongoClient(url)
db = client.pums18

In [3]:
def get_sample_df():
    data = db.smp.find({})
    return pd.DataFrame(list(data)).drop(columns=["_id"])

In [4]:
def get_model_df(puma):
    data = db.lab.find({"LOCATION": puma})
    return pd.DataFrame(list(data))

In [5]:
def get_options(input_column):  
    return [{"label": option, "value": option} for option in input_column]

In [6]:
def get_states():
    states = db.loc.distinct("STATE")
    return get_options(states)

In [7]:
def get_sectors():
    sectors = db.ind.distinct("SECTOR")
    return get_options(sectors)

In [8]:
def get_fields():
    fields = db.occ.distinct("FIELD")
    return get_options(fields)

In [9]:
def get_schooling():
    schooling = db.edu.distinct("SCHOOLING")
    return get_options(schooling)

In [10]:
def get_columns():
    colnames = ["SALARY", "HOURS", "AGE", "FIELD", "SECTOR",]
    return get_options(colnames)

In [11]:
df = get_sample_df()

In [12]:
app.layout = html.Div([
    dbc.Tabs([
        dbc.Tab(label="Viz",
                children=[
                    dbc.Col([
                        html.H1("Visualize data"),
                        html.Hr(),
                        dbc.Select(id="graph-select", options=get_columns()),
                    ]),
                    html.P(),
                    dbc.Row([dbc.Col([dcc.Graph(id="graph")])])
                ]),
        dbc.Tab(label="Pred",
                children=[
                    dbc.Container([
                        html.H1("Predict salaries"),
                        html.Hr(),
                        dbc.FormGroup([
                            html.P(),
                            dbc.Row([
                                dbc.Col([
                                    dbc.Label("State"),
                                    dbc.Select(id="states",
                                               options=get_states()),
                                ]),
                                dbc.Col([
                                    dbc.Label("Schooling"),
                                    dbc.Select(id="schooling",
                                               options=get_schooling()),
                                ]),
                            ]),
                            html.P(),
                            dbc.Row([
                                dbc.Col([
                                    dbc.Label("Age"),
                                    dbc.Input(id="age", value=40)
                                ]),
                                dbc.Col([
                                    dbc.Label("Avg Hrs/Wk"),
                                    dbc.Input(id="hrs", value=40)
                                ]),
                            ]),
                            html.P(),
                            dbc.Row([
                                dbc.Col([
                                    dbc.Label("Field"),
                                    dbc.Select(id="field",
                                               options=get_fields(),
                                               value=0)
                                ]),
                                dbc.Col([
                                    dbc.Label("Sector"),
                                    dbc.Select(id="sector",
                                               options=get_sectors(),
                                               value=0)
                                ]),
                            ]),
                            html.P(),
                            dbc.Row([
                                dbc.Col([
                                    dbc.Label("Occupation"),
                                    dbc.Select(id="occupation")
                                ]),
                                dbc.Col([
                                    dbc.Label("Industry"),
                                    dbc.Select(id="industry")
                                ]),
                            ]),
                            html.P(),
                            dbc.Row([
                                dbc.Col([
                                    dbc.Label("Location"),
                                    dbc.Select(id="pumas"),
                                ]),
                            ]),
                            html.P(),
                            dbc.Button("Go!", id="go-button"),
                        ]),
                    ]),
                ]),
        dbc.Tab(label="Shap",
                children=[
                    html.H1("Explain results"),
                    html.Hr(),
                ]),
    ]),
])

In [13]:
@app.callback(Output('graph-select', 'value'),
              [Input('graph-select', 'options')])
def get_plotting_selection(available_options):
    return available_options[0]['value']


@app.callback(Output('states', 'value'), [Input('states', 'options')])
def get_state_selection(available_options):
    selection = available_options[0]['value']
    return selection


@app.callback(Output('pumas', 'options'), [Input('states', 'value')])
def get_pumas(state):
    data = db.loc.find({"STATE": state})
    df = pd.DataFrame(list(data))
    pumas = list(df.LOCATION.values)
    return get_options(pumas)


@app.callback(Output('pumas', 'value'), [Input('pumas', 'options')])
def get_puma_selection(available_options):
    return available_options[0]['value']


@app.callback(Output('schooling', 'value'), [Input('schooling', 'options')])
def get_schooling_selection(available_options):
    selection = available_options[0]['value']
    return selection


@app.callback(Output('sector', 'value'), [Input('sector', 'options')])
def get_sector_selection(available_options):
    selection = available_options[0]['value']
    return selection


@app.callback(Output('field', 'value'), [Input('field', 'options')])
def get_field_selection(available_options):
    selection = available_options[0]['value']
    return selection


@app.callback(Output('industry', 'options'), [Input('sector', 'value')])
def get_industries(sector):
    data = db.ind.find({"SECTOR": sector})
    df = pd.DataFrame(list(data))
    industries = list(df.INDUSTRY.values)
    return get_options(industries)


@app.callback(Output('occupation', 'options'), [Input('field', 'value')])
def get_occupations(field):
    data = db.occ.find({"FIELD": field})
    df = pd.DataFrame(list(data))
    occupations = list(df.OCCUPATION.values)
    return get_options(occupations)


@app.callback(Output('occupation', 'value'), [Input('occupation', 'options')])
def get_occupation_selection(available_options):
    selection = available_options[0]['value']
    return selection


@app.callback(Output('industry', 'value'), [Input('industry', 'options')])
def get_industry_selection(available_options):
    selection = available_options[0]['value']
    return selection


@app.callback(Output('graph', 'figure'), [Input('graph-select', 'value')])
def get_graph(column):
    fig = px.histogram(get_sample_df(), column, template="simple_white")
    return fig

In [14]:
app.run_server(port=8051)

Dash app running on http://127.0.0.1:8051/
