In [1]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jupyter_dash import JupyterDash
from dash import html, dcc
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import pandas as pd
import io
import base64
import layouts
import numpy as np
import datetime
import cv2
import onnx
import onnxruntime
from keras.preprocessing import image
import json
from google.protobuf.json_format import MessageToJson
from google.protobuf.json_format import Parse
from scipy.special import softmax
import dianna
from dianna import visualization
import matplotlib.pyplot as plt
from plotly.tools import mpl_to_plotly
import plotly.graph_objects as go

# Build App
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = JupyterDash(__name__, external_stylesheets=external_stylesheets, prevent_initial_callbacks=True)

In [2]:
# debug
from importlib import reload
reload(layouts)

<module 'layouts' from '/Users/giuliacrocioni/Desktop/eScience/projects_2022/DIANNA/dianna/dashboard/layouts.py'>

In [3]:
colors = {
    'white': '#FFFFFF',
    'text': '#091D58',
    'blue1' : '#063446',
    'blue2' : '#0e749b',
    'blue3' : '#15b3f0',
    'blue4' : '#d0f0fc',
    'yellow1' : '#f0d515'
}

navbarcurrentpage = {
    'text-decoration' : 'underline',
    'text-decoration-color' : colors['yellow1'],
    'color' : colors['white'],
    'text-shadow': '0px 0px 1px rgb(251, 251, 252)',
    'textAlign' : 'center'
    }

navbarotherpage = {
    'text-decoration' : 'underline',
    'text-decoration-color' : colors['blue2'],
    'color' : colors['white'],
    'textAlign' : 'center'
    }

In [4]:
def blank_fig():
    fig = go.Figure(go.Scatter(x=[], y = []))
    fig.update_layout(template = None, paper_bgcolor=colors['blue4'])
    fig.update_xaxes(showgrid = False, showticklabels = False, zeroline=False)
    fig.update_yaxes(showgrid = False, showticklabels = False, zeroline=False)
    
    return fig

In [5]:
#rise = base64.b64encode(open('rise.png', 'rb').read())
#kernels = base64.b64encode(open('kernels.png', 'rb').read())

In [6]:
app.layout = html.Div([ # header with logo
    
    #Row 1 : Header
    layouts.get_header(),
    #Row 2 : Nav bar
    layouts.get_navbar(),

    layouts.get_uploads(), 

    #dcc.Store(id='image_to_run'),
    dcc.Store(id='model_to_run'), 

    html.Div([
        html.Div(['b'],
            className = 'five columns',
            style = {'color' : colors['blue4']}),
        html.Div([
            dcc.Dropdown(id = 'method_sel',
                options = [{'label': 'RISE', 'value': 'RISE'},
                           {'label': 'KernelSHAP', 'value': 'KernelSHAP'},
                           {'label': 'LIME', 'value': 'LIME'}],
                placeholder = "Select one/more methods",
                value=[],
                multi = True
            ),
            #html.Button(
            #    id='submit-val',
            #    children = html.Div(['Get explanation']), n_clicks=0,
            #    style={
            #        'width': '100%',
            #        #'height': '40px',
            #        #'lineHeight': '40px',
            #        'borderWidth': '1px',
            #        #'borderStyle': 'dashed',
            #        'borderRadius': '3px',
            #        'textAlign': 'center',
            #        'align-items': 'center',
            #        'margin': '10px',
            #        'color' : colors['white'],
            #        'background-color' : colors['blue2']}
            #    )
            ], 
            className = 'two columns')
    ], className = 'row', style = {
        'textAlign': 'center',
        'background-color' : colors['blue4'],
        'align-items': 'center'}),

    html.Div([
        html.Div(['b'],
            className = 'five columns',
            style = {'color' : colors['blue4']}),
        html.Div(id='output-state', className = 'two columns')
        ], className = 'row', style = {
        'textAlign': 'center',
        'background-color' : colors['blue4'],
        'align-items': 'center'}),

    html.Div([
        html.Div(['b'],
            className = 'five columns',
            style = {'color' : colors['blue4']}),
        html.Div([dcc.Graph(id='rise', figure = blank_fig())], className = 'two columns'),
        ],
        className="row",
        style={
            'background-color' : colors['blue4'],
            'textAlign': 'center',
            'align-items': 'center',
            'verical-align': 'center'}),

    # html.Div([

    #     html.Div(['b'],
    #         className = 'three columns',
    #         style = {'color' : colors['blue4']}),

    #     html.Div([    
    #         html.Img(
    #                 src = 'data:image/png;base64,{}'.format(rise.decode()),
    #                 height = '250 px',
    #                 width = 'auto')
    #     ],
    #     className = 'three columns',
    #     style = {
    #             'textAlign': 'center',
    #             'padding-top' : '1.3%',
    #             'padding-right' : '4%',
    #             'height' : 'auto'
    #             }),
    # html.Div([
    #     html.Img(
    #             src = 'data:image/png;base64,{}'.format(kernels.decode()),
    #             height = '250 px',
    #             width = 'auto')
    #     ],
    #     className = 'two columns',
    #     style = {
    #             'textAlign': 'center',
    #             'padding-top' : '1.3%',
    #             'padding-right' : '4%',
    #             'height' : 'auto'
    #             })

    # ],
    # className = 'row', style = {'background-color' : colors['blue4']}
    # )
    ])

