In [None]:
import os
import numpy as np
import cv2
from glob import glob
import tempfile
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from matplotlib import pyplot as plt
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
#-----------------------DEFINING FUNCTIONS FOR MODEL----------------------------
def conv_block(inputs, num_filters):
  x = Conv2D(num_filters, 3, padding="same")(inputs)
  x = BatchNormalization()(x)
  x = Activation("relu")(x)

  x = Conv2D(num_filters, 3, padding="same")(x)
  x = BatchNormalization()(x)
  x = Activation("relu")(x)

  return x

def encoder_block(inputs, num_filters):
  x = conv_block(inputs, num_filters)
  p = MaxPool2D((2, 2))(x)
  p = Dropout(0.5)(p)
  return x, p

def decoder_block(inputs, skip, num_filters):
  x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
  x = Concatenate()([x, skip])
  x = Dropout(0.5)(x)
  x = conv_block(x, num_filters)
  return x

def build_unet(input_shape):
  inputs = Input(input_shape)

  s1, p1 = encoder_block(inputs, 64)
  s2, p2 = encoder_block(p1, 128)
  s3, p3 = encoder_block(p2, 256)
  s4, p4 = encoder_block(p3, 512)

  b1 = conv_block(p4, 1024)

  d1 = decoder_block(b1, s4, 512)
  d2 = decoder_block(d1, s3, 256)
  d3 = decoder_block(d2, s2, 128)
  d4 = decoder_block(d3, s1, 64)

  outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

  model = Model(inputs, outputs, name="UNET")
  return model

def load_data(path):
  train_x = sorted(glob(os.path.join(path, "train", "images", "*")))
  print(f"Train Images size: {len(train_x)}")
  train_y = sorted(glob(os.path.join(path, "train", "masks", "*")))
  print(f"Train Masks size: {len(train_y)}")

  valid_x = sorted(glob(os.path.join(path, "valid", "images", "*")))
  print(f"Valid Images size: {len(valid_x)}")
  valid_y = sorted(glob(os.path.join(path, "valid", "masks", "*")))
  print(f"Valid Masks size: {len(valid_y)}")

  return (train_x, train_y), (valid_x, valid_y)

def read_image(path):
  path = path.decode()
  x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
  x = cv2.resize(x, (224, 224))
  x = x/255.0 # normalizing pixels
  x = np.expand_dims(x, axis=-1)
  return x

def read_mask(path):
  path = path.decode()
  x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
  x = cv2.resize(x, (224, 224))
  x = x/255.0 # normalizing mask
  x = np.expand_dims(x, axis=-1)
  return x

def tf_parse(x, y):
  def _parse(x, y):
    x = read_image(x)
    y = read_mask(y)
    return x, y

  x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
  x.set_shape([height, width, 1]) # 1 for grayscale
  y.set_shape([height, width, 1]) # 1 for grayscale

  return x, y

def tf_dataset(x, y, batch=8):
  dataset = tf.data.Dataset.from_tensor_slices((x, y))
  dataset = dataset.map(tf_parse, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.batch(batch)
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
  return dataset

#------------------------ADDRESSING WEIGHT IMBALANCE----------------------------
def add_sample_weights(image, label):
  # pos, neg, total values from running calculating_weights file
  pos, neg, total = 7933164, 146608916, 154542080

  # scaling by total/2 helps keep the loss to a similar magnitude
  # the sum of the weights of all examples stays the same
  weight_for_0 = (1 / neg) * (total / 2.0)
  weight_for_1 = (1 / pos) * (total / 2.0) - 4

  print('Weight for class 0: {:.2f}'.format(weight_for_0))
  print('Weight for class 1: {:.2f}'.format(weight_for_1))

  class_weight = tf.constant([weight_for_0, weight_for_1])
  class_weight = class_weight/tf.reduce_sum(class_weight)

  # create an image of 'sample_weights' by using the label at each pixel
  # as an index into the 'class_weight'
  sample_weights = tf.gather(class_weight, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

In [None]:

#---------------------INITIALIZATION FOR TRAINING MODEL-------------------------
unet_name = "your-model-name-here"
dataset_path = "/content/drive/MyDrive/Images"
files_dir = "/content/drive/MyDrive/Files" # where to save model, and training log

# if you want to try multiple epochs and batch sizes, add to the following arrays
epochs = [100]
batch_size = [8]
(train_x, train_y), (valid_x, valid_y) = load_data(dataset_path)

os.environ["PYTHONHASHSEED"] = str(42)
np.random.seed(42)
tf.random.set_seed(42)

METRICS = [
  tf.keras.metrics.TruePositives(name='tp'),
  tf.keras.metrics.FalsePositives(name='fp'),
  tf.keras.metrics.TrueNegatives(name='tn'),
  tf.keras.metrics.FalseNegatives(name='fn'),
  tf.keras.metrics.BinaryAccuracy(name='accuracy'),
  tf.keras.metrics.Precision(name='precision'),
  tf.keras.metrics.Recall(name='recall'),
  tf.keras.metrics.AUC(name='auc'),
  tf.keras.metrics.AUC(name='prc', curve='PR') # precision-recall curve
]


In [None]:
#---------------TRAINING FOR DIFFERENT EPOCHS AND BATCH SIZES-------------------
for e in epochs:
  for sz in batch_size:
    #-------------------------INITIALIZING FILES ETC----------------------------
    tf.keras.backend.clear_session()
    height, width = 224, 224

    print(f"\n\n* * * * * Epochs: {e} | Batch Size: {sz} * * * * *\n\n")
    model_file = os.path.join(files_dir, "models", f"{unet_name}.h5")
    log_file = os.path.join(files_dir, "logs", f"log-{unet_name}.csv")

    train_dataset = tf_dataset(train_x, train_y, batch=sz)
    valid_dataset = tf_dataset(valid_x, valid_y, batch=sz)

    input_shape = (height, width, 1)

    model = build_unet(input_shape)
    # if you want to further train a pretrained model,
    # uncomment the line below (no need to use build_unet),
    # and add your model name & path
    # model = tf.keras.models.load_model('/content/drive/MyDrive/Files/models/model.h5', compile=False)

    #-----------------------------TRAINING MODEL--------------------------------
    opt = tf.keras.optimizers.Adam()
    model.compile(loss = 'binary_crossentropy', optimizer=opt, metrics=METRICS)

    callbacks=[
        ModelCheckpoint(model_file, verbose=1, save_best_only=True),
        CSVLogger(log_file)
    ]

    with tf.device(tf.test.gpu_device_name()):
      model.fit(
            train_dataset.map(add_sample_weights),
            validation_data=valid_dataset,
            epochs=e,
            callbacks=callbacks,
            workers=4,
            verbose=2
      )