In [1]:
# Plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Dash&Flask
from jupyter_dash import JupyterDash
from dash import html, dcc
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
from flask_caching import Cache
# Onnx
import onnx
import onnxruntime
from onnx_tf.backend import prepare
# Others
import os
import base64
import layouts
import utilities
import numpy as np
import dianna
import warnings
warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf

In [2]:
# Bind some folder on server for storing the data.
folder_on_server = "app_data"
os.makedirs(folder_on_server, exist_ok=True)

# Build App
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = JupyterDash(__name__, external_stylesheets=external_stylesheets, prevent_initial_callbacks=True)

# Caching
cache = Cache(app.server, config={
    'CACHE_TYPE': 'filesystem',
    'CACHE_DIR': 'cache'
})
cache.clear()

# global variables
class_name_mnist = ['digit 0', 'digit 1']

In [3]:
# debug
from importlib import reload
reload(layouts)
reload(utilities)

<module 'utilities' from '/Users/giuliacrocioni/Desktop/eScience/projects_2022/DIANNA/dianna/dashboard/utilities.py'>

In [4]:
app.layout = html.Div([
    
    #Row 1 : Header
    layouts.get_header(),
    #Row 2 : Nav bar
    layouts.get_navbar(),

    layouts.get_uploads(),

    # hidden signal value
    dcc.Store(id='signal'),
    
    ])

@app.callback(Output('output-model-upload', 'children'),
              Input('upload-model', 'contents'),
              State('upload-model', 'filename'))
def upload_model(contents, filename):
    if contents is not None:
        try:
            if 'onnx' in filename[0]:

                content_type, content_string = contents[0].split(',')

                with open(os.path.join(folder_on_server, filename[0]), 'wb') as f:
                    f.write(base64.b64decode(content_string))

                children = [
                    utilities.parse_contents_model(c, n) for c, n in
                    zip(contents, filename)]

                return children
            else:
                return html.Div([
                    html.P('File format error!'),
                    html.Br(),
                    html.P('Please upload only models in .onnx format.')
                    ])
        except Exception as e:
            print(e)
            return html.Div(['There was an error processing this file.'])
    else:
        raise PreventUpdate

@app.callback(Output('graph_test', 'figure'), 
              Input('upload-image', 'contents'),
              State('upload-image', 'filename'))
def upload_image(contents, filename):
    if contents is not None:

        try:

            if 'jpg' in filename[0]:

                content_type, content_string = contents[0].split(',')

                with open(os.path.join(folder_on_server, filename[0]), 'wb') as f:
                    f.write(base64.b64decode(content_string))

                data_path = os.path.join(folder_on_server, filename[0])

                X_test = utilities.open_image(data_path)

                fig = go.Figure()

                if X_test.shape[2] < 3: # it's grayscale

                    fig.add_trace(go.Heatmap(z=X_test[:,:,0], colorscale='gray', showscale=False))

                fig.update_layout(
                    width=300,
                    height=300,
                    title=f"{filename[0]} uploaded",
                    title_x=0.5,
                    title_font_color=layouts.colors['blue1'])

                fig.update_xaxes(showgrid = False, showticklabels = False, zeroline=False)
                fig.update_yaxes(showgrid = False, showticklabels = False, zeroline=False)

                fig.layout.paper_bgcolor = layouts.colors['blue4']

                return fig
            else:
                return utilities.blank_fig(
                    text='File format error! <br><br>Please upload only images in .jpg format.')
        
        except Exception as e:
            print(e)
            return utilities.blank_fig(
                    text='There was an error processing this file.')

# perform expensive computations in this "global store"
# these computations are cached in a globally available
# redis memory store which is available across processes
# and for all time.
@cache.memoize()
def global_store(method_sel, model_path, image_test):
    # expensive query
    if method_sel == "RISE":
        relevances = dianna.explain_image(
            model_path, image_test, method=method_sel,
            labels=[i for i in range(2)],
            n_masks=5000, feature_res=8, p_keep=.1,
            axis_labels=('height','width','channels'))

    elif method_sel == "KernelSHAP":
        relevances = dianna.explain_image(
            model_path, image_test,
            method=method_sel, nsamples=1000,
            background=0, n_segments=200, sigma=0,
            axis_labels=('height','width','channels'))

    else:
        relevances = dianna.explain_image(
            model_path, image_test * 256, 'LIME',
            axis_labels=('height','width','channels'),
            random_state=2,
            labels=[i for i in range(2)],
            preprocess_function=utilities.preprocess_function)

    return relevances

@app.callback(
    Output('signal', 'data'),
    [Input('method_sel', 'value'),
    State("upload-model", "filename"),
    State("upload-image", "filename"),
    ])
