In [1]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jupyter_dash import JupyterDash
from dash import html, dcc
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import pandas as pd
import io
import base64
import layouts
import utilities
import numpy as np
import datetime
import cv2
import onnx
import onnxruntime
from keras.preprocessing import image
import json
from google.protobuf.json_format import MessageToJson
from google.protobuf.json_format import Parse
from scipy.special import softmax
import dianna
from dianna import visualization
import matplotlib.pyplot as plt
from plotly.tools import mpl_to_plotly
import plotly.graph_objects as go
import os
from onnx_tf.backend import prepare
from flask_caching import Cache

# Bind some folder on server for storing the data.
folder_on_server = "app_data"
os.makedirs(folder_on_server, exist_ok=True)

# Build App
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = JupyterDash(__name__, external_stylesheets=external_stylesheets, prevent_initial_callbacks=True)

In [2]:
# debug
from importlib import reload
reload(layouts)
reload(utilities)

<module 'utilities' from '/Users/giuliacrocioni/Desktop/eScience/projects_2022/DIANNA/dianna/dashboard/utilities.py'>

In [3]:
# colors
colors = {
    'white': '#FFFFFF',
    'text': '#091D58',
    'blue1' : '#063446', #dark blue
    'blue2' : '#0e749b',
    'blue3' : '#15b3f0',
    'blue4' : '#E4F3F9', #light blue
    'yellow1' : '#f0d515'
}

navbarcurrentpage = {
    'text-decoration' : 'underline',
    'text-decoration-color' : colors['yellow1'],
    'color' : colors['white'],
    'text-shadow': '0px 0px 1px rgb(251, 251, 252)',
    'textAlign' : 'center'
    }

navbarotherpage = {
    'text-decoration' : 'underline',
    'text-decoration-color' : colors['blue2'],
    'color' : colors['white'],
    'textAlign' : 'center'
    }

In [4]:
app.layout = html.Div([ # header with logo
    
    #Row 1 : Header
    layouts.get_header(),
    #Row 2 : Nav bar
    layouts.get_navbar(),

    layouts.get_uploads(), 

    # dcc.Store(id='image_to_run'),

    # dcc.Store(id='model_to_run'),

    #Selection bar
    #layouts.get_method_sel(),

    #Prediction plot
    #layouts.get_pred(),

    #Explaination/s plot/s
    #layouts.get_method_plot()
    
    ])

@app.callback(Output('output-model-upload', 'children'),
              #Output('model_to_run', 'data'),
              Input('upload-model', 'contents'),
              State('upload-model', 'filename'))
def upload_model(contents, filename):
    if contents is not None:
        try:
            if 'onnx' in filename[0]:

                content_type, content_string = contents[0].split(',')

                #print(f'contents have len {len(contents[0])}')
                #print(f'contents have type {type(contents[0])}')
                #print(f'filename has len {len(filename[0])}')
                #print(f'filename has type {type(filename[0])}')
                #print(f'filename first chars {filename[0][:50]}')

                with open(os.path.join(folder_on_server, filename[0]), 'wb') as f:
                    f.write(base64.b64decode(content_string))

                # decoded = base64.b64decode(content_string)
                # model_bytes = io.BytesIO(decoded)
                # model = onnx.load(model_bytes)
                # s = MessageToJson(model)
                # onnx_json = json.loads(s)
                # onnx_str = json.dumps(onnx_json)

                children = [
                    utilities.parse_contents_model(c, n) for c, n in
                    zip(contents, filename)]

                return children #, onnx_str
            else:
                return html.Div(['File format error, please upload only models in .onnx format.']) #, None
        except Exception as e:
            print(e)
            return html.Div(['There was an error processing this file.']) #, None
    else:
        raise PreventUpdate

@app.callback(Output('graph_test', 'figure'), 
              #Output('output-image-upload', 'children'),
              Input('upload-image', 'contents'),
              State('upload-image', 'filename'))
