# Trénovanie U-Net modelu

**Autor: Bc. Ivan Vykopal**

Notebook určený pre tréning U-Net modelu pre segmentáciu jadier v Lizard datasete.

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization
from tensorflow.keras.layers import MaxPooling2D, ReLU, LeakyReLU, Activation, RandomRotation, RandomFlip, RandomZoom, RandomContrast
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Add
from tensorflow.keras.layers import concatenate
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dropout
from glob import glob
import numpy as np
import sys
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from tensorflow.keras.utils import load_img
from tensorflow.image import resize
import wandb
from wandb.keras import WandbCallback
import tensorflow.keras.backend as K
import cv2 as cv
import json
from sklearn.utils import shuffle

In [None]:
wandb.login()

In [None]:
run = wandb.init(project="DP U-Net Lizard", entity="ivanvykopal")

In [None]:
config = {
    'IMAGE_SIZE': 512,
    'CHANNELS': 3,
    'BATCH_SIZE': 4,
    'EPOCHS': 100,
    'PADDING': 'same',
    'DTYPE': 'float32',
    'FILTERS': 16,
    'INITIALIZER': 'he_normal',
    'KERNEL_SIZE': (3, 3),
    'LEARNING_RATE': 0.001,
    'DROPOUT': 0,
    'THRESHOLD': 0.5
}

In [None]:
wandb.config.update(config)

In [None]:
class Unet():
    def __init__(self):
        super()
     
    def _conv_block(self, x, n_filters, n_convs, residual=False):
        out = tf.identity(x)
        for i in range(n_convs):
            out = Conv2D(n_filters, kernel_size=config['KERNEL_SIZE'], padding=config['PADDING'], kernel_initializer=config['INITIALIZER'])(out)
            out = BatchNormalization()(out)
            out = Activation('relu')(out)
            
        if residual:
            shortcut = Conv2D(n_filters, kernel_size=config['KERNEL_SIZE'], padding=config['PADDING'], kernel_initializer=config['INITIALIZER'])(x)
            shortcut = BatchNormalization()(shortcut)
            out = Add()([shortcut, out])
        return out
    
    def _downsample_block(self, x, n_filters, n_convs, residual=False):
        f = self._conv_block(x, n_filters, n_convs, residual)
        p = MaxPooling2D(2)(f)
        p = Dropout(config['DROPOUT'])(p)
        return f, p
    
    def _upsample_block(self, x, conv_features, n_filters, n_convs, residual=False):
        x = Conv2DTranspose(n_filters, 2, 2, padding=config['PADDING'])(x)
        x = concatenate([x, *conv_features])
        x = Dropout(config['DROPOUT'])(x)
        x = self._conv_block(x, n_filters, n_convs, residual)
        return x

    
    def create_model(self):
        inputs = Input(shape=(config['IMAGE_SIZE'], config['IMAGE_SIZE'], config['CHANNELS']))
        # encoder
        # 1 - downsample
        conv1_1, pool1 = self._downsample_block(inputs, config['FILTERS'], 2, False)
        # 2 - downsample
        conv2_1, pool2 = self._downsample_block(pool1, config['FILTERS'] * 2, 2, False)
        
        conv1_2 = self._upsample_block(conv2_1, [conv1_1], config['FILTERS'], 2, False)
        
        
        # 3 - downsample
        conv3_1, pool3 = self._downsample_block(pool2, config['FILTERS'] * 4, 2, False)
        
        conv2_2 = self._upsample_block(conv3_1, [conv2_1], config['FILTERS'] * 2, 2, False)
        conv1_3 = self._upsample_block(conv2_2, [conv1_1, conv1_2], config['FILTERS'], 2, False)
        
        # 4 - downsample
        conv4_1, pool4 = self._downsample_block(pool3, config['FILTERS'] * 8, 2, False)
        
        conv3_2 = self._upsample_block(conv4_1, [conv3_1], config['FILTERS'] * 4, 2, False)
        conv2_3 = self._upsample_block(conv3_2, [conv2_1, conv2_2], config['FILTERS'] * 2, 2, False)
        conv1_4 = self._upsample_block(conv2_3, [conv1_1, conv1_2, conv1_3], config['FILTERS'], 2, False)
        
        # 5 - bottleneck
        conv5_1 = self._conv_block(pool4, config['FILTERS'] * 16, 2, False)
        
        conv4_2 = self._upsample_block(conv5_1, [conv4_1], config['FILTERS'] * 8, 2, False)
        conv3_3 = self._upsample_block(conv4_2, [conv3_1, conv3_2], config['FILTERS'] * 4, 2, False)
        conv2_4 = self._upsample_block(conv3_3, [conv2_1, conv2_2, conv2_3], config['FILTERS'] * 2, 2, False)
        conv1_5 = self._upsample_block(conv2_4, [conv1_1, conv1_2, conv1_3, conv1_4], config['FILTERS'], 2, False)
            
        # outputs
        outputs = Conv2D(1, 1, padding=config['PADDING'], activation = "sigmoid")(conv1_5)

        unet_model = Model(inputs, outputs, name="U-Net")
        
        return unet_model

