In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
import tensorflow as tf

In [None]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG16
import tensorflow.keras.backend as K

In [None]:
import tensorflow_datasets as tfds

from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
tfds.__version__

In [None]:
!pip install tensorflow-datasets==4.3.0

In [None]:
dataset, info = tfds.load('foot_ulcer:1.2.0', data_dir='gs://foot_ulcer', with_info=True)

In [None]:
info

In [None]:
example = dataset['train'].take(1)
for sample in example:
    image, mask = sample["image"], sample["segmentation_mask"]
    plt.subplot(1, 2, 1)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(image))
    plt.subplot(1,2,2)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(mask))
    plt.show()

In [None]:
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask = tf.round(tf.cast(input_mask, tf.float32) / 255.0)
  return input_image, input_mask

In [None]:
img_height = 128
img_width = 128

In [None]:
@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (img_height, img_width))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (img_height, img_width))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

In [None]:
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (img_height, img_width))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (img_height, img_width))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

In [None]:
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test = dataset['test'].map(load_image_test)

In [None]:
example = train.take(1)
for image, mask in example:
    plt.subplot(1, 2, 1)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(image))
    plt.subplot(1,2,2)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(mask))
    plt.show()

In [None]:
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 16
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

In [None]:
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

In [None]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

In [None]:
for image, mask in train.take(7):
  sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

In [None]:
def conv_block(input_shape, num_filters):
    """build convolutional block for decoder portion"""
    x = Conv2D(num_filters, 3, padding="same")(input_shape)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

In [None]:
def decoder_block(input_shape, skip_features, num_filters):
    """build decoder block"""
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input_shape)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [None]:
def vgg16_unet(input_shape):
    "create input layer"
    inputs = Input(input_shape)

    "get base vgg16 model"
    vgg16 = VGG16(include_top=False, weights="imagenet", input_tensor=inputs)

    "create encoder section"
    s1 = vgg16.get_layer("block1_conv2").output
    s2 = vgg16.get_layer("block2_conv2").output
    s3 = vgg16.get_layer("block3_conv3").output
    s4 = vgg16.get_layer("block4_conv3").output

    "create bridge section"
    b1 = vgg16.get_layer("block5_conv3").output

    "create decoder section"
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    "create output layer"
    outputs = Conv2D(2, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs)
    return model

In [None]:
model = vgg16_unet(sample_image.shape)

In [None]:
# define custom metric function for DICE coefficient
def dice_coef(y_true, y_pred, smooth=1):
    y_pred = tf.argmax(y_pred, axis=-1)
    y_true = tf.reshape(y_true, (img_height, img_width, 1))
    y_true_f = K.flatten(y_true)
    y_pred_f = K.cast(K.flatten(y_pred), 'float32')
    intersection = y_true_f * y_pred_f
    dice = (2. * K.sum(intersection) + smooth) / ((K.sum(y_true_f) + K.sum(y_pred_f)) + smooth)
    return dice

In [None]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy', dice_coef])

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

In [None]:
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [None]:
show_predictions()

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
EPOCHS = 50
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])

In [None]:
model_history.history['val_accuracy'][-1]

In [None]:
model_history.history['val_dice_coef'][-1]

In [None]:
model.summary()

In [None]:
len(model.layers)