In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
import os
import glob
import functools

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
mpl.rcParams['figure.figsize'] = (12,12)

from sklearn.model_selection import train_test_split
import matplotlib.image as mpimg
import pandas as pd
from PIL import Image

In [0]:
import tensorflow as tf
import tensorflow.contrib as tfcontrib
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.python.keras import backend as K  

In [4]:
!pip install SimpleITK



In [0]:
import h5py
import SimpleITK as sitk

Find training data filenames and label filenames 

In [0]:
def getTrainNLabelNames(data_folder, m, ext='*.nii.gz'):
  x_train_filenames = []
  y_train_filenames = []
  for subject_dir in sorted(glob.glob(os.path.join(data_folder,m+'_train',ext))):
      x_train_filenames.append(os.path.realpath(subject_dir))
  for subject_dir in sorted(glob.glob(os.path.join(data_folder ,m+'_train_masks',ext))):
      y_train_filenames.append(os.path.realpath(subject_dir))
  return x_train_filenames, y_train_filenames


Convert 3D data to 2D data

In [0]:
def swapLabels(labels):
    labels[labels==421]=420
    unique_label = np.unique(labels)

    new_label = range(len(unique_label))
    for i in range(len(unique_label)):
        label = unique_label[i]
        print(label)
        newl = new_label[i]
        print(newl)
        labels[labels==label] = newl
       
    print(unique_label)

    return labels

In [0]:
def RescaleIntensity(slice_im,m):
  #slice_im: numpy array
  #m: modality, ct or mr
  if m =="ct":
    slice_im[slice_im>750] = 750
    slice_im[slice_im<-750] = -750
    slice_im = slice_im/750
  elif m=="mr":
#     top_10 = np.percentile(slice_im,90)
#     above = slice_im[slice_im>top_10]
#     med = np.median(above)
#     slice_im = slice_im/med
#     slice_im[slice_im>1.] = 1.
#     slice_im = slice_im*2.-1.
    slice_im[slice_im>1500] = 1500
    slice_im = (slice_im-750)/750
  return slice_im
    

In [0]:
def np_to_tfrecords(X, Y, file_path_prefix, verbose=True):
    def _bytes_feature(value):
      """Returns a bytes_list from a string / byte."""
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

    def _float_feature(value):
      """Returns a float_list from a float / double."""
      return tf.train.Feature(float_list=tf.train.FloatList(value=value))

    def _int64_feature(value):
      """Returns an int64_list from a bool / enum / int / uint."""
      return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
            
    if Y is not None:
        assert X.shape == Y.shape
    
    # Generate tfrecord writer
    result_tf_file = file_path_prefix + '.tfrecords'
    writer = tf.python_io.TFRecordWriter(result_tf_file)
    if verbose:
        print("Serializing example into {}".format(result_tf_file))
        
    # iterate over each sample,
    # and serialize it as ProtoBuf.
    
    d_feature = {}
    d_feature['X'] = _float_feature(X.flatten())
    if Y is not None:
        d_feature['Y'] = _int64_feature(Y.flatten())
    d_feature['shape0'] = _int64_feature([X.shape[0]])
    d_feature['shape1'] = _int64_feature([X.shape[1]])
            
    features = tf.train.Features(feature=d_feature)
    example = tf.train.Example(features=features)
    serialized = example.SerializeToString()
    writer.write(serialized)
    
    if verbose:
        print("Writing {} done!".format(result_tf_file))

