In [11]:
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
import dash
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import notebook
import os
import matplotlib.pyplot as plt
import imageio
from dash_slicer import VolumeSlicer
import cv2

In [12]:
case = '1187'

In [13]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
home = str(Path.home())
path = home + '/datasets/MRNet/'
train_path = path + 'train/'
valid_path = path + 'valid/'

In [14]:
def load_one_stack(case, data_path=train_path, plane='coronal'):
    fpath = '{}/{}/{}.npy'.format(data_path, plane, case)
    return np.load(fpath)

def load_one_att_stack(case, data_path=train_path, plane='coronal'):
    fpath = '{}attention/{}/{}.npy'.format(data_path, plane, case)
    if os.path.exists(fpath):
        att = np.load(fpath)
        heatmap = np.mean(att, axis=(1))
        return heatmap
    else:
        None

def load_stacks(case, data_path=train_path):
    x = {}
    planes = ['coronal', 'sagittal', 'axial']
    for i, plane in enumerate(planes):
        x[plane] = load_one_stack(case, data_path, plane)
    return x

def load_att_stacks(case, data_path=train_path):
    x = {}
    planes = ['coronal', 'sagittal', 'axial']
    for i, plane in enumerate(planes):
        x[plane] = load_one_att_stack(case, data_path, plane)
    return x

def load_cases(train=False, n=None):
    assert (type(n) == int) and (n < 1250)
    if train:
        case_list = pd.read_csv(path + 'train-acl.csv', names=['case', 'label'], header=None,
                               dtype={'case': str, 'label': np.int64})['case'].tolist()        
    else:
        case_list = pd.read_csv(path + 'valid-acl.csv', names=['case', 'label'], header=None,
                               dtype={'case': str, 'label': np.int64})['case'].tolist()        
    cases = {}
    atts = {}
    
    if n is not None:
        case_list = case_list[:n]
        
    for case in notebook.tqdm(case_list, leave=False):
        x = load_stacks(case, valid_path)
        y = load_att_stacks(case, valid_path)
        if y is None:
            y = np.zeros(x.shape)
        z = {'coronal': np.ones((x['coronal'].shape[0], 256,256,3)), 'sagittal': np.ones((x['sagittal'].shape[0], 256,256,3)), 'axial': np.ones((x['axial'].shape[0], 256,256,3))}
        planes = ['coronal', 'sagittal', 'axial']
        for j, plane in enumerate(planes):
            for i in range(x[plane].shape[0]):
                img = x[plane][i]
                img = np.array(np.stack((img, img, img), axis=2), dtype='uint8')
                att = y[plane][i]
                att = (att - np.min(att)) / (np.max(att) - np.min(att)) * 255
                att = np.array(att, dtype='uint8')
                att = cv2.applyColorMap(att, cv2.COLORMAP_JET)
                att = cv2.resize(att, (256, 256), interpolation=cv2.INTER_CUBIC)
                merge = cv2.addWeighted(att, 0.5, img, 0.5, 0)
                #z[plane][i] = merge
                
        cases[case] = x
        atts[case] = y
    return cases, atts

cases, atts = load_cases(n=120)

slice_nums = {}
for case in cases:
    slice_nums[case] = {}
    for plane in ['coronal', 'sagittal', 'axial']:
        slice_nums[case][plane] = cases[case][plane].shape[0]

  0%|          | 0/120 [00:00<?, ?it/s]

In [22]:
for i in range(cases[case]['axial'].shape[0]):
    img = cases[case]['axial'][i]
    img = np.array(np.stack((img, img, img), axis=2), dtype='uint8')
    att = atts[case]['axial'][i]
    att = (att - np.min(att)) / (np.max(att) - np.min(att)) * 255
    att = np.array(att, dtype='uint8')
    att = cv2.applyColorMap(att, cv2.COLORMAP_JET)
    att = cv2.resize(att, (256, 256), interpolation=cv2.INTER_CUBIC)
    att = cv2.cvtColor(att, cv2.COLOR_BGR2RGB)
    merge = cv2.addWeighted(att, 0.5, img, 0.5, 0)
    cv2.imwrite(path + 'output_2/axial/{}.png'.format(i), merge)


In [121]:
for i in range(mri_axial.shape[0]):
    img = mri_axial[i]
    img = np.array(np.stack((img, img, img), axis=2), dtype='uint8')
    att = att_axial[i]
    att = (att - np.min(att)) / (np.max(att) - np.min(att)) * 255
    att = np.array(att, dtype='uint8')
    att = cv2.applyColorMap(att, cv2.COLORMAP_JET)
    att = cv2.resize(att, (256, 256), interpolation=cv2.INTER_CUBIC)
    att = cv2.cvtColor(att, cv2.COLOR_BGR2RGB)
    merge = cv2.addWeighted(att, 0.5, img, 0.5, 0, output)
    cv2.imwrite(path + 'output/axial/{}.png'.format(i), merge)

In [5]:
app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

tabs_styles = {
    'height': '44px'
}
tab_style = {
    'borderBottom': '1px solid #d6d6d6',
    'padding': '6px',
    'fontWeight': 'bold'
}

