In [1]:
# import jupyter notebook version of dash framework
from jupyter_dash import JupyterDash as Dash
# import dash components
from dash import Input, Output, State, html, dcc
# Import warnings to ignore warnings
import warnings
warnings.filterwarnings('ignore')
import os
from io import BytesIO
import numpy as np
from PIL import Image
import base64
# import load_model from keras
import tensorflow as tf
# Import visualization
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
# Import Explainer 
from lime import lime_image


# For now we will load it locally
model = tf.keras.models.load_model("../models/op_model1_aug.keras")

# Define style sheet 
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

# Create the dash app
app = Dash(__name__, external_stylesheets=external_stylesheets)

# Create markdown for our dashboard
markdown_1 = html.Div(dcc.Markdown("""
    # Brain Tumor Lesion Assesment
"""))

markdown_2 = html.Div(dcc.Markdown("""
    The model integrated in into this dashbaord has the capability to predict different brian lesions.
    The four supported brain lession classification are Meningioma, Pituitary, Glioma tumors. 
    The fourth possible prediction would be a No Tumor classification.
"""))

# Upload image
upload_img = dcc.Upload(
    id='upload-image',
    children=html.Div([
        'Drag and Drop or ',
        html.A('Select File')
    ]),
        style={
        'width': '100%',
        'height': '60px',
        'lineHeight': '60px',
        'borderWidth': '1px',
        'borderStyle': 'dashed',
        'borderRadius': '5px',
        'textAlign': 'center',
        'margin': '10px'
    },
    multiple=True, # Do not allow multiple uploads
)

# Set the layout
app.layout = html.Div(children=[
    markdown_1,
    markdown_2,
    html.Div([upload_img, html.Div(id='output-image-upload')]), # Upload image and update image
    html.Div(id='prediction-output'),
    html.Div(id='lime-container'), # Display Lime Mask and Display Importance Heatmap
])


############################################################################################################################
# function to parse file path 
def parse_contents(contents, filename):
    return html.Div([
        html.H5(filename),

        # HTML images accept base64 encoded strings in the same format
        # that is supplied by the upload
        html.Img(src=contents),
        html.Hr()
    ])
# Define callback to change image upload
@app.callback(
    Output('output-image-upload', 'children'),
    Input('upload-image', 'contents'),
    State('upload-image', 'filename'),
)
# Define an update function for the uploaded image
def update_output(list_of_contents, list_of_names):
    if list_of_contents is not None:
        children = [
            parse_contents(c, n) for c, n in
            zip(list_of_contents, list_of_names)]
        return children 

#############################################################################################################################

# Define the callback to preprocess the image and make predictions
@app.callback(
    [Output('prediction-output', 'children'), 
    Output('lime-container', 'children')],
    [Input('upload-image', 'contents')]
)
# Define function to update the prediction
def update_prediction_output(contents):
    if contents is not None:
        max_prediction_label = None
        max_prediction_value = 0
        
        for content in contents:
            content_type, content_string = content.split(',')
            
            # Decode the uploaded image
            decoded_image = base64.b64decode(content_string)
            
            # Preprocess the image
            img = Image.open(BytesIO(decoded_image))
            img = img.convert('RGB') # Convert image to RGB
            img = img.resize((128, 128)) # Resize the image to expected model image dimensions  
            img = np.array(img) / 255.0  # Normalize the image
            
            # Make prediction
            prediction = model.predict(np.expand_dims(img, axis=0))
            
            # Get the index of the class with the highest probability
            max_index = np.argmax(prediction)
            
            # Map the index to the corresponding class label
            if max_index == 0:
                max_prediction_label = 'glioma'
            elif max_index == 1:
                max_prediction_label = 'meningioma'
            elif max_index == 2:
                max_prediction_label = 'no_tumor'
            elif max_index == 3:
                max_prediction_label = 'pituitary'
            
            # If the max probability for this image is higher, update the label
            if prediction[0, max_index] > max_prediction_value:
                max_prediction_value = prediction[0, max_index]
                
            # Generate Lime explanation
            # Load the Lime explainer
            explainer = lime_image.LimeImageExplainer(random_state=42)
            
            # Develop local model explanation
            explanation = explainer.explain_instance(
                image=img,
                classifier_fn=model.predict,
                top_labels=4,
                num_samples=2000,
                hide_color=0,
                random_seed=42
            )
            
            # Obtain mask and image from the explainer
            temp, mask = explanation.get_image_and_mask(
                explanation.top_labels[0],  # Using the top predicted label for visualization
                positive_only=True,
                num_features=5,
                hide_rest=True,
                min_weight=0.1
            )
            
            # Obtaining components to Diplay Heatmap on second subplot
            ind = explanation.top_labels[0]
            dict_heatmap = dict(explanation.local_exp[ind])
            heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
            
            # Create the Lime Mask Figure with the Heatmap in a single Figure
            # Lime Mask
            fig, axes = plt.subplots(1, 2, figsize=(12,6), facecolor='white')
            axes[0].imshow(mark_boundaries(temp / 2 + 0.5, mask)) # Plots image
            axes[0].set_title("Concerning Area", fontsize=20)
            plt.axis("off")
            
            # Display heatmap on second subplot
            axes[1].imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max())
            axes[1].set_title("Blue = More Important; Red = Less Important", fontsize=20)
            plt.axis("off")  # Hide axes
            
            # Create tight layout for figure
            plt.tight_layout()
            
            # Save the figure as html
            diagnostic_fig = 'diagnostic.png'
            fig.savefig(diagnostic_fig)    
            
            # Save the figure as bytes in memory
            buf = BytesIO()
            fig.savefig(buf, format='png')
            buf.seek(0)
            
            # Encode the bytes as base64
            fig_base64 = base64.b64encode(buf.read()).decode('utf-8')
            
            # Make html Image 
            lime_fig = html.Img(src=f'data:image/png;base64, {fig_base64}', style={'width': '100%', 'height': '600px'})
            
        # Return label with highest probablity 
        if max_prediction_label:
            statement = f"Prediction: {max_prediction_label.capitalize()}"
            return statement, lime_fig
        else:
            return "No prediction available", []

##############################################################################################################################
if __name__ == '__main__':
    app.run_server(mode='inline', host='localhost', port=5000)