In [0]:
import glob
import cv2
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from skimage.transform import resize

Using TensorFlow backend.


Image preprocessing, loading, saving

In [0]:
def volume_to_slices(array_3D, transpose_dim=(0, 1, 2, 3, 4)):
  ''' Converts 3D data to 2D slices 
    Args:
      array_3D(array): Array of 3D images (num_imgs, x, y, z, 1)
      transpose_dim(tuple): Desired order of channels 

    Returns:
      array: Array of 2D slices
  '''
  
  transposed = array_3D.transpose(transpose_dim)

  shape = transposed.shape
  return transposed.reshape(shape[0]*shape[1], shape[2], shape[3], shape[4])


In [0]:
def slices_to_volume(array_2D, shape, transpose_dim=(0, 1, 2, 3, 4)):
  ''' Converts 2D slices to 3D data 
    Args:
      array_3D(array): Array of 2D slices (num_imgs, x, y, 1)
      shape(tuple): New shape (num_imgs, x, y, z, 1)
      transpose_dim(tuple): Desired order of channels 

    Returns:
      array: Array of 3D images
  '''

  reshaped = array_2D.reshape(shape)
  return reshaped.transpose(transpose_dim)

In [0]:
def vertebrae_to_spine(masks):
  ''' Converts vertebrae masks into spine masks 
    Args:
      masks(array): Vertebrae masks to be converted

    Returns:
      array: Spine masks
  '''
  
  return (masks > 0).astype('uint8')

In [0]:
def save_predictions(data_path, predictions):
  ''' Saves predictions as .npy
  
    Args:
      data_path(string): Path to the model folder
      predictions(array): Predictions
  '''
  
  np.save(os.path.join(data_path, 'predictions.npy'), predictions)
  print('Saved predictions')


In [0]:
def create_test_imgs(data_path, dim):
  ''' Creates .npy test image data and saves it 
  
    Args:
      data_path(string): Path to the data folder
      dim(tuple): Desired dimension of the images (x,y,z)
  '''
  
  img_data_path = os.path.join(data_path, 'test/*')
  img_files = glob.glob(img_data_path)

  total = len(img_files)
  imgs = np.ndarray((total, dim[0], dim[1], dim[2]), dtype=np.uint8)

  print('Creating test data')
  i = 0
  for f in sorted(img_files):
    img = nib.load(f).get_fdata().astype(np.float32)

    img = normalize(img, 0, 255)
    img = resize_img(img, dim)

    img = np.array([img])
    imgs[i,] = img

    i += 1

    if i % 10 == 0:
      print(f'{i}/{total} done')
  print('Finished creating test data')

  imgs = np.expand_dims(imgs, axis=4)
  np.save(os.path.join(data_path, 'test_img_data.npy'), imgs)
  print('Saved test image data')

In [0]:
def create_test_masks(data_path, dim):
  ''' Creates .npy test masks data and saves it 
  
    Args:
      data_path(string): Path to the data folder
      dim(tuple): Desired dimension of the images (x,y,z)
  '''
  
  mask_data_path = os.path.join(data_path, 'test_masks/*')
  mask_files = glob.glob(mask_data_path)
  total = len(mask_files)
  masks = np.ndarray((total, dim[0], dim[1], dim[2]), dtype=np.uint8)

  print('Creating mask data')
  i = 0
  for f in sorted(mask_files):
    mask = nib.load(f).get_fdata().astype(np.float32)

    mask = resize_img(mask, dim, anti_aliasing=False, order=0)

    mask = np.array([mask])   
    masks[i,] = mask

    i += 1

    if i % 10 == 0:
      print(f'{i}/{total} done')
  print('Finished creating mask data')

  masks = np.expand_dims(masks, axis=4)   
  np.save(os.path.join(data_path, 'test_mask_data.npy'), masks)
  print('Saved training masks')

