In [1]:
#Dashboard imports
from jupyter_dash import JupyterDash
import dash
import dash_leaflet as dl
from dash import dcc
from dash import html
import plotly.express as px
from dash import dash_table as dt
from dash.dependencies import Input, Output, State
import os
import numpy as np
import pandas as pd
from bson.json_util import dumps
from bson.objectid import ObjectId
import base64
from crud import MLMongo
from cartpole import MLCartpole


DATA_FILE = 'metrics.csv'
IMAGE_FILE = 'logo.png'    

###########################
# Data Manipulation / Model
###########################

username = str('aiuser')
password = str('password')

#Creates the MongoDB CRUD-capable instance
mongobject = MLMongo(username, password)


# class read method must support return of cursor object 
df = pd.DataFrame.from_records(mongobject.read_all({}))
df = df.iloc[:,0:]
                               
#########################
# Dashboard Layout / View
#########################
app = JupyterDash('__name__')

#reads a local image file and encodes it as a base64 string. 
#The open() function is used to open the file in binary mode, 
#and base64.b64encode() is used to encode the binary data as a base64 string.
#(displays images embedded in the app layout using the html.Img() component)
encoded_image = base64.b64encode(open(IMAGE_FILE, 'rb').read())


app.layout = html.Div([
    html.Div(id='hidden-div', style={'display':'none'}),
    html.Center(html.B(html.H1('Machine Learning Dashboard'))),
    html.Center(html.Img(src='data:image/png;base64,{}'.format(encoded_image.decode()))),
    html.Center(html.B(html.H2("Developed by Lukas Mueller (2023)"))),
    html.Hr(),
    
    html.Div([
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Alpha:'),
                    className='col s12 m6',
                    style={'margin-right': '75px'}
                    ),
                html.Div(
                    dcc.Input(id='input-1', type='number', value= 0.01, placeholder="0>x>1, low"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Gamma:'),
                    className='col s12 m6',
                    style={'margin-right': '62px'}
                    ),
                html.Div(
                    dcc.Input(id='input-2', type='number', value= 0.90, placeholder="0>x>1, high"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Epsilon Min:'),
                    className='col s12 m6',
                    style={'margin-right': '33px'}
                    ),
                html.Div(
                    dcc.Input(id='input-3', type='number', value= 0.01, placeholder="0>x>1, low"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Epsilon Max:   '),
                    className='col s12 m6',
                    style={'margin-right': '31px'}
                    ),
                html.Div(
                    dcc.Input(id='input-4', type='number',  value= 0.75, placeholder="0>x>1, high"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Epsilon Decay: '),
                    className='col s12 m6',
                    style={'margin-right': '18px'}
                    ),
                html.Div(
                    dcc.Input(id='input-5', type='number', value= 0.95,placeholder="0>x>1, high"), 
                    className='col s12 m6',
                    )
             ]),

        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('# of Episodes: '),
                    className='col s12 m6',
                    style={'margin-right': '22px'}
                    ),
                html.Div(
                    dcc.Input(id='input-6', type='number', value= 24, placeholder="10-40"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Batch Size:    '),
                    className='col s12 m6',
                    style={'margin-right': '43px'}
                    ),
                html.Div(
                    dcc.Input(id='input-7', type='number', value=12, placeholder="Default value: 20"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Buffer Size:   '),
                    className='col s12 m6',
                    style={'margin-right': '40px'}
                    ),
                html.Div(
                    dcc.Input(id='input-8', type='number', value=1200, placeholder="Default value: 1000000"),
                    className='col s12 m6',
                    )
             ]),
        html.Button('Submit', id='submit-button', n_clicks=0),
        html.Div(id='output')
    ]),
    html.Hr(),
    html.Br(),
    
    #this row houses four radio buttons that are used to filter the data
    html.Div(className='row',
        style={'display': 'flex'},
            children=[
                dcc.RadioItems(
                    id="filter-type",
                    #Labels are provisioned for useful filtering of performance metrics
                    options=[#TODO: the selctions here require that stored record be parsed and 
                        #stored according to performance metric. This may violate best practice.
                       {'label': 'Learning Rate(alpha)', 'value': 'lr'},
                       {'label': 'Foresight(gamma)', 'value': 'exp'},
                       {'label': 'Prioritization(zeta)', 'value': 'att'},
                       {'label': 'Reset', 'value': 'reset'}
                    ],
                    value='reset'
                )
            ]
    ),

    html.Hr(),
    #this is the layout for the data table
    dt.DataTable(
        id='datatable-id',
        columns=[
            {"name": i, "id": i, "deletable": False, "selectable": True} for i in df.columns
        ],
        
        #these are options which mostly provide for native interactivity with the table
        data=df.to_dict('records'),
        editable= False,
        filter_action="native",
        sort_action="native",
        sort_mode="multi",
        column_selectable= False,
        row_selectable= False,
        selected_columns=[],
        selected_rows=[],
        page_action="native",
        page_current=0,
        page_size=10,

        
        #Below is some simple styling for the data table to make it easier to look at...
        
        #table striping implemented to make visual tracking of a document's data easier
        style_cell_conditional=[
            {
                'if': {'column_id': c},
                'textAlign': 'left'
            } for c in ['Date', 'Region']
        ],
        style_data={
            'color': 'black',
            'backgroundColor': 'white'
        },
        style_data_conditional=[
            {
                'if': {'row_index': 'odd'},
                'backgroundColor': 'rgb(220, 220, 220)',
            }
        ],
        style_header={
            'backgroundColor': 'rgb(210, 210, 210)',
            'color': 'black',
            'fontWeight': 'bold'
        },
        #sizing for table cells, to increase presentability based on length of data
        style_cell={
            'minHeight': '16px', 'height': '16px', 'maxHeight': '16px',
            'minWidth': '160px', 'width': '160px', 'maxWidth': '160px',
            'whiteSpace': 'normal'
        },
    ),
    
    html.Br(),
    html.Hr(),
    
#This sets up the dashboard so that your charts are side-by-side 
    html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                #this is the first chart that is displayed (correlation between lr and runs to solve)
                html.Div(
                    id='graph1-id',
                    className='col s12 m6',
                    ),
                #this is the second chart that is displayed (correlation between exp and runs to solve)
                html.Div(
                    id='graph2-id',
                    className='col s12 m6',
                    )
             ])
])
                               
                