tab_selected_style = {
    'borderTop': '1px solid #d6d6d6',
    'borderBottom': '1px solid #d6d6d6',
    'backgroundColor': '#119DFF',
    'color': 'white',
    'padding': '6px'
}

raw_slicer_axial = VolumeSlicer(app, cases[case]['axial'])
raw_slicer_coronal = VolumeSlicer(app, cases[case]['coronal'])
raw_slicer_sagittal = VolumeSlicer(app, cases[case]['sagittal'])

app.layout = html.Div([
    dcc.Tabs(id="tabs-styled-with-inline", value='tab-1', children=[
        dcc.Tab(label='Raw', value='tab-1', style=tab_style, selected_style=tab_selected_style),
        dcc.Tab(label='Explore', value='tab-2', style=tab_style, selected_style=tab_selected_style),
        dcc.Tab(label='Define', value='tab-3', style=tab_style, selected_style=tab_selected_style),
        dcc.Tab(label='Report', value='tab-4', style=tab_style, selected_style=tab_selected_style),
        dcc.Tab(label='Export', value='tab-5', style=tab_style, selected_style=tab_selected_style),
    ], style=tabs_styles),
    html.Div(id='tabs-content-inline', style={'padding': '5px 0px 5px 0px'})
])

@app.callback(Output('tabs-content-inline', 'children'),
              Input('tabs-styled-with-inline', 'value'))
def render_content(tab):
    if tab == 'tab-1':
        return html.Div([
            dcc.Tabs(id="subtab-1-styled-with-inline", value='tab-1', children=[
                dcc.Tab(label='Coronal', value='tab-1', style=tab_style, selected_style=tab_selected_style),
                dcc.Tab(label='Sagittal', value='tab-2', style=tab_style, selected_style=tab_selected_style),
                dcc.Tab(label='Axial', value='tab-3', style=tab_style, selected_style=tab_selected_style),
            ], style=tabs_styles),
            html.Div(id='subtab-1-content-inline', style={'padding': '10px 10px 10px 10px', 'width': '50%', 'margin': 'auto'})
        ])
    elif tab == 'tab-2':
        return html.Div([
            dcc.Tabs(id="subtab-2-styled-with-inline", value='tab-1', children=[
                dcc.Tab(label='ACL Tear', value='tab-1', style=tab_style, selected_style=tab_selected_style),
                dcc.Tab(label='Meniscus Tear', value='tab-2', style=tab_style, selected_style=tab_selected_style),
                dcc.Tab(label='Anomalies', value='tab-3', style=tab_style, selected_style=tab_selected_style),
            ], style=tabs_styles),
            html.Div(id='subtab-2-content-inline', style={'padding': '10px 10px 10px 10px', 'width': '90%', 'margin': 'auto'})
        ])
    elif tab == 'tab-3':
        return html.Div([
            html.H3('Tab content 3')
        ])
    elif tab == 'tab-4':
        return html.Div([
            html.H3('Tab content 4')
        ])
    elif tab == 'tab-5':
        return html.Div([
            html.H3('Tab content 5')
        ])

@app.callback(Output('subtab-1-content-inline', 'children'),
              Input('subtab-1-styled-with-inline', 'value'))
def render_content(tab):
    if tab == 'tab-1':
        return html.Div([
             raw_slicer_coronal.graph, raw_slicer_coronal.slider, *raw_slicer_coronal.stores
        ])
    elif tab == 'tab-2':
        return html.Div([
             raw_slicer_sagittal.graph, raw_slicer_sagittal.slider, *raw_slicer_sagittal.stores
        ])
    elif tab == 'tab-3':
        return html.Div([
             raw_slicer_axial.graph, raw_slicer_axial.slider, *raw_slicer_axial.stores
        ])

@app.callback(Output('subtab-2-content-inline', 'children'),
              Input('subtab-2-styled-with-inline', 'value'))
def render_content(tab):
    if tab == 'tab-1':
        return html.Div([
            html.Div([
                    raw_slicer_coronal.graph, raw_slicer_coronal.slider, *raw_slicer_coronal.stores
                ],
                style={'width': '33%', 'display': 'inline-block'}
                ),
            html.Div([
                    raw_slicer_sagittal.graph, raw_slicer_sagittal.slider, *raw_slicer_sagittal.stores
                ],
                style={'width': '33%', 'display': 'inline-block'}
                ),
            html.Div([
                    raw_slicer_axial.graph, raw_slicer_axial.slider, *raw_slicer_axial.stores
                ],
                style={'width': '33%', 'display': 'inline-block', 'float': 'right'},
                )          
        ])
    elif tab == 'tab-2':
        return html.Div([
             raw_slicer_sagittal.graph, raw_slicer_sagittal.slider, *raw_slicer_sagittal.stores
        ])
    elif tab == 'tab-3':
        return html.Div([
            html.H3('loading...')
        ])

In [6]:
app.run_server(port=8050, debug=True, dev_tools_props_check=False)

Dash app running on http://127.0.0.1:8050/