In [0]:
def data_preprocess(modality,data_folder,view, data_folder_out):
  train_img_path = []
  train_mask_path = []
  train_weights = []
  for m in modality:
    imgVol_fn, mask_fn = getTrainNLabelNames(data_folder, m)
    print("number of training data %d" % len(imgVol_fn))
    assert len(imgVol_fn) == len(mask_fn)

    for i in range(0,len(imgVol_fn)):
      img_path = imgVol_fn[i]
      mask_path = mask_fn[i]
      imgVol = sitk.GetArrayFromImage(sitk.ReadImage(img_path))  # numpy array
      imgVol = RescaleIntensity(imgVol, m)
      #imgVol = HistogramEqualization(imgVol)
      maskVol = sitk.GetArrayFromImage(sitk.ReadImage(mask_path))  # numpy array
      maskVol = swapLabels(maskVol)
      if m =="mr":
        imgVol = np.moveaxis(imgVol,0,-1)
        maskVol = np.moveaxis(maskVol,0,-1)
      print("number of image slices in this view %d" % imgVol.shape[view])
      for sid in range(imgVol.shape[view]):
        out_im_path = os.path.join(data_folder_out, m+'_train', m+'_train'+str(i)+'_'+str(sid))
        out_msk_path = os.path.join(data_folder_out, m+'_train_masks',  m+'_train_mask'+str(i)+'_'+str(sid))
        slice_im = np.moveaxis(imgVol,view,0)[sid,:,:]
        slice_msk = np.moveaxis(maskVol,view,0)[sid,:,:]
        #slice_im = HistogramEqualization(slice_im)
        #sitk.WriteImage(sitk.Cast(sitk.RescaleIntensity(sitk.GetImageFromArray(slice_im.astype(np.uint16))), sitk.sitkUInt8),out_im_path+'.png')
        #sitk.WriteImage(sitk.Cast(sitk.RescaleIntensity(sitk.GetImageFromArray((slice_msk).astype(np.uint16))), sitk.sitkUInt8), out_msk_path+'.png')
        #np.save(out_im_path+'.npy',RescaleIntensity(slice_im,m))
        #np.save(out_msk_path+'.npy',slice_msk)
        np_to_tfrecords(slice_im.astype(np.float32),slice_msk.astype(np.int64), out_im_path, verbose=True)
        train_img_path.append(out_im_path)
        train_mask_path.append(out_msk_path)
    #train_weights+=list(np.ones(num).astype(int))
  return train_img_path, train_mask_path

In [0]:
np.random.seed(10)

In [0]:
def sample(iterable, n):
    """
    Returns @param n random items from @param iterable.
    """
    if n == 0:
      return []
    reservoir = []
    factor = int(np.ceil(float(n)/float(len(iterable))))
    for i in range(0,factor-1):
      iterable+=iterable
    for t, item in enumerate(iterable):
        if t < n:
            reservoir.append(item)
        else:
            m = np.random.randint(0,t)
            if m < n:
                reservoir[m] = item
    return reservoir

In [13]:
modality = ["ct", "mr"]
data_folder = '/content/gdrive/My Drive/ImageData/MMWHS'
view = 1
data_folder_out = '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal'
print("Making dir...")
try:
  os.mkdir(data_folder_out)
except Exception as e: print(e)
for m in modality:
  try:
    os.mkdir(os.path.join(data_folder_out, m+'_train'))
    os.mkdir(os.path.join(data_folder_out, m+'_train_masks'))
  except Exception as e: print(e)

overwrite = False
if overwrite:
  _, _  = data_preprocess(modality,data_folder,view, data_folder_out)

x_train_filenames = []
x_weights = []
filenames = [None]*len(modality)
nums = np.zeros(len(modality))
for i, m in enumerate(modality):
  filenames[i], _ = getTrainNLabelNames(data_folder_out, m, ext='*.tfrecords')
  nums[i] = len(filenames[i])
  x_train_filenames+=filenames[i]
  print(nums)

#nums = np.max(nums) - nums
#for i , _ in enumerate(modality):
#  index = sample(range(len(filenames[i])), nums[i])
#  x_train_filenames+=[filenames[i][j] for j in index]
    
print(len(x_train_filenames))


Making dir...
[Errno 17] File exists: '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal'
[Errno 17] File exists: '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/ct_train'
[Errno 17] File exists: '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/mr_train'
[10240.     0.]
[10240.  7124.]
17364


In [14]:

#x_train_filenames = [i+'.tfrecords' for i in x_train_filenames]
#y_train_filenames = [i+'.tfrecords' for i in y_train_filenames]
#x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = \
#                    train_test_split(x_train_filenames, y_train_filenames, test_size=0.2, random_state=42)
#x_train_filenames, x_val_filenames, x_train_weights,_ = train_test_split(x_train_filenames, x_weights, test_size=0.2, random_state=42)
x_train_filenames, x_val_filenames = train_test_split(x_train_filenames, test_size=0.2, random_state=42)
num_train_examples = len(x_train_filenames)
num_val_examples = len(x_val_filenames)

print("Number of training examples: {}".format(num_train_examples))
print("Number of validation examples: {}".format(num_val_examples))

x_train_filenames[:10]
#y_train_filenames[:10]


Number of training examples: 13891
Number of validation examples: 3473


