In [1]:
import stock_data_downloader
import nn
import pandas as pd
from jupyter_dash import JupyterDash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from dash.exceptions import PreventUpdate
import dash_table
import graphs

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

app.layout = html.Div([
                        html.Hr(style={'width':'35%','float':'left'}),
                        dcc.Input(
                            id='ticker-selector', 
                            type='text', 
                            value='MSFT', 
                            style={"height":'50px'}),
                        html.Hr(style={'width':'45%','float':'right'}),
                        html.Button(
                            id='go-button-state', 
                            n_clicks=0, 
                            children='Go'),
                        dcc.Store(
                            id='dataset'),
                        html.Div(
                            id='technical-indicators', 
                            style = {'float':'left', 'width':'24%'}),
                        html.Div(
                            dcc.Loading(
                                id = "price-loading", type = 'cube',
                                children=[
                                    dcc.Graph(
                                        id='price')]), 
                            style = {'display':'inline-block', 'width':'75%'}),
                        html.Div(
                            dcc.Loading(
                                id = "candle-loading", 
                                type = 'cube',
                                children=[
                                    dcc.Graph(
                                        id='candle')]), 
                            style = {'float':'right', 'width':'99%'}),
                        dcc.Store(
                            id='memory'),
                        html.Table([
                            html.Thead([
                                html.Tr(
                                    html.Th('Click to train model and make predictions:')),
                            html.Tr([
                                html.Th(
                                    html.Button(
                                        'Figures', 
                                        id='memory-button')),]),]),]),    
                        dcc.Loading(id = "model-loading", type = 'cube',
                        children=[
                            html.Div(
                                dcc.Graph(
                                    id='loss_plot'),
                                style = {'float':'left', 'width':'24%'}) ,
                            html.Div(
                                dcc.Graph(
                                    id='predictions'),
                                style = {'display':'inline-block', 'height':675, 'width':'75%'})])
                    ])

@app.callback(Output('dataset', 'data'),
              Output('memory-button', 'n_clicks'),
              Input('go-button-state', 'n_clicks'),
              State('ticker-selector', 'value'))
def get_dataset(n_clicks, ticker):
    df = stock_data_downloader.fetch_data(ticker)
    df = stock_data_downloader.calculate_technical_indicators(df)
    return (df.to_json(date_format='iso', orient='split'), -1)

@app.callback(Output('technical-indicators', 'children'), Input('dataset', 'data'))
def update_table(stock_data):
    df = pd.read_json(stock_data, orient='split')
    today = df[['Open', 'High', 'Low', 'Volume', 'RSI_EMA', 'RSI_SMA']].iloc[-1]
    prev_close = df['Close'].iloc[-2]
    today['Prev. Close'] = prev_close
    return graphs.ti_table(today)

@app.callback(Output('price', 'figure'),
              Output('candle', 'figure'),
              Input('dataset', 'data'))
def update_graph(stock_data):
    df = pd.read_json(stock_data, orient='split')
    fig1 = graphs.plot_price(df)
    fig2 = graphs.plot_candlesticks(df)
    return (fig1, fig2)

@app.callback(Output('memory', 'data'),
              [Input('memory-button', 'n_clicks')],
              [State('memory', 'data')])
def on_click(n_clicks, data):
    if n_clicks is None:
        # prevent the None callbacks is important with the store component.
        # you don't want to update the store for nothing.
        raise PreventUpdate
    if n_clicks == -1:
        return {'clicks':0}
    # Give a default data dict with 0 clicks if there's no data.
    data = data or {'clicks': 0}
    data['clicks'] = data['clicks'] + 1
    
    return data

@app.callback(Output('loss_plot', 'figure'),
              Output('predictions', 'figure'),
              [Input('memory', 'modified_timestamp')],
              [State('memory', 'data')],
              [State('dataset', 'data')])
def on_data(ts, data, stock_data):
    if ts is None:
        fig = go.Figure()
        fig.update_layout(plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)',
                          yaxis = dict(showgrid=False, zeroline=False, tickfont = dict(color = 'rgba(0,0,0,0)')),
                          xaxis = dict(showgrid=False, zeroline=False, tickfont = dict(color = 'rgba(0,0,0,0)')))
        return(fig,fig)
    data = data or {}
    
    fig1 = go.Figure()
    fig1.update_layout(plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)',
                          yaxis = dict(showgrid=False, zeroline=False, tickfont = dict(color = 'rgba(0,0,0,0)')),
                          xaxis = dict(showgrid=False, zeroline=False, tickfont = dict(color = 'rgba(0,0,0,0)')))
    fig2 = go.Figure()
    fig2.update_layout(plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)',
                          yaxis = dict(showgrid=False, zeroline=False, tickfont = dict(color = 'rgba(0,0,0,0)')),
                          xaxis = dict(showgrid=False, zeroline=False, tickfont = dict(color = 'rgba(0,0,0,0)')))
    
    if data.get('clicks', 0) > 0:
        df = pd.read_json(stock_data, orient='split')
        X_test, y_test, X_train, y_train, test_scaler = nn.split_dataset(df.to_numpy())
        model = nn.make_model(X_train)
        hist = model.fit(X_train, y_train, epochs=3, batch_size=60, validation_data = (X_test, y_test))
        yhat, yhat_df, y_test_df = nn.get_yhat(model, X_test, y_test, test_scaler, df)

        loss_df = pd.DataFrame(data={'train':hist.history['loss'], 'test':hist.history['val_loss']})
        fig1 = px.line(loss_df, x=loss_df.index, y=['train', 'test'],labels={'index':'epoch'},color_discrete_map={'train':'#0000FF', 'test':'#FFA500'})

        fig2 = go.Figure(data=[go.Scatter(x=y_test_df.index, y=y_test_df[y_test_df.columns[0]], name='true', line={'color':'#0000FF'})])
        fig2.add_trace(trace=go.Scatter(x=yhat_df.index, y=yhat_df[yhat_df.columns[0]], name='pred', line={'color':' #cb4335 '}))
    
    return (fig1, fig2)


app.run_server()

Dash app running on http://127.0.0.1:8050/
[*********************100%***********************]  1 of 1 completed
Epoch 1/3
Epoch 2/3
Epoch 3/3