In [None]:
def diceCoef(y_true, y_pred, smooth=1):
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
    return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

In [None]:
def diceCoefLoss(y_true, y_pred):
    return (1-diceCoef(y_true, y_pred))

In [None]:
model = Unet().create_model()

In [None]:
model.summary()

In [None]:
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=config['LEARNING_RATE']),
                  loss=tf.keras.losses.BinaryCrossentropy(),
                  #loss=diceCoefLoss,
                  metrics=['Precision', 'Recall', tf.keras.metrics.BinaryIoU(), tf.keras.metrics.MeanIoU(num_classes=2), diceCoef])

{
    "type": "flip",
    "mode": "horizontal_and_vertical" / "horizontal" / "vertical"
},
{
    "type": "rotation",
    "factor": float
},
{
    "type": "normal"
},
{
    "type": "zoom",
    "height_factor": float,
    "width_factor": float
},
{
    "type": "contrast",
    "factor": float
}

In [None]:
class Dataset(tf.keras.utils.Sequence):
    
    def __init__(self, batch_size, img_size, directory, img_json, augmentations):
        self.batch_size = batch_size
        self.img_size = img_size
        self.img_json = img_json
        self.img_paths =  []
        self.directory = directory
        self.augmentations = [{"type": "normal"}] + augmentations
        
        for aug_idx, augmentation in enumerate(self.augmentations):
            for index, image in enumerate(self.img_json['images']):
                for index_patch in range(len(image['patches'])):
                    coors = image['patches'][index_patch] + [(aug_idx * len(self.img_json['images'])) + index, augmentation] 
                    self.img_paths.append(coors)
        print(len(self.img_paths))
                
    def __len__(self):
        return len(self.img_paths) // self.batch_size
    
    def __getitem__(self, idx):
        i = idx * self.batch_size
        batch_imgs = self.img_paths[i : i + self.batch_size]
        
        x = np.zeros((self.batch_size,) + self.img_size + (config['CHANNELS'],), dtype="float32")
        y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32")
        for index, batch in enumerate(batch_imgs):
            y_idx, x_idx, file_idx, augmentation = batch
            file_idx = file_idx % len(self.img_json['images'])
            file_name = self.img_json['images'][file_idx]['name'].replace('.tif', '')
            img = load_img(os.path.join(self.directory, file_name + '_' + str(y_idx) +'_' + str(x_idx) + '.tif'), color_mode='rgb')
            #x[index] = np.array(img) / 255
            img = np.array(img) / 255
            
            mask_name = self.img_json['images'][file_idx]['name']
            full_mask = load_img(os.path.join(self.directory.replace('patches','masks'), mask_name.replace('.tif','.png')), color_mode='grayscale')
            #mask = full_mask[y_idx * config['IMAGE_SIZE']:(y_idx + 1) * config['IMAGE_SIZE'], x_idx * config['IMAGE_SIZE']:(x_idx + 1) * config['IMAGE_SIZE']]
            (left, upper) = (x_idx * config['IMAGE_SIZE'], y_idx * config['IMAGE_SIZE'])
            mask = full_mask.crop((left, upper, left + config['IMAGE_SIZE'], upper + config['IMAGE_SIZE']))
            mask = np.expand_dims(mask, 2)
            mask = (np.array(mask) > 128).astype('float32')
            #y[index] = mask
            
            if augmentation['type'] == 'normal':
                x[index] = img
                y[index] = mask
            elif augmentation['type'] == 'rotation':
                x[index] = RandomRotation(factor=augmentation['factor'], interpolation='nearest')(img)
                y[index] = RandomRotation(factor=augmentation['factor'], interpolation='nearest')(mask)
            elif augmentation['type'] == 'flip':
                x[index] = RandomFlip(mode=augmentation['mode'])(img)
                y[index] = RandomFlip(mode=augmentation['mode'])(mask)
            elif augmentation['type'] == 'zoom':
                x[index] = RandomZoom(height_factor=augmentation['height_factor'], width_factor=augmentation['width_factor'], interpolation='nearest')(img)
                y[index] = RandomZoom(height_factor=augmentation['height_factor'], width_factor=augmentation['width_factor'], interpolation='nearest')(mask)
            elif augmentation['type'] == 'contrast':
                x[index] = RandomContrast(factor=augmentation['factor'])(img)
                y[index] = RandomContrast(factor=augmentation['factor'])(mask)

        return x, y