['/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/ct_train/ct_train10_311.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/ct_train/ct_train10_188.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/mr_train/mr_train0_115.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/ct_train/ct_train18_315.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/mr_train/mr_train5_112.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/ct_train/ct_train9_483.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/mr_train/mr_train1_61.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/mr_train/mr_train6_86.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/mr_train/mr_train13_217.tfrecords',
 '/content/gdrive/My Drive/ImageData/MMWHS/2d_multiclass-coronal/mr_train/mr_train3_121.tfrecords'

# Visualize
Let's take a look at some of the examples of different images in our dataset. 

In [0]:
# display_num = 5

# r_choices = np.random.choice(num_train_examples, display_num)

# plt.figure(figsize=(10, 15))
# for i in range(0, display_num * 2, 2):
#   img_num = r_choices[i // 2]
#   x_pathname = x_train_filenames[img_num]
#   y_pathname = y_train_filenames[img_num]
  
#   plt.subplot(display_num, 2, i + 1)
#   plt.imshow(mpimg.imread(x_pathname))
  
#   plt.title("Original Image")
  
#   example_labels = mpimg.imread(y_pathname)
#   label_vals = np.unique(example_labels)
#   print(label_vals)
  
#   plt.subplot(display_num, 2, i + 2)
#   plt.imshow(example_labels)
#   plt.title("Masked Image")  
  
# plt.suptitle("Examples of Images and their Masks")
# plt.show()

# Set up 

In [0]:
img_shape = (256, 256, 1)
num_class = 8
batch_size = 8
epochs = 300

# Build our input pipeline with `tf.data`



## Shifting the image

In [0]:
def shift_img(output_img, label_img, width_shift_range, height_shift_range):
  """This fn will perform the horizontal or vertical shift"""
  if width_shift_range or height_shift_range:
      if width_shift_range:
        width_shift_range = tf.random_uniform([], 
                                              -width_shift_range * img_shape[1],
                                              width_shift_range * img_shape[1])
      if height_shift_range:
        height_shift_range = tf.random_uniform([],
                                               -height_shift_range * img_shape[0],
                                               height_shift_range * img_shape[0])
      # Translate both 
      output_img = tfcontrib.image.translate(output_img,
                                             [width_shift_range, height_shift_range])
      label_img = tfcontrib.image.translate(label_img,
                                             [width_shift_range, height_shift_range])
  return output_img, label_img

## Flipping the image randomly 

In [0]:
def flip_img(horizontal_flip, tr_img, label_img):
  if horizontal_flip:
    flip_prob = tf.random_uniform([], 0.0, 1.0)
    tr_img, label_img = tf.cond(tf.less(flip_prob, 0.5),
                                lambda: (tf.image.flip_left_right(tr_img), tf.image.flip_left_right(label_img)),
                                lambda: (tr_img, label_img))
  return tr_img, label_img

##Scale/shift the image intensity randomly

In [0]:
def changeIntensity_img(tr_img, label_img, changeIntensity=False):
  if changeIntensity:
    scale = tf.random_uniform([], 0.9, 1.1)
    shift = tf.random_uniform([], -0.1, 0.1)
    tr_img = tr_img*scale+shift
    
  return tr_img, label_img
  

## Assembling our transformations into our augment function


In [0]:
def _augment(img,
             label_img,
             resize=None,  # Resize the image to some size e.g. [256, 256]
             horizontal_flip=False,  # Random left right flip,
             changeIntensity=False,
             width_shift_range=0,  # Randomly translate the image horizontally
             height_shift_range=0):  # Randomly translate the image vertically 
  if resize is not None:
    # Resize both images
    label_img = tf.image.resize_images(label_img, resize, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True)
    img = tf.image.resize_images(img, resize, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True)
  
  
  img, label_img = flip_img(horizontal_flip, img, label_img)
  img, label_img = shift_img(img, label_img, width_shift_range, height_shift_range)
  img, label_img = changeIntensity_img(img, label_img,changeIntensity )
  return img, label_img

## Processing each pathname

In [0]:
def _parse_function(example_proto):
  features = {"X": tf.VarLenFeature(tf.float32),
              "Y": tf.VarLenFeature(tf.int64),
              "shape0": tf.FixedLenFeature((), tf.int64),
              "shape1": tf.FixedLenFeature((), tf.int64)}
  parsed_features = tf.parse_single_example(example_proto, features)
  img = tf.sparse_tensor_to_dense(parsed_features["X"])
  height = tf.cast(parsed_features["shape0"], tf.int32)
  width = tf.cast(parsed_features["shape1"], tf.int32)
  print(img,parsed_features)
  label = tf.sparse_tensor_to_dense(parsed_features["Y"])
  img = tf.reshape(img, tf.stack([height, width,1]))
  label = tf.reshape(label, tf.stack([height, width,1]) )
  label = tf.cast(label, tf.int32)
  return img, label

In [0]:
def _process_pathnames(fname):
  # We map this function onto each pathname pair  
  dataset_str = tf.read_file(fname)
  dataset = tf.data.TFRecordDataset(dataset_str)
  parsed_features = dataset.map(_parse_function)
  
  iterator = dataset.make_one_shot_iterator()
  #data = iterator.get_next()
  
  img = tf.sparse_tensor_to_dense(parsed_features["X"])
  height = tf.cast(parsed_features["shape0"], tf.int32)
  width = tf.cast(parsed_features["shape1"], tf.int32)
  label = tf.sparse_tensor_to_dense(parsed_features["Y"])
  img = tf.reshape(img, tf.stack([height, width]))
  label = tf.reshape(label, tf.stack([height, width]) )

  print(img, label)
  return img, label

##Assemble Dataset

In [0]:
def get_baseline_dataset(filenames, preproc_fn=functools.partial(_augment),
                         threads=5, 
                         batch_size=batch_size,
                         shuffle=True):           
  num_x = len(filenames)
  # Create a dataset from the filenames and labels
  files = tf.data.Dataset.from_tensor_slices(filenames)
  print(files)
  
  dataset = files.apply(tf.contrib.data.parallel_interleave(
    tf.data.TFRecordDataset, cycle_length=threads))
  # Map our preprocessing function to every element in our dataset, taking
  # advantage of multithreading
  dataset = dataset.map(_parse_function, num_parallel_calls=threads)
  # dataset = dataset.map(_process_pathnames, num_parallel_calls=threads)
  if preproc_fn.keywords is not None and 'resize' not in preproc_fn.keywords:
    assert batch_size == 1, "Batching images must be of the same size"

  dataset = dataset.map(preproc_fn, num_parallel_calls=threads)
  print(num_x)
  if shuffle:
    dataset = dataset.shuffle(int(num_x/2))
  
  
  # It's necessary to repeat our data for all epochs 
  dataset = dataset.repeat().batch(batch_size)
  dataset = dataset.prefetch(buffer_size=batch_size)

  return dataset

## Set up train and validation datasets
Note that we apply image augmentation to our training dataset but not our validation dataset. 

In [0]:
tr_cfg = {
    'resize': [img_shape[0], img_shape[1]],
    'horizontal_flip': True,
    'changeIntensity': True,
    'width_shift_range': 0.1,
    'height_shift_range': 0.1
}
tr_preprocessing_fn = functools.partial(_augment, **tr_cfg)

In [0]:
val_cfg = {
    'resize': [img_shape[0], img_shape[1]],
}
val_preprocessing_fn = functools.partial(_augment, **val_cfg)

In [26]:
train_ds = get_baseline_dataset(x_train_filenames, preproc_fn=tr_preprocessing_fn,
                                batch_size=batch_size)
val_ds = get_baseline_dataset(x_train_filenames, preproc_fn=val_preprocessing_fn,
                              batch_size=batch_size)

<DatasetV1Adapter shapes: (), types: tf.string>
Tensor("SparseToDense:0", shape=(?,), dtype=float32) {'X': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f98f73c3fd0>, 'Y': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f98f73c3da0>, 'shape0': <tf.Tensor 'ParseSingleExample/ParseSingleExample:6' shape=() dtype=int64>, 'shape1': <tf.Tensor 'ParseSingleExample/ParseSingleExample:7' shape=() dtype=int64>}
13891
<DatasetV1Adapter shapes: (), types: tf.string>
Tensor("SparseToDense:0", shape=(?,), dtype=float32) {'X': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f98f94c9828>, 'Y': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f98f94c9c88>, 'shape0': <tf.Tensor 'ParseSingleExample/ParseSingleExample:6' shape=() dtype=int64>, 'shape1': <tf.Tensor 'ParseSingleExample/ParseSingleExample:7' shape=() dtype=int64>}
13891


## Let's see if our image augmentor data pipeline is producing expected results

In [0]:
# temp_ds = get_baseline_dataset(x_train_filenames, 
#                                preproc_fn=tr_preprocessing_fn,
#                                batch_size=5,
#                                shuffle=True)
# # Let's examine some of these augmented images
# data_aug_iter = temp_ds.make_one_shot_iterator()
# next_element = data_aug_iter.get_next()
# with tf.Session() as sess: 
#   batch_of_imgs, label = sess.run(next_element)

#   # Running next element in our graph will produce a batch of images
#   plt.figure(figsize=(20, 20))

#   plt.subplot(2, 5, 1)
#   plt.imshow(batch_of_imgs[0,:,:,0])
#   plt.subplot(2, 5, 6)
#   plt.imshow(label[0, :, :,0])
  
#   plt.subplot(2, 5, 2)
#   plt.imshow(batch_of_imgs[1, :, :,0])
#   plt.subplot(2, 5, 7)
#   plt.imshow(label[1, :, :,0])
  
#   plt.subplot(2, 5, 3)
#   plt.imshow(batch_of_imgs[2, :, :,0])
#   plt.subplot(2, 5, 8)
#   plt.imshow(label[2, :, :,0])
  
#   plt.subplot(2, 5, 4)
#   plt.imshow(batch_of_imgs[3, :, :,0])
#   plt.subplot(2, 5, 9)
#   plt.imshow(label[3, :, :,0])
  
#   plt.subplot(2, 5, 5)
#   plt.imshow(batch_of_imgs[4, :, :,0])
#   plt.subplot(2, 5, 10)
#   plt.imshow(label[4, :, :,0])
#   plt.show()

In [0]:
# #sanity checks
# print(np.max(batch_of_imgs), np.min(batch_of_imgs))
# print(batch_of_imgs.shape)
# print(label.shape)
# print(np.unique(label))

# Build the model

In [0]:
def conv_block(input_tensor, num_filters):
  encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
  encoder = layers.BatchNormalization()(encoder)
  encoder = layers.Activation('relu')(encoder)
  encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
  encoder = layers.BatchNormalization()(encoder)
  encoder = layers.Activation('relu')(encoder)
  return encoder

def encoder_block(input_tensor, num_filters):
  encoder = conv_block(input_tensor, num_filters)
  encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
  
  return encoder_pool, encoder

def decoder_block(input_tensor, concat_tensor, num_filters):
  decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
  decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
  decoder = layers.BatchNormalization()(decoder)
  decoder = layers.Activation('relu')(decoder)
  decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
  decoder = layers.BatchNormalization()(decoder)
  decoder = layers.Activation('relu')(decoder)
  decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
  decoder = layers.BatchNormalization()(decoder)
  decoder = layers.Activation('relu')(decoder)
  return decoder

In [0]:
inputs = layers.Input(shape=img_shape)
# 256

encoder0_pool, encoder0 = encoder_block(inputs, 32)
# 128

encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64)
# 64

encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128)
# 32

encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256)
# 16

encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512)
# 8

center = conv_block(encoder4_pool, 1024)
# center

decoder4 = decoder_block(center, encoder4, 512)
# 16

decoder3 = decoder_block(decoder4, encoder3, 256)
# 32

decoder2 = decoder_block(decoder3, encoder2, 128)
# 64

decoder1 = decoder_block(decoder2, encoder1, 64)
# 128

decoder0 = decoder_block(decoder1, encoder0, 32)
# 256

outputs = layers.Conv2D(num_class, (1, 1), activation='softmax', data_format="channels_last")(decoder0)

In [0]:
model = models.Model(inputs=[inputs], outputs=[outputs])

In [0]:
def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score

In [0]:
from tensorflow.python.keras.utils import to_categorical
def dice_loss(y_true, y_pred):
    y_true_one_hot = tf.one_hot(tf.cast(y_true,tf.int32), num_class)
    loss = 0.
    weights = [1.,1.,1.,1.,1.,1.,1.,1.]
    for i in range(num_class):
      loss += weights[i]*(1 - dice_coeff(y_true_one_hot[:,:,:,:,i], y_pred[:,:,:,i]))
    return loss

In [0]:
def bce_dice_loss(y_true, y_pred):
    loss = losses.sparse_categorical_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

