# Automated Segmentation of Medical Images Using a Convolutional Neural Network


## Objective
To automatically segment a selection of regions of interest (ROIs) from medical imaging data using a trained Neural Network.


## Neural Network Architecture
The selected architecture is a Multiscale Pyramid 2D Convolutional Neural Network (Dourthe et al. (2021) [1]), which was chosen based on its reported ability to accurately extract contextual and morphological information from medical images at various scales.


## How to Use

### Requirements
Create a new Python 3.9 environment and install requirements.txt within this environment.

### Data Management
Combine all the axial DICOM images that you wish to segment in a single directory.

### Notebook Setup
In the Settings section:
- Edit the different paths and filenames under the DIRECTORIES & FILENAMES section
- Edit the list of labels under the DATASET INFORMATION section
- Choose the right input shape under the MODEL PARAMETERS section
- If you want to calculate the cross-sectional area of each region of interest on every segmented image, edit different parameters under the SCANNING INFORMATION section
- Every other parameter can be left as their original value

### Run the Notebook
Once the Settings have been edited, run each cell of the notebook, select the directory where the images to-be-segmented are located and wait until the predictions are generated.


## References
[1] Dourthe B, Shaikh N, S AP, Fels S, Brown SHM, Wilson DR, Street J, Oxland TR. Automated Segmentation of spinal Muscles from Upright Open MRI Using a Multi-Scale Pyramid 2D Convolutional Neural Network. Spine (Phila Pa 1976). 2021 Dec 15. doi: 10.1097/BRS.0000000000004308. PMID: 34919072. https://pubmed.ncbi.nlm.nih.gov/34919072/

___
# Libraries Import

In [1]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Computation time monitoring
from time import time

# Data visualization
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import clear_output, HTML
from jupyterthemes import jtplot

# Data processing
import os
import pandas as pd
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from dicompylercore import dicomparser
from torch.utils.data import DataLoader
from tkinter import Tk
from tkinter.filedialog import askdirectory

# Pytoolbox import
from pytoolbox.utils import *
from pytoolbox.dataset import *
from pytoolbox.network import *
from pytoolbox.loss import *

print('Libraries successfully imported')

Libraries successfully imported


___
# Settings

In [2]:
###########################
# DIRECTORIES & FILENAMES #
###########################

# Define path towards data directory
main_dir = 'C:/Users/username/project_name/data'

# Define path towards directory where the trained model was saved
model_path = 'trained_models'

# Define model filename
model_filename = 'autoseg_mri_model_500.pt'

# Define path towards directory where predictions (i.e. segmented images) will be saved
predictions_path = 'predictions'
# Check if the corresponding directory exists, if not, create it
if not os.path.exists(os.path.abspath(predictions_path)):
    os.mkdir(os.path.abspath(predictions_path))

#######################
# DATASET INFORMATION #
#######################

# Define list of labels
#   NOTE: this number of labels in this list should match with the number of labels that
#   the model was trained to segment (+1 for background)
labels_list = ['background/other',
               'vertibral body',
               'right psoas major',
               'left psoas major',
               'right multifidus - erector spinae',
               'left multifidus - erector spinae',
               'subcutaneous fat']

####################
# MODEL PARAMETERS #
####################

# Define input shape (i.e. number of pixels along x- or y-axis)
#   NOTE: Only one value is needed as input images are either squares,
#   or will be resized to squares within the data loader
input_shape = 256

# Define dropout rate (same as the one used during training)
dropout_rate = 0.3

##################################################################
# SCANNING INFORMATION (used to calculate cross-sectional areas) #
##################################################################

# Define field of view (in mm)
fov = 240

# Define matrix size
matrix_size = (224, 192)

###############################
# DATA VISUALIZATION SETTINGS #
###############################

# Dark mode
dark_mode = True

# Define Jupyter theme based on dark mode
# list available themes
# onedork | grade3 | oceans16 | chesterish | monokai | solarizedl | solarizedd
if dark_mode:
    jtplot.style(theme='chesterish')
else:
    jtplot.style(theme='grade3')

print('Settings successfully defined')

Settings successfully defined


___
# Generate Segmentation Predictions

In [5]:
#####################################
# LOAD and INITIALIZE TRAINED MODEL #
#####################################

# Define number of labels
num_labels = len(labels_list) - 1

# Re-create instance of the multi-scale pyramid model
model = MultiScalePyramid(num_labels=len(labels_list)-1, input_shape=input_shape, training=False)

# Assign device to 'cuda' if available, or to 'cpu' otherwise
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Implements data parallelism at the module level
model = torch.nn.DataParallel(model).to(device)

# Load trained model parameters
model_state = torch.load(os.path.join(model_path, model_filename).replace('\\', '/'), map_location=device)
epochs = model_state['epoch']
state_dict = model_state['state_dict']

# Apply trained model parameters to model
model.load_state_dict(state_dict)

# Make sure NO backpropagation happens and the model DOES NOT train
for parameter in model.parameters():
    parameter.requires_grad = False

# Specify that model is in evaluation mode
model.eval()

############################
# CHOOSE TESTING DIRECTORY #
############################

# Initialize dialog box
root = Tk()
root.update()
# Ask user to select folder
test_dicoms_path = askdirectory(initialdir=main_dir, title='Select Folder') # shows dialog box and return the path
# Close dialog box
root.destroy()

