In [1]:
from jupyter_dash import JupyterDash
import dash
from dash import dcc
from dash import html
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import notebook
import io
import os
import base64

In [2]:
from pathlib import Path
home = str(Path.home())
path = home + '/datasets/MRNet/'
train_path = path + 'train/'
valid_path = path + 'valid/'

In [3]:
def load_one_stack(case, data_path=train_path, plane='coronal'):
    fpath = '{}/{}/{}.npy'.format(data_path, plane, case)
    return np.load(fpath)

def load_one_att_stack(case, data_path=train_path, plane='coronal'):
    fpath = '{}attention/{}/{}.npy'.format(data_path, plane, case)
    if os.path.exists(fpath):
        att = np.load(fpath)
        heatmap = np.mean(att, axis=(1))
        return heatmap
    else:
        None

def load_stacks(case, data_path=train_path):
    x = {}
    planes = ['coronal', 'sagittal', 'axial']
    for i, plane in enumerate(planes):
        x[plane] = load_one_stack(case, data_path, plane)
    return x

def load_att_stacks(case, data_path=train_path):
    x = {}
    planes = ['coronal', 'sagittal', 'axial']
    for i, plane in enumerate(planes):
        x[plane] = load_one_att_stack(case, data_path, plane)
    return x

def load_cases(train=False, n=None):
    assert (type(n) == int) and (n < 1250)
    if train:
        case_list = pd.read_csv(path + 'train-acl.csv', names=['case', 'label'], header=None,
                               dtype={'case': str, 'label': np.int64})['case'].tolist()        
    else:
        case_list = pd.read_csv(path + 'valid-acl.csv', names=['case', 'label'], header=None,
                               dtype={'case': str, 'label': np.int64})['case'].tolist()        
    cases = {}
    atts = {}
    
    if n is not None:
        case_list = case_list[:n]
        
    for case in notebook.tqdm(case_list, leave=False):
        x = load_stacks(case, valid_path)
        att = load_att_stacks(case, valid_path)
        cases[case] = x
        if att is None:
            atts[case] = np.zeros(x.shape)
        else:
            atts[case] = att
    return cases, atts

cases, atts = load_cases(n=120)

  0%|          | 0/120 [00:00<?, ?it/s]

In [4]:
case = '1130'
attention = True

In [5]:
slice_nums = {}
for case in cases:
    slice_nums[case] = {}
    for plane in ['coronal', 'sagittal', 'axial']:
        slice_nums[case][plane] = cases[case][plane].shape[0]

In [6]:
app = dash.Dash()

app.layout = html.Div(children=[
    #html.H1(children='Hello Dash'),

    dcc.Slider(
        id='coronal',
        min=0,
        max=slice_nums[case]['coronal'] - 1,
        step=1,
        value=1,
    ),
    
    dcc.Slider(
        id='sagittal',
        min=0,
        max=slice_nums[case]['sagittal'] - 1,
        step=1,
        value=1,
    ),
    
    dcc.Slider(
        id='axial',
        min=0,
        max=slice_nums[case]['axial'] - 1,
        step=1,
        value=1,
    ),

    html.Img(id='example') # img element
])

@app.callback(
    dash.dependencies.Output('example', 'src'), # src attribute
    [dash.dependencies.Input('coronal', 'value')]
)
def update_figure(input):
    #create some matplotlib graph
    buf = io.BytesIO() # in-memory files
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,5))

    ax1.imshow(cases[case]['coronal'][input, :, :], 'gray')
    if attention:
        ax1.imshow(cv2.resize(atts[case]['coronal'][input, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax1.set_title(f'MRI slice {input} on coronal plane')

    ax2.imshow(cases[case]['sagittal'][input, :, :], 'gray')
    if attention:
        ax2.imshow(cv2.resize(atts[case]['sagittal'][input, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax2.set_title(f'MRI slice {input} on sagittal plane')

    ax3.imshow(cases[case]['axial'][input, :, :], 'gray')
    if attention:
        ax3.imshow(cv2.resize(atts[case]['axial'][input, :, :], (256, 256), interpolation=cv2.INTER_CUBIC), cmap=plt.cm.viridis, alpha=.5)
    ax3.set_title(f'MRI slice {input} on axial plane')
    plt.savefig(buf, format = "png") # save to the above file object
    plt.close()
    data = base64.b64encode(buf.getbuffer()).decode("utf8") # encode to html elements
    return "data:image/png;base64,{}".format(data)

In [None]:
app.run_server()

Dash is running on http://127.0.0.1:8050/

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:8050 (Press CTRL+C to quit)
127.0.0.1 - - [23/Apr/2022 13:31:37] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:31:37] "GET /_dash-layout HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:31:37] "GET /_dash-dependencies HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:31:37] "[36mGET /_dash-component-suites/dash/dcc/async-slider.js HTTP/1.1[0m" 304 -
127.0.0.1 - - [23/Apr/2022 13:31:37] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:31:40] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:31:41] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:31:51] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:32:17] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [23/Apr/2022 13:32:20] "POST /_dash-update-component HTTP/1.1" 200 -