In [0]:
def la_loss(y_true, y_pred):
    y_true_one_hot = tf.one_hot(tf.cast(y_true,tf.int32), num_class)
    loss = 1 - dice_coeff(y_true_one_hot[:,:,:,:,2], y_pred[:,:,:,2])
    return loss


In [34]:
from tensorflow.python.keras.optimizers import Adam
adam = Adam(lr=0.02, beta_1=0.9, beta_2=0.999, epsilon=None, decay=1e-6, amsgrad=False)
model.compile(optimizer=adam, loss=bce_dice_loss, metrics=[dice_loss])

model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 256, 256, 1)  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 32) 320         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_v1 (BatchNo (None, 256, 256, 32) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 256, 256, 32) 0           batch_normalization_v1[0][0]     
__________________________________________________________________________________________________
conv2d_1 (

In [0]:
save_model_path = '/content/gdrive/My Drive/DeepLearning/2DUNet/Logs/weights_multi-all-coronal.hdf5'
cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_model_path, monitor='val_dice_loss', save_best_only=True, verbose=1)

In [38]:
# Alternatively, load the weights directly: model.load_weights(save_model_path)
try:
  model = models.load_model(save_model_path, custom_objects={'bce_dice_loss': bce_dice_loss, 'dice_loss': dice_loss})