# Generate list of files within selected folder
files_list = os.listdir(test_dicoms_path)

# Generate list of DICOMS from selected folder
dcms_list = []
for file in files_list:
    if file.endswith('.dcm'):
        dcms_list.append(file)

# Define where the dice results and predictions will be saved
predictions_path = results_dir + '/' + test_dicoms_path.split('/')[-1]
if not os.path.exists(predictions_path):
    os.mkdir(predictions_path)

########################
# GENERATE PREDICTIONS #
########################

# Initialize figure
fig = plt.figure(figsize=(8, 8))
plot = []

# Initialize animation status print statement
ani_status = 0
print(f'Predictions rendering in progress: {ani_status:3.0f}%', end='\r')

# Initialize results arrays
images = []
predictions = []

# Loop through every slice located in selected testing scan, along with their respective labels
for i, dcm_file in enumerate(dcms_list, start=1):      

    # Load and pre-process DICOM (i.e. convert current DICOM slice image to numpy array)
    img = preprocess_dicom(os.path.join(test_dicoms_path, dcm_file), dtype=np.float32)

    # Reshape image if shape different from input shape
    if img.shape[0] != input_shape:
        img = cv2.resize(img, (input_shape, input_shape), interpolation = cv2.INTER_CUBIC)
    images.append(img)

    # Initialize list of predictions
    predicted_labs = []

    # Ensure the model is not updated when generating predictions
    with torch.no_grad():

        # Convert images from numpy.ndarray to torch.FloatTensor format
        img_tensor = torch.FloatTensor(img).to(device).unsqueeze(dim=0).unsqueeze(dim=0)

        # Pass sample through model (only available output during testing -> stage2_output)
        stage2_output = model(img_tensor).squeeze()

        # Squeeze output to be 1D
        stage2_output = stage2_output.squeeze()

        # Append output to list of predictions
        predicted_labs.append(stage2_output.cpu().detach().numpy())

        # To prevent memory limitations, each prediction is deleted after being appended
        # to the list of precictions as an array
        del stage2_output

    # Convert predicted labels from numpy.ndarray to torch.FloatTensor format
    predicted_labs = torch.FloatTensor(predicted_labs)

    # Use linear interpolation to reshape the predicted labels into the same format as the original files
    predicted_labs = F.upsample(predicted_labs, img.shape, mode='bilinear').squeeze().detach().numpy()
    predicted_labs = np.round(np.argmax(predicted_labs, axis=0)).astype(np.uint8)

    # Save shapes of original and predicted labels
    predicted_labs_shape = predicted_labs.shape

    # Update animation status
    ani_status = i*100/len(dcms_list)
    print(f'Predictions rendering in progress: {ani_status:3.0f}%', end='\r')
    # Display MR image frame
    disp_img = plt.imshow(img, cmap='gray')
    # Display segmentation masks
    disp_lab = plt.imshow(predicted_labs, cmap='cubehelix', alpha=0.4)
    plt.axis('off')
    plot.append([disp_img, disp_lab])

    # Append current prediction to predictions list
    predictions.append(predicted_labs)

    # To prevent memory limitations, each prediction is deleted after being exported
    del predicted_labs

# Convert images and predictions to arrays
images = np.array(images)
predictions = np.array(predictions)

# Save predictions in jpg format (one file per slice)
for prediction, dcm_file in zip(predictions, dcms_list):
    plt.imsave(os.path.join(predictions_path, dcm_file[:-4]  + '.jpg'), prediction, cmap='gray')

# Finalize animation
plt.title('Predicted Segmented Images', fontsize=20)
plt.close()
print(f'Predictions rendering in progress: {100:3.0f}%')

# Generate animation
ani = animation.ArtistAnimation(fig, plot, interval=600, blit=True, repeat_delay=1000)
HTML(ani.to_html5_video())

Predictions rendering in progress: 100%


___
# Calculate Cross-Sectional Area

In [None]:
# Define pixel size based on field of view and matrix size
pix_size = fov/matrix_size[0]

# Initialize empty list to store CSA values
csa_list = []

# Define slice list
slice_list = ['slice number ' + str(slice_num) for slice_num in range(1, len(predictions)+1)]

# Loop through every slice of the predictions
for prediction in predictions:
    
    # Loop through every ROI
    for i in range(1, num_labels+1):
        
        # Calculate number of pixels contained within ROI
        num_pix = prediction[prediction==i].shape[0]
        
        # Convert number of pixels to mm2 using pixel size
        csa = num_pix*pix_size**2
        
        # Append CSA value to list
        csa_list.append(csa)
        
# Convert list to array and reshape
csa = np.array(csa_list).reshape((predictions.shape[0], num_labels))

# Calculate mean and standard deviation
csa_mean = np.round(np.nanmean(csa, axis=0), 3)
csa_std = np.round(np.nanstd(csa, axis=0), 3)
# Combine results, transpose and export dataframe
csa = np.vstack([csa, csa_mean, csa_std])
slice_list.append('mean')
slice_list.append('std')
df_results = pd.DataFrame(csa, index=slice_list, columns=labels_list[1:])

# Save dice results into csv file
df_results.to_csv(predictions_path + '/csa_results.csv')

print('\n\nLABEL-SPECIFIC CROSS-SECTIONAL AREA RESULTS (in mm2) ON TESTING DATA\n')
df_results