In [None]:
with open('images-train.json') as json_file:
    train_json = json.load(json_file)
    
files = shuffle(train_json['images'], random_state= 42)
train_filenames = files[:27]
valid_filenames = files[27:]

train_json = {
    "images": train_filenames,
    "patch_size": train_json['patch_size']
}

valid_json = {
    "images": valid_filenames,
    "patch_size": train_json['patch_size']
}

with open('images-test.json') as json_file:
    test_json = json.load(json_file)

In [None]:
augmentations = [
    {
        "type": "flip",
        "mode": "horizontal_and_vertical"
    },
    {
        "type": "flip",
        "mode": "vertical"
    },
    {
        "type": "flip",
        "mode": "horizontal"
    },
    {
        "type": "zoom",
        "height_factor": 0.5,
        "width_factor": 0.5
    },
    {
        "type": "zoom",
        "height_factor": 0.2,
        "width_factor": 0.2
    },
    {
        "type": "contrast",
        "factor": 0.2
    }, 
    {
        "type": "contrast",
        "factor": 0.5
    }, 
    {
        "type": "contrast",
        "factor": 0.7
    }
]

size = 0
for image in train_json['images']:
    size += len(image['patches'])

train_size = size * (len(augmentations) + 1)

In [None]:
train_dataset = Dataset(batch_size=config['BATCH_SIZE'], img_size=(config['IMAGE_SIZE'], config['IMAGE_SIZE']), directory='data\\train\\patches', img_json=train_json, augmentations=augmentations)
valid_dataset = Dataset(batch_size=config['BATCH_SIZE'], img_size=(config['IMAGE_SIZE'], config['IMAGE_SIZE']), directory='data\\train\\patches', img_json=valid_json, augmentations=[])
test_dataset = Dataset(batch_size=config['BATCH_SIZE'], img_size=(config['IMAGE_SIZE'], config['IMAGE_SIZE']), directory='data\\test\\patches', img_json=test_json, augmentations=[])

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    print("Name:", gpu.name, "  Type:", gpu.device_type)

In [None]:
run.display(height=720)

In [None]:
history = model.fit(train_dataset, epochs=config['EPOCHS'], validation_data=valid_dataset, steps_per_epoch = int(train_size // config['BATCH_SIZE']), callbacks=[WandbCallback()])

In [None]:
run.finish()

In [None]:
def display_sample(display_list):
    """Show side-by-side an input image,
    the ground truth and the prediction.
    """
    plt.figure(figsize=(18, 18))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [None]:
for i in range(3):
    index_batch=i
    pred_mask = model.predict(train_dataset[index_batch][0])
    for j in range(4):
        index_img=j
        display_sample([train_dataset[index_batch][0][index_img], train_dataset[index_batch][1][index_img], (np.array(pred_mask[index_img]) > config['THRESHOLD']).astype('float32')])

In [None]:
results = model.evaluate(test_dataset, batch_size=config['BATCH_SIZE'])

In [None]:
for i in range(len(test_json['images']) // config['BATCH_SIZE']):
    index_batch=i
    pred_mask = model.predict(test_dataset[index_batch][0])
    for j in range(config['BATCH_SIZE']):
        index_img=j
        display_sample([test_dataset[index_batch][0][index_img], test_dataset[index_batch][1][index_img], (np.array(pred_mask[index_img]) > config['THRESHOLD']).astype('float32')])

In [None]:
def save_predictions(dataset, files):
    index = 0
    for i in range(len(files) // config['BATCH_SIZE']):
        for j in range(config['BATCH_SIZE']):
            pred_mask = model.predict(dataset[i][0])
            image1 = cv.copyMakeBorder(dataset[i][0][j], 5, 5, 5, 5, cv.BORDER_CONSTANT, value=[255, 255, 255])
            image2 = cv.copyMakeBorder(dataset[i][1][j], 5, 5, 5, 5, cv.BORDER_CONSTANT, value=[255, 255, 255])
            image3 = cv.copyMakeBorder((np.array(pred_mask[j]) > 0.25).astype('float32').squeeze(), 5, 5, 5, 5, cv.BORDER_CONSTANT, value=[255, 255, 255])
            im_h = cv.hconcat([image1, image2, image3])
            im_h = np.expand_dims(im_h * 255, 2)
            cv.imwrite('predicted masks/' + str(index) + '.png', im_h)
            index += 1