def upload_image(contents, filename):
    if contents is not None:

        try:

            if 'npz' in filename[0]:

                content_type, content_string = contents[0].split(',')

                with open(os.path.join(folder_on_server, filename[0]), 'wb') as f:
                    f.write(base64.b64decode(content_string))

                data_path = os.path.join(folder_on_server, filename[0])
                data = np.load(data_path)

                X_test = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1]) / 255

                X_test = X_test[3]

                fig = go.Figure()

                fig.add_trace(go.Heatmap(z=X_test[:,:,0], colorscale='gray', showscale=False))

                fig.update_layout(
                    width=300,
                    height=300)

                fig.update_xaxes(showgrid = False, showticklabels = False, zeroline=False)
                fig.update_yaxes(showgrid = False, showticklabels = False, zeroline=False)

                fig.layout.paper_bgcolor = colors['blue4']

                children = [
                    utilities.parse_contents_model(c, n) for c, n in
                    zip(contents, filename)]

        # try:
            # if 'jpg' in filename[0]:

            #     print(type(contents))
            #     print(len(contents))

            #     content_type, content_string = contents[0].split(',')
            #     print(content_type)
            #     binary = base64.b64decode(content_string)
            #     image = np.asarray(bytearray(binary), dtype="uint8")
            #     #image = cv2.imdecode(image, cv2.IMREAD_COLOR)
            #     image = cv2.imdecode(image, cv2.IMREAD_GRAYSCALE)
            #     print(type(image))
            #     print(image.shape)

            #     children = [
            #         utilities.parse_contents_image(c, n) for c, n in
            #         zip(contents, filename)]

                return fig#, children
            else:
                return None #html.Div(['File format error, please upload only images in .npz format.'])
        except Exception:
            return  None #html.Div(['There was an error processing this file.'])

# def run_model(data, model):
#     # get ONNX predictions
#     sess = onnxruntime.InferenceSession(model.SerializeToString())
#     input_name = sess.get_inputs()[0].name
#     output_name = sess.get_outputs()[0].name
    
#     onnx_input = {input_name: data}
#     pred_onnx = sess.run([output_name], onnx_input)
    
#     return softmax(pred_onnx[0], axis=1)

def _determine_vmax(max_data_value):
    vmax = 1
    if max_data_value > 255:
        vmax = None
    elif max_data_value > 1:
        vmax = 255
    return vmax

