In [1]:
#from jupyter_plotly_dash import JupyterDash
#Dashboard imports
from jupyter_dash import JupyterDash
from bson.objectid import ObjectId
import dash
import dash_leaflet as dl
from dash import dcc
from dash import html
import plotly.express as px
from dash import dash_table as dt
from dash.dependencies import Input, Output, State
import os
import numpy as np
import pandas as pd
from pymongo import MongoClient
from bson.json_util import dumps
import base64
from crud import MLDash

#Cartpole imports
import random  
import gym    
from collections import deque  
from keras.models import Sequential  
from keras.layers import Dense  
from keras.optimizers import Adam  
  
from dqn_solver import DQNSolver
from scores.score_logger import ScoreLogger  
    
###########################
# Machine Learning Model
###########################
os.system('set KMP_DUPLICATE_LIB_OK=TRUE')
ENV_NAME = "CartPole-v1"  

#Cartpole function
def cartpole(envName, learnRate, gamma, explMin, explMax, explDecay,  batchSize, memSize):  
    env = gym.make(envName)  
    score_logger = ScoreLogger(envName)  
    observation_space = env.observation_space.shape[0]  
    action_space = env.action_space.n  
    dqn_solver = DQNSolver(observation_space, action_space, explMax, memSize, learnRate)  
    run = 0  
    while True:  
        run += 1  
        state = env.reset() 
        state = state[0]
        state = np.reshape(state, (1, observation_space)) 
        step = 0  
        empty = None
        while True:  
            step += 1  
            #env.render()  
            action = dqn_solver.act(state)  
            state_next, reward, terminal, info, empty = env.step(action)  
            reward = reward if not terminal else -reward  
            state_next = np.reshape(state_next, (1, observation_space))   
            dqn_solver.remember(state, action, reward, state_next, terminal)  
            state = state_next  
            if terminal:  
                score_logger.add_record(step,run,dqn_solver.exploration_rate,gamma,learnRate,memSize,batchSize,explMax,explMin,explDecay)  
                break  
            dqn_solver.experience_replay(batchSize, gamma, explDecay, explMin)  
    return True

###########################
# Data Manipulation / Model
###########################

username = str('aiuser')
password = str('password')

mongobject = MLDash(username, password)


# class read method must support return of cursor object 
df = pd.DataFrame.from_records(mongobject.read_all({}))
df = df.iloc[:,0:]
                               
#########################
# Dashboard Layout / View
#########################
app = JupyterDash('__name__')

#reads a local image file and encodes it as a base64 string. 
#The open() function is used to open the file in binary mode, 
#and base64.b64encode() is used to encode the binary data as a base64 string.
#(displays images embedded in the app layout using the html.Img() component)
image_filename = 'logo.png' 
encoded_image = base64.b64encode(open(image_filename, 'rb').read())



