In [4]:
from jupyter_dash import JupyterDash
import dash
from dash import dcc
from dash import html
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import notebook
import io
import os
import base64

In [5]:
from pathlib import Path
home = str(Path.home())
path = home + '/datasets/MRNet/'
train_path = path + 'train/'
valid_path = path + 'valid/'

In [6]:
def load_one_stack(case, data_path=train_path, plane='coronal'):
    fpath = '{}/{}/{}.npy'.format(data_path, plane, case)
    return np.load(fpath)

def load_one_att_stack(case, data_path=train_path, plane='coronal'):
    fpath = '{}attention/{}/{}.npy'.format(data_path, plane, case)
    if os.path.exists(fpath):
        att = np.load(fpath)
        heatmap = np.mean(att, axis=(1))
        return heatmap
    else:
        None

def load_stacks(case, data_path=train_path):
    x = {}
    planes = ['coronal', 'sagittal', 'axial']
    for i, plane in enumerate(planes):
        x[plane] = load_one_stack(case, data_path, plane)
    return x

def load_att_stacks(case, data_path=train_path):
    x = {}
    planes = ['coronal', 'sagittal', 'axial']
    for i, plane in enumerate(planes):
        x[plane] = load_one_att_stack(case, data_path, plane)
    return x

def load_cases(train=False, n=None):
    assert (type(n) == int) and (n < 1250)
    if train:
        case_list = pd.read_csv(path + 'train-acl.csv', names=['case', 'label'], header=None,
                               dtype={'case': str, 'label': np.int64})['case'].tolist()        
    else:
        case_list = pd.read_csv(path + 'valid-acl.csv', names=['case', 'label'], header=None,
                               dtype={'case': str, 'label': np.int64})['case'].tolist()        
    cases = {}
    atts = {}
    
    if n is not None:
        case_list = case_list[:n]
        
    for case in notebook.tqdm(case_list, leave=False):
        x = load_stacks(case, valid_path)
        att = load_att_stacks(case, valid_path)
        cases[case] = x
        if att is None:
            atts[case] = np.zeros(x.shape)
        else:
            atts[case] = att
    return cases, atts

cases, atts = load_cases(n=120)

  0%|          | 0/120 [00:00<?, ?it/s]

In [7]:
case = '1130'
attention = True

In [8]:
slice_nums = {}
for case in cases:
    slice_nums[case] = {}
    for plane in ['coronal', 'sagittal', 'axial']:
        slice_nums[case][plane] = cases[case][plane].shape[0]

In [9]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

app.layout = html.Div(children=[
    html.Div([
        html.P(children='Coronal'),
        dcc.Slider(
            id='coronal',
            min=0,
            max=slice_nums[case]['coronal'] - 1,
            step=1,
            value=round((slice_nums[case]['coronal'] - 1) / 2, 0),
    )], style={'width': '40%', 'padding': '0px 0px 0px 0px'}),
    
    html.Div([
        html.P(children='Sagittal'),
        dcc.Slider(
            id='sagittal',
            min=0,
            max=slice_nums[case]['sagittal'] - 1,
            step=1,
            value=round((slice_nums[case]['sagittal'] - 1) / 2, 0),
    )], style={'width': '40%', 'padding': '0px 0px 0px 0px'}),
    
    html.Div([
        html.P(children='Axial'),
        dcc.Slider(
            id='axial',
            min=0,
            max=slice_nums[case]['axial'] - 1,
            step=1,
            value=round((slice_nums[case]['axial'] - 1) / 2, 0),
    )], style={'width': '40%', 'padding': '0px 0px 0px 0px'}),
    html.Div(html.Img(id='image', style={'width': '80%', 'padding': '0px 0px 0px 0px'})),
    html.Div([
    daq.ToggleSwitch(
        id='my-toggle-switch',
        value=False
    ),])

    html.Div(html.Img(id='image', style={'width': '80%', 'padding': '0px 0px 0px 0px'})),
])
@app.callback(
    dash.dependencies.Output('image', 'src'), # src attribute
    [dash.dependencies.Input('coronal', 'value'),
    dash.dependencies.Input('sagittal', 'value'),
    dash.dependencies.Input('axial', 'value'),
    dash.dependencies.Input('my-toggle-switch', 'value')]
)


def update_figures(coronal, sagittal, axial, toggle):
    
    buf = io.BytesIO()
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,5))
    ax1.imshow(cases[case]['coronal'][coronal, :, :], 'gray')
    if toggle:
        ax1.imshow(cv2.resize(atts[case]['coronal'][coronal, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax1.set_title(f'MRI slice {coronal} on coronal plane')
    ax2.imshow(cases[case]['sagittal'][sagittal, :, :], 'gray')
    if toggle:
        ax2.imshow(cv2.resize(atts[case]['sagittal'][sagittal, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax2.set_title(f'MRI slice {sagittal} on sagittal plane')
    ax3.imshow(cases[case]['axial'][axial, :, :], 'gray')
    if toggle:
    dash.dependencies.Input('axial', 'value')]
)
def update_figures(coronal, sagittal, axial):

    buf = io.BytesIO() 
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,5))

    ax1.imshow(cases[case]['coronal'][coronal, :, :], 'gray')
    if attention:
        ax1.imshow(cv2.resize(atts[case]['coronal'][coronal, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax1.set_title(f'MRI slice {coronal} on coronal plane')

    ax2.imshow(cases[case]['sagittal'][sagittal, :, :], 'gray')
    if attention:
        ax2.imshow(cv2.resize(atts[case]['sagittal'][sagittal, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax2.set_title(f'MRI slice {sagittal} on sagittal plane')

    ax3.imshow(cases[case]['axial'][axial, :, :], 'gray')
    if attention:
        ax3.imshow(cv2.resize(atts[case]['axial'][axial, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax3.set_title(f'MRI slice {axial} on axial plane')
    plt.savefig(buf, format = "png") # save to the above file object
    plt.close()
    data = base64.b64encode(buf.getbuffer()).decode("utf8") # encode to html elements
    return "data:image/png;base64,{}".format(data)




SyntaxError: invalid syntax (2392980172.py, line 43)

In [7]:
app.run_server()
app.run_server(port=8050)

Dash app running on http://127.0.0.1:8050/