def compute_value(method_sel, fn_m, fn_i):
    if method_sel is None:
        raise PreventUpdate
    else:
        for m in method_sel:
            # compute value and send a signal when done
            data_path = os.path.join(folder_on_server, fn_i[0])
            image_test = utilities.open_image(data_path)

            model_path = os.path.join(folder_on_server, fn_m[0])

            global_store(m, model_path, image_test)
        return method_sel

@app.callback(
    Output('output-state', 'children'),
    Output('graph', 'figure'),
    Input("signal", "data"),
    State("upload-model", "filename"),
    State("upload-image", "filename"),
)
def update_multi_options(sel_methods, fn_m, fn_i):
    if sel_methods is None:
        raise PreventUpdate

    else:

        if (fn_m and fn_i) is not None:

            try:

                data_path = os.path.join(folder_on_server, fn_i[0])
                X_test = utilities.open_image(data_path)

                onnx_model_path = os.path.join(folder_on_server, fn_m[0])
                onnx_model = onnx.load(onnx_model_path)
                # get the output node
                output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]
                predictions = prepare(onnx_model).run(X_test[None, ...])[f'{output_node}']

                if len(predictions[0]) == 2:
                    class_name = [c for c in class_name_mnist]

                pred_class = class_name[np.argmax(predictions)]

                n_rows = len(class_name)

                fig = make_subplots(rows=n_rows, cols=3, subplot_titles=("RISE", "KernelShap", "LIME"))#, horizontal_spacing = 0.05)

                for m in sel_methods:

                    for i in range(n_rows):

                        fig.update_yaxes(title_text=class_name[i], row=i+1, col=1)

                        if m == "RISE":

                            relevances_rise = global_store(
                                m, onnx_model_path, X_test)

                            # RISE plot
                            fig.add_trace(go.Heatmap(
                                                z=X_test[:,:,0], colorscale='gray', showscale=False), i+1, 1)
                            fig.add_trace(go.Heatmap(
                                                z=relevances_rise[i], colorscale='Bluered',
                                                showscale=False, opacity=0.7), i+1, 1)

                        elif m == "KernelSHAP":

                            shap_values, segments_slic = global_store(
                                m, onnx_model_path, X_test)

                            # KernelSHAP plot
                            fig.add_trace(go.Heatmap(
                                            z=X_test[:,:,0], colorscale='gray', showscale=False), i+1, 2)
                            fig.add_trace(go.Heatmap(
                                            z=utilities.fill_segmentation(shap_values[i][0], segments_slic),
                                            colorscale='Bluered',
                                            showscale=False,
                                            opacity=0.7), i+1, 2)
                        else:

                            relevances_lime = global_store(
                                m, onnx_model_path, X_test)

                            # LIME plot
                            fig.add_trace(go.Heatmap(
                                                z=X_test[:,:,0], colorscale='gray', showscale=False), i+1, 3)

                            fig.add_trace(go.Heatmap(
                                                z=relevances_lime[i], colorscale='Bluered',
                                                showscale=False, opacity=0.7), i+1, 3)

                fig.update_layout(
                    width=650,
                    height=500,
                    paper_bgcolor=layouts.colors['blue4'])

                fig.update_xaxes(showgrid = False, showticklabels = False, zeroline=False)
                fig.update_yaxes(showgrid = False, showticklabels = False, zeroline=False)

                return html.Div(['The predicted class is: ' + pred_class], style = {
                    'fontSize': 14,
                    'margin-top': '20px',
                    'margin-right': '40px'
                    }), fig

            except Exception as e:
                print(e)
                return html.Div(['There was an error running the model. Check either the test image or the model.']), None
        else:
            return html.Div(['Missing either model or image.']), None

app.run_server(mode='external', port=8050)

Dash app running on http://127.0.0.1:8050/


Explaining: 100%|██████████| 50/50 [00:00<00:00, 64.25it/s]
2022-04-04 19:40:44.397488: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-04-04 19:40:44.397611: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



2022-04-04 19:40:46.579636: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-04-04 19:40:46.579699: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:40:54.522721: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
  0%|          | 0/1 [00:00<?, ?it/s]2022-04-04 19:40:54.723249: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:40:55.372941: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
100%|██████████| 1/1 [00:01<00:00,  1.43s/it]




2022-04-04 19:40:56.366216: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
100%|██████████| 5000/5000 [00:02<00:00, 1837.04it/s]




2022-04-04 19:41:05.434760: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:41:18.054361: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:41:23.102612: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:41:28.236256: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:41:28.902914: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:41:29.640999: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-04-04 19:41:37.074125: I tensorflow/core/grappler/optimizers/cust