#############################################
# Interaction Between Components / Controller
#############################################
@app.callback(Output('output', 'children'),
              [Input('submit-button', 'n_clicks')],
              [State('input-1', 'value'),
               State('input-2', 'value'),
               State('input-3', 'value'),
               State('input-4', 'value'),
               State('input-5', 'value'),
               State('input-6', 'value'),
               State('input-7', 'value'),
               State('input-8', 'value')
              ])         

#This function can be used to store user input values to local variables, but in this case they are used (on the fly) as function arguments
# This causes the cartpole function execute as soon as the user submits their input data   
def update_output(n_clicks, input1, input2, input3, input4, input5, input6, input7, input8):
       if n_clicks > 0:
            print("Please wait while the learning algorithm trains. Performance metrics will then be viewable.")
            cartpole_instance = MLCartpole(input1, input2, input3, input4, input5, input6, input7, input8)
            solved = cartpole_instance.cartpole()
            
            if solved == True:
                print("This model was successful. Congratulations! Attempting to write to database...")
                try:
                    mongobject.writeDb()

                except:
                    print("There was an issue preventing data from being written to the database. Please contact your IT Administrator, local developer, and nearest coffee house.")

                update_dashboard('reset')
                return f'Dashboard updated: {solved}'
                 
                
            else:
                print("This version exceeded the run limit. Enter a different set of values, and try again")
                return ''     

#Radio filter callback, with different options to filter for learning algorithm params. 
#Reset filter removes other filters to display all data.
@app.callback(
    [Output('datatable-id','data'), Output('datatable-id','columns')],
    [Input('filter-type', 'value')])

def update_dashboard(filter_type):
    #provides the search query for learning rate parameters and the corresponding time to solve
    if filter_type == 'lr':
        df = pd.DataFrame.from_records(mongobject.read_all({
            "alpha": "alpha",
            "epsilon": "epsilon",
            "gamma": "gamma",
            "runs": "runs",
            "time": "time"
        }))
        #df = df.iloc[:,0:]

    #provides the search query for exploration parameter values and the corresponding time to solve
    elif filter_type == 'exp':
        df = pd.DataFrame.from_records(mongobject.read_all({
            "alpha": "alpha",
            "epsilon": "epsilon",
            "gamma": "gamma",
            "runs": "runs",
            "time": "time"
        }))
        #df = df.iloc[:,0:]

    #provides the search query for 
    elif filter_type == 'att':
        df = pd.DataFrame.from_records(mongobject.read_all({
            "alpha": "alpha",
            "epsilon": "epsilon",
            "gamma": "gamma",
            "runs": "runs",
            "time": "time"
        }))
        #df = df.iloc[:,0:]

    # Reset query filter by querying all the data without specifying any particulars
    elif filter_type == 'reset':
        df = pd.DataFrame.from_records(mongobject.read_all({}))
        df = df.iloc[:,0:]
        
    #even though this would be logically "impossible"...
    else: 
        return 0
    
    #store the columns and the queried data in the columns and data variable, and return to enable display
    columns=[{"name": i, "id": i, "deletable": False, "selectable": True} for i in df.columns]
    #df = df.iloc[:,0:]
    data=df.to_dict('records')
    
    return (data,columns)

#Highlight any column that is selected via checkboxes displayed in header cells
@app.callback(
    Output('datatable-id', 'style_data_conditional'),
    [Input('datatable-id', 'selected_columns')]
)
def update_style_c(selected_columns):
    return [{
        'if': { 'column_id': i },
        'background_color': '#D2F3FF'
    } for i in selected_columns]


if __name__ == '__main__':
    app.run_server(mode='inline', port=8890)

Dash is running on http://127.0.0.1:8890/



Please wait while the learning algorithm trains. Performance metrics will then be viewable.
Run: 1, exploration: 0.25, score: 10
Scores: (min: 10, avg: 10, max: 10)

Run: 2, exploration: 0.25, score: 16
Scores: (min: 10, avg: 13, max: 16)

Run: 3, exploration: 0.25, score: 11
Scores: (min: 10, avg: 12.333333333333334, max: 16)

Run: 4, exploration: 0.25, score: 21
Scores: (min: 10, avg: 14.5, max: 21)

Run: 5, exploration: 0.25, score: 11
Scores: (min: 10, avg: 13.8, max: 21)

Run: 6, exploration: 0.25, score: 18
Scores: (min: 10, avg: 14.5, max: 21)

Run: 7, exploration: 0.25, score: 26
Scores: (min: 10, avg: 16.142857142857142, max: 26)

Run: 8, exploration: 0.25, score: 18
Scores: (min: 10, avg: 16.375, max: 26)

Run: 9, exploration: 0.25, score: 20
Scores: (min: 10, avg: 16.77777777777778, max: 26)

Run: 10, exploration: 0.25, score: 26
Scores: (min: 10, avg: 17.7, max: 26)

This model was successful. Congratulations!
There was an issue preventing data from being written to the dat