In [0]:
def create_train_imgs(data_path, dim):
  ''' Creates .npy train image data and saves it 
  
    Args:
      data_path(string): Path to the data folder
      dim(tuple): Desired dimension of the images (x,y,z)
  '''

  img_data_path = os.path.join(data_path, 'imgs/*')
  img_files = glob.glob(img_data_path)

  total = len(img_files)
  imgs = np.ndarray((total, dim[0], dim[1], dim[2]), dtype=np.uint8)

  print('Creating training data')
  i = 0
  for f in sorted(img_files):
    img = nib.load(f).get_fdata().astype(np.float32)

    img = normalize(img, 0, 255)
    img = resize_img(img, dim)

    img = np.array([img])
    imgs[i,] = img

    i += 1

    if i % 10 == 0:
      print(f'{i}/{total} done')
  print('Finished creating training data')

  imgs = np.expand_dims(imgs, axis=4)
  np.save(os.path.join(data_path, 'train_img_data.npy'), imgs)
  print('Saved training image data')

In [0]:
def create_train_masks(data_path, dim):
  ''' Creates .npy train masks data and saves it 
  
    Args:
      data_path(string): Path to the data folder
      dim(tuple): Desired dimension of the images (x,y,z)
  '''
  
  mask_data_path = os.path.join(data_path, 'masks/*')
  mask_files = glob.glob(mask_data_path)
  total = len(mask_files)
  masks = np.ndarray((total, dim[0], dim[1], dim[2]), dtype=np.uint8)

  print('Creating mask data')
  i = 0
  for f in sorted(mask_files):
    mask = nib.load(f).get_fdata().astype(np.float32)

    mask = resize_img(mask, dim, anti_aliasing=False, order=0)

    mask = np.array([mask])   
    masks[i,] = mask

    i += 1

    if i % 10 == 0:
      print(f'{i}/{total} done')
  print('Finished creating mask data')

  masks = np.expand_dims(masks, axis=4)   
  np.save(os.path.join(data_path, 'train_mask_data.npy'), masks)
  print('Saved training masks')

In [0]:
def load_test_data(data_path):
  ''' Loads test data for segmentation 
  
    Args:
      data_path(string): Path to the folder with test data

    Returns:
    np.array: Test images
  '''

  test_imgs = np.load(os.path.join(data_path, 'test_img_data.npy'))
  test_masks = np.load(os.path.join(data_path, 'test_mask_data.npy'))

  return test

In [0]:
def load_train_data(data_path):
  ''' Loads train data for segmentation 
  
    Args:
      data_path(string): Path to the folder with train data

    Returns:
      np.array: Train images
      np.array: Train masks 
  '''
  
  imgs = np.load(os.path.join(data_path, 'train_img_data.npy'))
  masks = np.load(os.path.join(data_path, 'train_mask_data.npy'))

  return imgs, masks

In [0]:
def change_orientation(file, orientation):
  ''' Changes orientation of the given .nib file to the given orientation
  
    Args:
      file(string): Path to the .nib file to be reoriented
      orientation(string): Desired orientation
  '''
  
  img = nib.load(file)
  print('Original: ')
  print(nib.aff2axcodes(img.affine))

  reorient = Reorient(orientation=orientation)
  reorient.inputs.in_file = file
  res = reorient.run()


In [0]:
def crop(img, size):
  ''' Simple crop function, extracts an array of the specified size 
  
    Args:
      imgs(array): 3D images to be cropped
      size(tuple): Size of the cropped image (x,y,z)

    Returns: 
      array: Cropped images
  '''

  x, y, z = size[0], size[1], size[2]

  return img[0:x, 0:y, 0:z]

In [0]:
def normalize(img, new_min, new_max):
  ''' Normalizes to the given range 
  
    Args:
      img(array): Image to be normalised
      new_min(float): New minimum
      new_max(float): New maximum

    Returns: 
      array: Normalised image
  ''' 
  normalized_img = img.copy()
  cv2.normalize(img, normalized_img, new_min, new_max, cv2.NORM_MINMAX)

  print(f'Normalized: {img.min()}, {img.max()} -> {normalized_img.min()}, {normalized_img.max()}')
  return normalized_img

