# Training the UNET

## Import

In [1]:
import os
import numpy as np
import cv2
from glob import glob
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, Multiply
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger

## Seeding

In [2]:
os.environ["PYTHONHASHSEED"] = str(42)
np.random.seed(42)
tf.random.set_seed(42)

## Hyperparameters

In [3]:
height = 512
width = 512

batch_size = 8
lr = 1e-4
epochs = 100

## Path

In [4]:
dataset_path = "dataset"

files_dir = "files"
model_file = os.path.join(files_dir, "unet.h5")
log_file = os.path.join(files_dir, "log.csv")

## Creating Folder

In [5]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [6]:
create_dir(files_dir)

## Building UNET: ResNet50 + CBAM + UNET

### CBAM

In [7]:
def channel_attention_module(x, ratio=8):
    channel = x.shape[-1]
    
    l1 = Dense(channel//ratio, activation="relu", use_bias=False)
    l2 = Dense(channel, use_bias=False)
    
    x1 = GlobalAveragePooling2D()(x)
    x1 = l1(x1)
    x1 = l2(x1)
    
    x2 = GlobalMaxPooling2D()(x)
    x2 = l1(x2)
    x2 = l2(x2)
    
    feats = x1 + x2
    feats = Activation("sigmoid")(feats)
    
    feats = Multiply()([x, feats])
    return feats

In [8]:
def spatial_attention_module(x):
    x1 = tf.reduce_mean(x, axis=-1)
    x1 = tf.expand_dims(x1, axis=-1)
    
    x2 = tf.reduce_max(x, axis=-1)
    x2 = tf.expand_dims(x2, axis=-1)
    
    feats = Concatenate()([x1, x2])
    feats = Conv2D(1, kernel_size=7, padding="same", activation="sigmoid")(feats)
    
    feats = Multiply()([x, feats])
    return feats

In [9]:
def cbam(x):
    x = channel_attention_module(x)
    x = spatial_attention_module(x)
    return x

### Conv Block

In [10]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = cbam(x)
    
    return x

### Encoder Block

In [11]:
def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

### Decoder Block

In [12]:
def decoder_block(inputs, skip, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip])
    x = conv_block(x, num_filters)
    return x

### UNET

In [13]:
def build_unet(input_shape):
    """ Inputs """
    inputs = Input(input_shape)
    
    """ ResNet50 Encoder """
    resnet50 = ResNet50(include_top=False, weights="imagenet", input_tensor=inputs)
    
    s1 = resnet50.get_layer("input_1").output
    s2 = resnet50.get_layer("conv1_relu").output
    s3 = resnet50.get_layer("conv2_block3_out").output
    s4 = resnet50.get_layer("conv3_block4_out").output
    
    b1 = resnet50.get_layer("conv4_block6_out").output
    
    """ Decoder """
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)
    
    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
    
    model = Model(inputs, outputs, name="UNET")
    return model

## Dataset Pipeline 

### Loading the training and validation dataset

In [14]:
def load_data(path):
    train_x = sorted(glob(os.path.join(path, "train", "images", "*")))
    train_y = sorted(glob(os.path.join(path, "train", "masks", "*")))
    
    valid_x = sorted(glob(os.path.join(path, "valid", "images", "*")))
    valid_y = sorted(glob(os.path.join(path, "valid", "masks", "*")))
    
    return (train_x, train_y), (valid_x, valid_y)

### Reading Images

In [15]:
def read_image(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (width, height))
    x = x/255.0
    return x

### Reading Mask

In [16]:
def read_mask(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    x = cv2.resize(x, (width, height))
    x = x/255.0
    x = np.expand_dims(x, axis=-1)
    return x

### tf.data pipeline

In [17]:
def tf_parse(x, y):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y
    
    x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
    x.set_shape([height, width, 3])
    y.set_shape([height, width, 1])
    
    return x, y

In [18]:
def tf_dataset(x, y, batch=8):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.map(tf_parse, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

## Training

In [19]:
(train_x, train_y), (valid_x, valid_y) = load_data(dataset_path)
print(f"Train: {len(train_x)} - {len(train_y)}")
print(f"Valid: {len(valid_x)} - {len(valid_y)}")

Train: 4544 - 4544
Valid: 567 - 567


In [20]:
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)

In [21]:
for x, y in valid_dataset:
    print(x.shape, y.shape)
    
    break

(8, 512, 512, 3) (8, 512, 512, 1)


In [22]:
input_shape = (height, width, 3)
model = build_unet(input_shape)

In [23]:
model.summary()

Model: "UNET"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 518, 518, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 256, 256, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 256, 256, 64) 256         conv1_conv[0][0]                 
_______________________________________________________________________________________________

#### Loss Function

In [24]:
def dice_loss(y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    dice = (2. * intersection + 1e-15) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1e-15)
    return 1.0 - dice

In [25]:
opt = tf.keras.optimizers.Adam(lr)
model.compile(loss=dice_loss, optimizer=opt, metrics=["acc"])

In [26]:
callbacks = [
        ModelCheckpoint(model_file, verbose=1, save_best_only=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
        CSVLogger(log_file),
        EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=False)
    ]

In [27]:
model.fit(
    train_dataset, 
    validation_data=valid_dataset,
    epochs=epochs,
    callbacks=callbacks
)

Epoch 1/100

Epoch 00001: val_loss improved from inf to 0.27252, saving model to files/unet.h5




Epoch 2/100

Epoch 00002: val_loss improved from 0.27252 to 0.12124, saving model to files/unet.h5
Epoch 3/100

Epoch 00003: val_loss improved from 0.12124 to 0.10461, saving model to files/unet.h5
Epoch 4/100

Epoch 00004: val_loss improved from 0.10461 to 0.09152, saving model to files/unet.h5
Epoch 5/100

Epoch 00005: val_loss improved from 0.09152 to 0.08338, saving model to files/unet.h5
Epoch 6/100

Epoch 00006: val_loss improved from 0.08338 to 0.07783, saving model to files/unet.h5
Epoch 7/100

Epoch 00007: val_loss did not improve from 0.07783
Epoch 8/100

Epoch 00008: val_loss improved from 0.07783 to 0.07161, saving model to files/unet.h5
Epoch 9/100

Epoch 00009: val_loss did not improve from 0.07161
Epoch 10/100

Epoch 00010: val_loss improved from 0.07161 to 0.06957, saving model to files/unet.h5
Epoch 11/100

Epoch 00011: val_loss did not improve from 0.06957
Epoch 12/100

Epoch 00012: val_loss did not improve from 0.06957
Epoch 13/100
  9/568 [..........................

KeyboardInterrupt: 