def parse_contents_image(contents, filename):
    return html.Div([
        html.H5(filename + ' loaded'),
        # HTML images accept base64 encoded strings in the same format
        # that is supplied by the upload
        html.Img(src=contents, height = '160 px', width = 'auto')
    ])

def parse_contents_model(contents, filename):
    return html.Div([
        html.H5(filename + ' loaded')
    ])

@app.callback(Output('output-model-upload', 'children'),
              Output('model_to_run', 'data'),
              Input('upload-model', 'contents'),
              State('upload-model', 'filename'))
def upload_model(contents, filename):
    if contents is not None:
        try:
            if 'onnx' in filename[0]:

                content_type, content_string = contents[0].split(',')
                decoded = base64.b64decode(content_string)
                model_bytes = io.BytesIO(decoded)
                model = onnx.load(model_bytes)
                s = MessageToJson(model)
                onnx_json = json.loads(s)
                onnx_str = json.dumps(onnx_json)

                children = [
                    parse_contents_model(c, n) for c, n in
                    zip(contents, filename)]

                return children, onnx_str
            else:
                return html.Div(['File format error, please upload only models in .onnx format.']), None
        except Exception as e:
            print(e)
            return html.Div(['There was an error processing this file.']), None
    else:
        raise PreventUpdate

@app.callback(Output('output-image-upload', 'children'),
              Input('upload-image', 'contents'),
              State('upload-image', 'filename'))
def upload_image(contents, filename):
    if contents is not None:
        print(filename)
        try:
            if 'jpg' in filename[0]:

                print(type(contents))
                print(len(contents))

                content_type, content_string = contents[0].split(',')
                print(content_type)
                binary = base64.b64decode(content_string)
                image = np.asarray(bytearray(binary), dtype="uint8")
                #image = cv2.imdecode(image, cv2.IMREAD_COLOR)
                image = cv2.imdecode(image, cv2.IMREAD_GRAYSCALE)
                print(type(image))
                print(image.shape)

                children = [
                    parse_contents_image(c, n) for c, n in
                    zip(contents, filename)]

                return children
            else:
                return html.Div(['File format error, please upload only images in .jpg format.'])
        except Exception:
            return html.Div(['There was an error processing this file.'])

def run_model(data, model):
    # get ONNX predictions
    sess = onnxruntime.InferenceSession(model.SerializeToString())
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name
    
    onnx_input = {input_name: data}
    pred_onnx = sess.run([output_name], onnx_input)
    
    return softmax(pred_onnx[0], axis=1)

def _determine_vmax(max_data_value):
    vmax = 1
    if max_data_value > 255:
        vmax = None
    elif max_data_value > 1:
        vmax = 255
    return vmax

###########