In [0]:
def pad_with_zeros(img, size, centred=True):
  ''' Puts imgs into an array of the specified size to the centre or (0,0,0) and fills the rest of the array with zeros 
  
    Args:
      img(array): 3D image to be padded
      size(tuple): Size of the new image
      centred(bool): Whether the image should be centred or start at (0,0,0)

    Returns:
      array: Image padded with zeros
  '''
  
  n = np.zeros([size[0],size[1],size[2]])

  s0, s1, s2 = img.shape
  c0, c1, c2 = 0, 0, 0

  if centred :
    c0 = int((size[0] - s0)/2)
    c1 = int((size[1] - s1)/2)
    c2 = int((size[2] - s2)/2)

  n[c0:c0+s0, c1:c1+s1, c2:c2+s2] = img

  print(f'Reshaped: {img.shape} -> {n.shape}')
  return 

In [0]:
def resize_img(img, size, anti_aliasing=True, order=1):
  ''' Resizes to given size, optional anti-aliasing 
    Args:
      img(array): Image to be resized
      size(tuple): Dimensions of the new image
      anti_aliasing(bool): Whether to use anti-aliasing
      order(int): Order of interpolation

    Returns:
      array: Resized image
  '''
  resized = resize(img, size, preserve_range=True, anti_aliasing=anti_aliasing, order=order)

  print(f'Reshaped: {img.shape} -> {resized.shape}')
  return resized


Visualising

In [0]:
def mask_images(imgs, masks):
  ''' Masks imgs with provided masks 
  Args:
    imgs (array): Images to be masked
    masks (array): Segmentation masks to be applied
  
  Returns:
    array: Segmented images  
  '''
  
  masked = imgs.copy()
  masked[~masks] = 0 

  return masked

In [0]:
def plot_loss(data_path):
  ''' Reads logs with errors and plots train and val error 
  Args:
      data_path (string): Path to the log file
  '''
  
  df = pd.read_csv(data_path)
  plt.plot(df['epoch'], df['loss'], df['epoch'], df['val_loss'])
  plt.legend(['train_loss',  'val_loss'])
  plt.show()

In [0]:
def plot_predictions(index, predicted_masks, orig_imgs, masked_imgs):
  ''' Plots prediction, original img and masked image with the specified index in the array 
  
  Args:
      index (int): Index of the arrays to be plotted
      predicted_masks (array): Predicted masks (num_instances, x, y, z, 1)
      orig_imgs (array): Original images (num_instances, x, y, z, 1)
      masked_imgs (array): Masked (segmented) images (num_instances, x, y, z, 1)
  '''
  
  if index >= len(predicted_masks):
    print("Index out of range")
    return

  plot_slices(np.squeeze(predicted_masks, axis=4)[index])
  plot_slices(np.squeeze(orig_imgs, axis=4)[index])
  plot_slices(np.squeeze(masked_imgs, axis=4)[index])

In [0]:
def plot_slices(img, cmap="gray"):
  ''' Plots centre slices along x,y,z axis 
  
  Args:
      img (array): 3D image to be plotted
      cmap (string): Colormap for plotting
  '''
  
  dim = np.array(img.shape) / 2

  slice_0 = img[int(dim[0]), :, :]
  slice_1 = img[:, int(dim[1]), :]
  slice_2 = img[:, :, int(dim[2])]
  show_slices([slice_0, slice_1, slice_2], cmap)

  plt.title("Centre slices")

In [0]:
def show_slices(slices, cmap):
    ''' Function to display a row of image slices 
    
    Args:
      slices (array): Slice to be plotted
      cmap (string): Colormap for plotting
    '''
    
    fig, axes = plt.subplots(1, len(slices))
    for i, slice in enumerate(slices):
      axes[i].imshow(slice.T, cmap=cmap, origin="lower")
    
    fig.tight_layout(pad=3.0)
