In [7]:
%%writefile tfm_app.py

import sys
path = "../"
sys.path.append(path)
import streamlit as st
import SessionState
import os
import time
import numpy as np
import nibabel as nib
import tensorflow as tf
from deepbrain import Extractor
from scipy import ndimage
from tensorflow.keras.models import load_model 
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.metrics import BinaryAccuracy
from Notebooks.aux_functions.aux_functions_cnn import *
from skimage.transform import resize
import plotly.graph_objects as go


def read_nifti_file(file):
    """
    Read and load nifti file.
    """
    
    # Read file
    volume = nib.load(file)

    # Get raw data
    volume = volume.get_fdata()
    
    # Exchange axis 0 and 2
    if volume.shape[1] == volume.shape[2]:
        print(f"{file} has a shape incompatible")
    
    return volume


def remove_skull(volume):
    """
    Extract only brain mass from volume.
    """
    
    # Initialize brain tissue extractor
    ext = Extractor()

    # Calculate probability of being brain tissue
    prob = ext.run(volume) 

    # Extract mask with probability higher than 0.5
    mask = prob > 0.5
    
    # Detect only pixels with brain mass
    volume [mask == False] = 0
    volume = volume.astype("float32")
    
    return volume


def normalize(volume):
    """
    Normalize the volume intensity.
    """
    
    I_min = np.amin(volume)
    I_max = np.amax(volume)
    new_min = 0.0
    new_max = 1.0
    
    volume_nor = (volume - I_min) * (new_max - new_min)/(I_max - I_min)  + new_min
    volume_nor = volume_nor.astype("float32")
    
    return volume_nor


def cut_volume(volume):
    """
    Cut size of 3D volume.
    """
    
    if volume.shape[0] == 256:
        volume_new = volume[20:220,30:,:]
    
    if volume.shape[0] == 192:
        volume_new = volume[20:180,20:180,:]
    
    return volume_new


def resize_volume(volume):
    """
    Resize across z-axis
    """
    
    # Set the desired depth
    desired_height = 180
    desired_width = 180
    desired_depth = 110
    
    # Get current depth
    current_height = volume.shape[0]
    current_width = volume.shape[1]
    current_depth = volume.shape[2]
    
    # Compute depth factor
    height = current_height / desired_height
    width = current_width / desired_width
    depth = current_depth / desired_depth

    height_factor = 1 / height
    width_factor = 1 / width
    depth_factor = 1 / depth
    
    # Rotate
    #img = ndimage.rotate(img, 90, reshape=False)
    
    # Resize across z-axis
    volume = ndimage.zoom(volume, (height_factor, width_factor, depth_factor), order=1)
    
    return volume

def resize_volume_2(volume):
    """
    Resize across z-axis
    """
    
    # Set the desired depth
    desired_height = 55
    desired_width = 65
    desired_depth = 40
    
    # Get current depth
    current_height = volume.shape[0]
    current_width = volume.shape[1]
    current_depth = volume.shape[2]
    
    # Compute depth factor
    height = current_height / desired_height
    width = current_width / desired_width
    depth = current_depth / desired_depth

    height_factor = 1 / height
    width_factor = 1 / width
    depth_factor = 1 / depth
    
    # Rotate
    #img = ndimage.rotate(img, 90, reshape=False)
    
    # Resize across z-axis
    volume = ndimage.zoom(volume, (height_factor, width_factor, depth_factor), order=1)
    
    return volume
    

def process_scan(file):
    """
    Read, skull stripping and resize Nifti file.
    """
    
    # Read Nifti file
    volume = read_nifti_file(file)
    
    # Extract skull from 3D volume
    volume = remove_skull(volume)
    
    # Cut 3D volume
    #volume = cut_volume(volume)
    
    # Resize width, height and depth
    volume = resize_volume(volume)
    
    # Normalize pixel intensity
    volume = normalize(volume)
    
    return volume


def load_cnn(model_name):

    #  Load model
    model = load_model(path + "Results/" + model_name + ".h5", 
                       custom_objects = {'f1': f1})

    # Define optimizer
    optimizer = Adam(learning_rate = 0.001, decay = 1e-6)

    # Compile model
    model.compile(loss = "binary_crossentropy",
                  optimizer = optimizer,
                  metrics = [BinaryAccuracy(), f1])

    return model

# loading the trained model
# Load neural network
model_name = "3d_model_v4"
model = load_cnn(model_name)
 

# Define menu
menu = ['Home','About']

choice = st.sidebar.selectbox('Menu',menu)

# Define About tab
if choice=='About':
    st.title('About')
    st.subheader('About the App')
    st.text('The objective of this application is to...')
    
# Define Home tab
if choice=='Home':
        
    st.title('Alzheimer detection')
    st.subheader('Upload your MRI')

if __name__ == '__main__':

    # Define Home tab
    if choice == 'Home':

        # Upload file
        mri_file = st.file_uploader("Upload your MRI file. App supports Nifti (.nii) and zipped Nifti (.nii.gz) files.",
                                    type = ['nii','gz'])
        
        
        prediction = None
        
        if mri_file is not None:
            
            # Check if file is correct
            if "nii" not in mri_file.name:
                st.text("Please choose a Nifti file")

            if "nii" in mri_file.name:
                
                with open("./data/" + mri_file.name, "wb") as file: 

                    #file.write(mri_file.read())
                    newFileByteArray = bytearray(mri_file.read())
                    file.write(newFileByteArray)
                
                session_state = SessionState.get(checkboxed=False)

                if st.button('Predict class') or session_state.checkboxed:

                    st.subheader('Class detection')
        
                    ## PROCESS ##
                    #mri_file = "./" + str(mri_file.name)
                    mri_file = "./data/" + str(mri_file.name)
                    volume = process_scan(mri_file)
                    volume = volume[10:120, 30:160, 15:95]
                    volume = np.reshape(volume, (1,) + volume.shape)
                    prediction = model.predict(volume)
                    
                    if prediction > 0.5:
                        st.text("RESULT: Model has predicted that the MRI contains signs of Alzheimer")
                    else:
                        st.text("RESULT: Model has predicted that the MRI is cognitively normal")

                    session_state.checkboxed = True
                    
                    
                    if st.checkbox("Show results"):
                        st.write("Hello world")
                        resized_volume = volume[0,:,:,:]
                        resized_volume = resize_volume_2(resized_volume)
                        
                        X, Y, Z = np.mgrid[0:55:55j, 0:65:65j, 0:40:40j]
                        fig = go.Figure(data=go.Volume(
                            x=X.flatten(),
                            y=Y.flatten(),
                            z=Z.flatten(),
                            value=resized_volume.flatten(),
                            isomin=0.2,
                            isomax=0.54,
                            colorscale = "jet",
                            opacity=0.1, # needs to be small to see through all surfaces
                            surface_count=17, # needs to be a large number for good volume rendering
                            ))
                        st.plotly_chart(fig)

Overwriting tfm_app.py