@app.callback(
    Output('output-state', 'children'),
    Output('rise', 'figure'),
    Input("method_sel", "value"),
    State('model_to_run', 'data')
)
def update_multi_options(sel_methods, onnx_str):
    if not sel_methods:
        raise PreventUpdate
    # Make sure that the set values are in the option list, else they will disappear
    # from the shown select list, but still part of the `value`.

    else:

        if onnx_str is not None:

            try:
                data = np.load('../tutorials/data/binary-mnist.npz')
                X_test = data['X_test'].astype(np.float32).reshape([-1, 1, 28, 28])

                model = Parse(onnx_str, onnx.ModelProto())

                pred_onnx = run_model(X_test, model)

                class_name = ['digit 0', 'digit 1']
                pred_class = class_name[np.argmax(pred_onnx[3])]

                print("The predicted class is:", pred_class)

                relevances = dianna.explain_image(lambda x: run_model(x, model), X_test[3], method="RISE",
                                                labels=[i for i in range(2)],
                                                n_masks=5000, feature_res=8, p_keep=.1,
                                                axis_labels=('channels','height','width'))

                fig = make_subplots(rows=2, cols=1)

                fig.add_trace(go.Heatmap(
                                    z=X_test[3][0], colorscale='gray', showscale=False), 1, 1)

                fig.add_trace(go.Heatmap(
                                    z=relevances[0], colorscale='Bluered', showscale=False, opacity=0.5), 1, 1)

                fig.add_trace(go.Heatmap(
                                    z=X_test[3][0], colorscale='gray', showscale=False), 2, 1)

                fig.add_trace(go.Heatmap(
                                    z=relevances[1], colorscale='Bluered', showscale=False, opacity=0.5), 2, 1)

                fig.update_yaxes(title_text=class_name[0], row=1, col=1)
                fig.update_yaxes(title_text=class_name[1], row=2, col=1)

                fig.update_layout(width=300, height=500, title_text="RISE", title_x=0.5, paper_bgcolor=colors['blue4'])

                return html.Div(['The predicted class is: ' + pred_class], style = {'fontSize': 14}), fig

            except Exception as e:
                print(e)
                return html.Div(['There was an error running the model. Check either the test image or the model.']), None
        else:
            return html.Div(['Either model format is not correct or no model was uploaded. Please upload models in .onnx format.']), None

###########

#@app.callback(Output('output-state', 'children'),
#              Output('rise', 'figure'),
#              Input('submit-val', 'n_clicks'),
#              State('model_to_run', 'data'))
# def update_output(n_clicks, onnx_str):

#     if n_clicks > 0 and onnx_str is not None:

#         try:
#             data = np.load('./data/binary-mnist.npz')
#             X_test = data['X_test'].astype(np.float32).reshape([-1, 1, 28, 28])

#             model = Parse(onnx_str, onnx.ModelProto())

#             pred_onnx = run_model(X_test, model)

#             class_name = ['digit 0', 'digit 1']
#             pred_class = class_name[np.argmax(pred_onnx[3])]

#             print("The predicted class is:", pred_class)

#             relevances = dianna.explain_image(lambda x: run_model(x, model), X_test[3], method="RISE",
#                                             labels=[i for i in range(2)],
#                                             n_masks=5000, feature_res=8, p_keep=.1,
#                                             axis_labels=('channels','height','width'))

#             fig = make_subplots(rows=2, cols=1)

#             fig.add_trace(go.Heatmap(
#                                 z=X_test[3][0], colorscale='gray', showscale=False), 1, 1)

#             fig.add_trace(go.Heatmap(
#                                 z=relevances[0], colorscale='Bluered', showscale=False, opacity=0.5), 1, 1)

#             fig.add_trace(go.Heatmap(
#                                 z=X_test[3][0], colorscale='gray', showscale=False), 2, 1)

#             fig.add_trace(go.Heatmap(
#                                 z=relevances[1], colorscale='Bluered', showscale=False, opacity=0.5), 2, 1)

#             fig.update_yaxes(title_text=class_name[0], row=1, col=1)
#             fig.update_yaxes(title_text=class_name[1], row=2, col=1)

#             fig.update_layout(width=300, height=500, title_text="RISE", title_x=0.5)

#             return html.Div([html.H5('The predicted class is: ' + pred_class + '\nExplainations:')]), fig

#         except Exception as e:
#             print(e)
#             return html.Div(['There was an error running the model. Check either the test image or the model.']), None
    
#     elif n_clicks > 0 and onnx_str is None:
#         return html.Div(['Either model format is not correct or no model was uploaded. Please upload models in .onnx format.']), None
    
#     else:
#         return None, None


app.run_server(mode='external', port=8051)

Dash app running on http://127.0.0.1:8051/
The predicted class is: digit 0


Explaining: 100%|██████████| 50/50 [00:00<00:00, 147.04it/s]


['digit0.jpg']
<class 'list'>
1
data:image/jpeg;base64
<class 'numpy.ndarray'>
(28, 28)