app.layout = html.Div([
    html.Div(id='hidden-div', style={'display':'none'}),
    html.Center(html.B(html.H1('Machine Learning Dashboard'))),
    html.Center(html.Img(src='data:image/png;base64,{}'.format(encoded_image.decode()))),
    html.Center(html.B(html.H2("Developed by Lukas Mueller (2023)"))),
    html.Hr(),
    
    html.Div([
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Alpha:'),
                    className='col s12 m6',
                    style={'margin-right': '75px'}
                    ),
                html.Div(
                    dcc.Input(id='input-1', type='number', value= 0.001, placeholder="0>x>1, low"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Gamma:'),
                    className='col s12 m6',
                    style={'margin-right': '62px'}
                    ),
                html.Div(
                    dcc.Input(id='input-2', type='number', value= 0.95, placeholder="0>x>1, high"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Epsilon Min:'),
                    className='col s12 m6',
                    style={'margin-right': '33px'}
                    ),
                html.Div(
                    dcc.Input(id='input-3', type='number', value= 0.01, placeholder="0>x>1, low"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Epsilon Max:   '),
                    className='col s12 m6',
                    style={'margin-right': '31px'}
                    ),
                html.Div(
                    dcc.Input(id='input-4', type='number',  value= 1.0, placeholder="0>x>1, high"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Epsilon Decay: '),
                    className='col s12 m6',
                    style={'margin-right': '18px'}
                    ),
                html.Div(
                    dcc.Input(id='input-5', type='number', value= 0.95,placeholder="0>x>1, high"), 
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Batch Size:    '),
                    className='col s12 m6',
                    style={'margin-right': '43px'}
                    ),
                html.Div(
                    dcc.Input(id='input-6', type='number', value=20, placeholder="Default value: 20"),
                    className='col s12 m6',
                    )
             ]),
        
        html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                html.Div(
                    html.Label('Memory Size:   '),
                    className='col s12 m6',
                    style={'margin-right': '25px'}
                    ),
                html.Div(
                    dcc.Input(id='input-7', type='number', value=1000000, placeholder="Default value: 1000000"),
                    className='col s12 m6',
                    )
             ]),
        html.Button('Submit', id='submit-button', n_clicks=0),
        html.Div(id='output')
    ]),
    html.Hr(),
    html.Br(),
    
    #this row houses four radio buttons that are used to filter the data
    html.Div(className='row',
        style={'display': 'flex'},
            children=[
                dcc.RadioItems(
                    id="filter-type",
                    #Labels are provisioned for useful filtering of performance metrics
                    options=[#TODO: the selctions here require that stored record be parsed and 
                        #stored according to performance metric. This may violate best practice.
                       {'label': 'Learning Rate', 'value': 'lr'},
                       {'label': 'Exploration', 'value': 'exp'},
                       {'label': 'Experience Replay', 'value': 'rep'},
                       {'label': 'Reset', 'value': 'reset'}
                    ],
                    value='reset'
                )
            ]
    ),

    html.Hr(),
    #this is the layout for the data table
    dt.DataTable(
        id='datatable-id',
        columns=[
            {"name": i, "id": i, "deletable": False, "selectable": True} for i in df.columns
        ],
        
        #these are options which mostly provide for native interactivity with the table
        data=df.to_dict('records'),
        editable= False,
        filter_action="native",
        sort_action="native",
        sort_mode="multi",
        column_selectable= False,
        row_selectable= False,
        selected_columns=[],
        selected_rows=[],
        page_action="native",
        page_current=0,
        page_size=10,

        
        #Below is some simple styling for the data table to make it easier to look at...
        
        #table striping implemented to make visual tracking of a document's data easier
        style_cell_conditional=[
            {
                'if': {'column_id': c},
                'textAlign': 'left'
            } for c in ['Date', 'Region']
        ],
        style_data={
            'color': 'black',
            'backgroundColor': 'white'
        },
        style_data_conditional=[
            {
                'if': {'row_index': 'odd'},
                'backgroundColor': 'rgb(220, 220, 220)',
            }
        ],
        style_header={
            'backgroundColor': 'rgb(210, 210, 210)',
            'color': 'black',
            'fontWeight': 'bold'
        },
        #sizing for table cells, to increase presentability based on length of data
        style_cell={
            'minHeight': '16px', 'height': '16px', 'maxHeight': '16px',
            'minWidth': '160px', 'width': '160px', 'maxWidth': '160px',
            'whiteSpace': 'normal'
        },
    ),
    
    html.Br(),
    html.Hr(),
    
#This sets up the dashboard so that your charts are side-by-side 
    html.Div(className='row',
         style={'display' : 'flex'},
             children=[
                #this is the first chart that is displayed (correlation between lr and runs to solve)
                html.Div(
                    id='graph1-id',
                    className='col s12 m6',
                    ),
                #this is the second chart that is displayed (correlation between exp and runs to solve)
                html.Div(
                    id='graph2-id',
                    className='col s12 m6',
                    )
             ])
])
                               
                

#############################################
# Interaction Between Components / Controller
#############################################
@app.callback(Output('output', 'children'),
              [Input('submit-button', 'n_clicks')],
              [State('input-1', 'value'),
               State('input-2', 'value'),
               State('input-3', 'value'),
               State('input-4', 'value'),
               State('input-5', 'value'),
               State('input-6', 'value'),
               State('input-7', 'value'),
              ])
def update_output(n_clicks, input1, input2, input3, input4, input5, input6, input7 ):
    if n_clicks > 0:
        print("Please wait while the learning algorithm trains. Performance metrics will then be viewable.")
        done = cartpole(ENV_NAME, input1, input2, input3, input4, input5, input6, input7)
        return f'Training Complete: {done}' 
    else:
        return ''

#Radio filter callback, with different options to filter for learning algorithm params. 
#Reset filter removes other filters to display all data.
@app.callback(
    [Output('datatable-id','data'), Output('datatable-id','columns')],
    [Input('filter-type', 'value')])