@app.callback(
    Output('output-state', 'children'),
    Output('graph', 'figure'),
    Input("method_sel", "value"),
    #State('model_to_run', 'data'),
    State("upload-model", "filename"),
    State("upload-image", "filename"),
)
def update_multi_options(sel_methods, fn_m, fn_i): #(sel_methods, onnx_str, fn):
    if sel_methods is None:
        raise PreventUpdate
    # Make sure that the set values are in the option list, else they will disappear
    # from the shown select list, but still part of the `value`.

    else:

        if (fn_m and fn_i) is not None:

            try:

                class_name = ['digit 0', 'digit 1']

                #data = np.load('../tutorials/data/binary-mnist.npz')
                data_path = os.path.join(folder_on_server, fn_i[0])
                data = np.load(data_path)
                X_test = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1]) / 255

                onnx_model_path = os.path.join(folder_on_server, fn_m[0])
                onnx_model = onnx.load(onnx_model_path)
                # get the output node
                output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]
                predictions = prepare(onnx_model).run(X_test[3][None, ...])[f'{output_node}']
                pred_class = class_name[np.argmax(predictions)]

                #model = Parse(onnx_str, onnx.ModelProto())

                #pred_onnx = run_model(X_test, model)

                #pred_class = class_name[np.argmax(pred_onnx[3])]

                print("The predicted class is:", pred_class)

                #onnx_model = onnx.load(os.path.join(folder_on_server, fn[0]))

                #print(type(onnx_model))

                onnx_model_path = os.path.join(folder_on_server, fn_m[0])

                # RISE
                #relevances_rise = dianna.explain_image(lambda x: run_model(x, model), X_test[3], method="RISE",
                # relevances_rise = dianna.explain_image(onnx_model_path, X_test[3], method="RISE",
                #                                 labels=[i for i in range(2)],
                #                                 n_masks=5000, feature_res=8, p_keep=.1,
                #                                 axis_labels=('height','width','channels'))
                
                # KernelSHAP
                # shap_values, segments_slic = dianna.explain_image(onnx_model_path, X_test[3],
                #                                   method="KernelSHAP", nsamples=1000,
                #                                   background=0, n_segments=200, sigma=0,
                #                                   axis_labels=('height','width','channels'))

                # LIME
                # relevances_lime = dianna.explain_image(onnx_model_path, X_test[3] * 256, 'LIME',
                #                           axis_labels=('height','width','channels'),
                #                           random_state=2,
                #                           labels=[i for i in range(2)],
                #                           preprocess_function=utilities.preprocess_function)

                

                n_rows = len(class_name)

                fig = make_subplots(rows=n_rows, cols=3, subplot_titles=("RISE", "KernelShap", "LIME"))#, horizontal_spacing = 0.05)

                for m in sel_methods:

                    for i in range(n_rows):

                        fig.update_yaxes(title_text=class_name[i], row=i+1, col=1)

                        if m == "RISE":

                            relevances_rise = dianna.explain_image(
                                onnx_model_path, X_test[3], method="RISE",
                                labels=[i for i in range(2)],
                                n_masks=5000, feature_res=8, p_keep=.1,
                                axis_labels=('height','width','channels'))

                            # RISE plot
                            fig.add_trace(go.Heatmap(
                                                z=X_test[3][:,:,0], colorscale='gray', showscale=False), i+1, 1)
                            fig.add_trace(go.Heatmap(
                                                z=relevances_rise[i], colorscale='Bluered',
                                                showscale=False, opacity=0.7), i+1, 1)

                        elif m == "KernelSHAP":

                            shap_values, segments_slic = dianna.explain_image(
                                onnx_model_path, X_test[3],
                                method="KernelSHAP", nsamples=1000,
                                background=0, n_segments=200, sigma=0,
                                axis_labels=('height','width','channels'))

                            # KernelSHAP plot
                            fig.add_trace(go.Heatmap(
                                            z=X_test[3][:,:,0], colorscale='gray', showscale=False), i+1, 2)
                            fig.add_trace(go.Heatmap(
                                            z=utilities.fill_segmentation(shap_values[i][0], segments_slic),
                                            colorscale='Bluered',
                                            showscale=False,
                                            opacity=0.7), i+1, 2)
                        else:
                            
                            relevances_lime = dianna.explain_image(
                                onnx_model_path, X_test[3] * 256, 'LIME',
                                axis_labels=('height','width','channels'),
                                random_state=2,
                                labels=[i for i in range(2)],
                                preprocess_function=utilities.preprocess_function)

                            # LIME plot
                            fig.add_trace(go.Heatmap(
                                                z=X_test[3][:,:,0], colorscale='gray', showscale=False), i+1, 3)

                            fig.add_trace(go.Heatmap(
                                                z=relevances_lime[i], colorscale='Bluered',
                                                showscale=False, opacity=0.7), i+1, 3)

                fig.update_layout(
                    width=650,
                    height=500,
                    paper_bgcolor=colors['blue4'])

                return html.Div(['The predicted class is: ' + pred_class], style = {
                    'fontSize': 14,
                    'margin-top': '20px',
                    'margin-right': '40px'
                    }), fig

            except Exception as e:
                print(e)
                return html.Div(['There was an error running the model. Check either the test image or the model.']), None
        else:
            return html.Div(['Either model format is not correct or no model was uploaded. Please upload models in .onnx format.']), None

app.run_server(mode='external', port=8050)

Dash app running on http://127.0.0.1:8050/
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



2022-04-01 11:30:22.434940: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-04-01 11:30:22.435080: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2022-04-01 11:30:24.792381: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-04-01 11:30:24.792437: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


The predicted class is: digit 0


Explaining: 100%|██████████| 50/50 [00:00<00:00, 59.79it/s]
Explaining: 100%|██████████| 50/50 [00:00<00:00, 62.16it/s]
