In [50]:
import joblib
from dash import Dash, dcc, html, Input, Output, State, callback
import dash
import shap
import matplotlib.pyplot as plt
import base64
import io
import pandas as pd
import plotly.graph_objects as go


In [18]:
mapping_dict = {
    'XGBoost': joblib.load("./saved_values/xgboost-shapley_values"),
    'Random Forest': joblib.load("./saved_values/randomforest-shapley_values"),
}

def model_name_to_explanation(value):
    return mapping_dict.get(value)

In [19]:
def string_target_to_integer_target(value):
    mapping_dict = {
        'very_high' :0,
        'high' :1,
        'moderate' :2,
        'low' :3,
        'very_low' :4
    }

    return mapping_dict.get(value)

In [20]:
model_list = ['XGBoost', 'Random Forest']

In [53]:
app = Dash(__name__)
app.layout = html.Div([
    html.H1("Shapley Values Dashboard"),
    dcc.Dropdown(model_list, placeholder='Select Model', id='dropdown_model'),
    dcc.Dropdown(['very_high', 'high', 'moderate', 'low', 'very_low'], placeholder='Select Target', id='dropdown_target'),
    dcc.Input(type='number', placeholder='Enter Start', id='input_start'),
    dcc.Input(type='number', placeholder='Enter End', id='input_end'),
    dcc.Input(type='number', placeholder='Enter Steps', id='input_steps'),
    html.Br(),
    html.Button('Generate Plot', id='btn_generate_plot'),
    html.Div(id='dd-output-container')
])


@callback(
    Output('dd-output-container', 'children'),
    State('dropdown_model', 'value'),           # State instead of Input is important for button
    State('dropdown_target', 'value'), 
    State('input_start', 'value'),
    State('input_end', 'value'),
    State('input_steps', 'value'),
    Input('btn_generate_plot', 'n_clicks')
)
def generate_shapley_plots(model, target, start, end, steps, n_clicks):
    if n_clicks is None:
        return dash.no_update # Do nothing if the button is not clicked
    
    explanation = model_name_to_explanation(model)
    target = string_target_to_integer_target(target)
    
    if explanation is not None and target is not None:
        fig, ax = plt.subplots()
        ax.cla() # Clear the previous plot

        shap.plots.beeswarm(explanation[:, start:end:steps, target], show=False)
        
        # Convert Matplotlib figure to base64-encoded string
        img_buf = io.BytesIO()
        plt.savefig(img_buf, format='svg')

        plt.clf() # Clear the entire figure
        plt.close() # Close the figure to release memory

        img_buf.seek(0)
        img_str = "data:image/svg+xml;base64," + base64.b64encode(img_buf.read()).decode('utf-8')

        return html.Img(src=img_str)
    else:
        return


if __name__ == '__main__':
    app.run(debug=True)