def update_dashboard(filter_type):
    #provides the search query for learning rate parameters and the corresponding time to solve
    if filter_type == 'lr':
        df = pd.DataFrame.from_records(mongobject.read_all({
            "alpha": "alpha",
            "epsilon": "epsilon",
            "gamma": "gamma",
            "runs": "runs",
            "time": "time"
        }))
        #df = df.iloc[:,0:]

    #provides the search query for exploration parameter values and the corresponding time to solve
    elif filter_type == 'exp':
        df = pd.DataFrame.from_records(mongobject.read_all({
            "alpha": "alpha",
            "epsilon": "epsilon",
            "gamma": "gamma",
            "runs": "runs",
            "time": "time"
        }))
        #df = df.iloc[:,0:]

    #provides the search query for 
    elif filter_type == 'rep':
        df = pd.DataFrame.from_records(mongobject.read_all({
            "alpha": "alpha",
            "epsilon": "epsilon",
            "gamma": "gamma",
            "runs": "runs",
            "time": "time"
        }))
        #df = df.iloc[:,0:]

    # Reset query filter by querying all the data without specifying any particulars
    elif filter_type == 'reset':
        df = pd.DataFrame.from_records(mongobject.read_all({}))
        df = df.iloc[:,0:]
        
    #even though this would be logically "impossible"...
    else: 
        return 0
    
    #store the columns and the queried data in the columns and data variable, and return to enable display
    columns=[{"name": i, "id": i, "deletable": False, "selectable": True} for i in df.columns]
    #df = df.iloc[:,0:]
    data=df.to_dict('records')
    
    return (data,columns)

#Highlight any column that is selected via checkboxes displayed in header cells
@app.callback(
    Output('datatable-id', 'style_data_conditional'),
    [Input('datatable-id', 'selected_columns')]
)
def update_style_c(selected_columns):
    return [{
        'if': { 'column_id': i },
        'background_color': '#D2F3FF'
    } for i in selected_columns]


if __name__ == '__main__':
    app.run_server(mode='inline', port=8890)

Dash is running on http://127.0.0.1:8890/



Please wait while the learning algorithm trains. Performance metrics will then be viewable.


Run: 1, exploration: 0.5688000922764596, score: 31
Scores: (min: 31, avg: 31, max: 31)



Run: 2, exploration: 0.34056162628811465, score: 11
Scores: (min: 11, avg: 21, max: 31)





Run: 3, exploration: 0.17482461472379698, score: 14
Scores: (min: 11, avg: 18.666666666666668, max: 31)



Run: 4, exploration: 0.11598222130000553, score: 9
Scores: (min: 9, avg: 16.25, max: 31)



Run: 5, exploration: 0.07309772651287748, score: 10
Scores: (min: 9, avg: 15, max: 31)



Run: 6, exploration: 0.04849452524942309, score: 9
Scores: (min: 9, avg: 14, max: 31)



Run: 7, exploration: 0.030563645913324056, score: 10
Scores: (min: 9, avg: 13.428571428571429, max: 31)



Run: 8, exploration: 0.019262719795904448, score: 10
Scores: (min: 9, avg: 13, max: 31)



Run: 9, exploration: 0.012140317781059323, score: 10
Scores: (min: 9, avg: 12.666666666666666, max: 31)





Run: 10, exploration: 0.01, score: 9
Scores: (min: 9, avg: 12.3, max: 31)



Run: 11, exploration: 0.01, score: 9
Scores: (min: 9, avg: 12, max: 31)



Run: 12, exploration: 0.01, score: 9
Scores: (min: 9, avg: 11.75, max: 31)



Run: 13, exploration: 0.01, score: 9
Scores: (min: 9, avg: 11.538461538461538, max: 31)



Run: 14, exploration: 0.01, score: 10
Scores: (min: 9, avg: 11.428571428571429, max: 31)



Run: 15, exploration: 0.01, score: 9
Scores: (min: 9, avg: 11.266666666666667, max: 31)



Run: 16, exploration: 0.01, score: 13
Scores: (min: 9, avg: 11.375, max: 31)





Run: 17, exploration: 0.01, score: 16
Scores: (min: 9, avg: 11.647058823529411, max: 31)







Run: 18, exploration: 0.01, score: 23
Scores: (min: 9, avg: 12.277777777777779, max: 31)





Run: 19, exploration: 0.01, score: 19
Scores: (min: 9, avg: 12.631578947368421, max: 31)





Run: 20, exploration: 0.01, score: 17
Scores: (min: 9, avg: 12.85, max: 31)















Run: 21, exploration: 0.01, score: 53
Scores: (min: 9, avg: 14.761904761904763, max: 53)



Run: 22, exploration: 0.01, score: 9
Scores: (min: 9, avg: 14.5, max: 53)









Run: 23, exploration: 0.01, score: 32
Scores: (min: 9, avg: 15.26086956521739, max: 53)















Run: 24, exploration: 0.01, score: 59
Scores: (min: 9, avg: 17.083333333333332, max: 59)











Run: 25, exploration: 0.01, score: 39
Scores: (min: 9, avg: 17.96, max: 59)