except:
  print("model not loaded")
  pass

Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.


In [0]:
history = model.fit(train_ds, 
                   steps_per_epoch=int(np.ceil(num_train_examples / float(batch_size))),
                   epochs=epochs,
                   validation_data=val_ds,
                   validation_steps=int(np.ceil(num_val_examples / float(batch_size))),
                   callbacks=[cp])

Epoch 1/300
Epoch 00001: val_dice_loss improved from inf to 1.99270, saving model to /content/gdrive/My Drive/DeepLearning/2DUNet/Logs/weights_multi-all-coronal.hdf5
Epoch 2/300
Epoch 00002: val_dice_loss improved from 1.99270 to 1.59678, saving model to /content/gdrive/My Drive/DeepLearning/2DUNet/Logs/weights_multi-all-coronal.hdf5
Epoch 3/300
Epoch 00003: val_dice_loss did not improve from 1.59678
Epoch 4/300
Epoch 00004: val_dice_loss improved from 1.59678 to 1.45218, saving model to /content/gdrive/My Drive/DeepLearning/2DUNet/Logs/weights_multi-all-coronal.hdf5
Epoch 5/300
Epoch 00005: val_dice_loss did not improve from 1.45218
Epoch 6/300
Epoch 00006: val_dice_loss did not improve from 1.45218
Epoch 7/300
Epoch 00007: val_dice_loss improved from 1.45218 to 1.38524, saving model to /content/gdrive/My Drive/DeepLearning/2DUNet/Logs/weights_multi-all-coronal.hdf5
Epoch 8/300
Epoch 00008: val_dice_loss improved from 1.38524 to 1.37358, saving model to /content/gdrive/My Drive/DeepLe

In [0]:
dice = history.history['dice_loss']
val_dice = history.history['val_dice_loss']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, dice, label='Training Dice Loss')
plt.plot(epochs_range, val_dice, label='Validation Dice Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Dice Loss')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

In [0]:
# Let's visualize some of the outputs 
data_aug_iter = val_ds.make_one_shot_iterator()
next_element = data_aug_iter.get_next()

batch_of_imgs, label = tf.keras.backend.get_session().run(next_element)
predicted_label = np.argmax(model.predict(batch_of_imgs),axis=-1)
# Running next element in our graph will produce a batch of images
plt.figure(figsize=(10, 20))
for i in range(5):
  img = batch_of_imgs[i]
  print(np.unique(label[i]))
  print(np.unique(predicted_label[i]))
  plt.subplot(5, 3, 3 * i + 1)
  plt.imshow(img[:,:,0])
  plt.title("Input image")
  
  plt.subplot(5, 3, 3 * i + 2)
  plt.imshow(label[i, :, :, 0])
  plt.title("Actual Mask")
  plt.subplot(5, 3, 3 * i + 3)
  plt.imshow(predicted_label[i, :, :])
  plt.title("Predicted Mask")
plt.suptitle("Examples of Input Image, Label, and Prediction")
plt.show()

In [0]:
debug_label = model.predict(batch_of_imgs)
debug_label_bk = debug_label[2,:,:,0]
debug_label_la = debug_label[2,:,:,2]
print(debug_label_bk.shape)
print(debug_label_la.shape)
plt.subplot(1,2,1)
plt.imshow(debug_label_bk)
plt.subplot(1,2,2)
plt.imshow(debug_label_la)
plt.show()