### Reqirements
- keras >= 2.2.0 or tensorflow >= 1.13
- segmenation-models==1.0.*
- albumentations==0.3.0

# Loading dataset

In [16]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import glob

import cv2
import tensorflow
import keras
import numpy as np
import matplotlib.pyplot as plt

import albumentations as A

import segmentation_models as sm

import tensorflow as tf
tf.test.gpu_device_name()
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

# sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

Num GPUs Available:  1


# Set paths

In [17]:
DATA_DIR = '/data/'

if not os.path.exists('/data/models'):
    os.makedirs('/data/models')

# Debug for fileloading
FLAG_DEBUG_LOADING = 0

# Single z stack
FLAG_SINGLE_Z = 1

# Convert image to imtype 
imtype = "uint8"
# imtype = "float32"

class_weights=np.array([0 , 1 , .5])
class_labels = ['background' , 'nucleus' , 'cytoplasm']


dir_tag = ''
tag = 'fullval'

x_train_dir = [os.path.join(DATA_DIR, 'Image_BF1_train') , os.path.join(DATA_DIR, 'Image_BF2_train') , os.path.join(DATA_DIR, 'Image_BF3_train')]
y_train_dir = os.path.join(DATA_DIR, 'Mask_3class_train' + dir_tag)

x_valid_dir = [os.path.join(DATA_DIR, 'Image_BF1_dev') , os.path.join(DATA_DIR, 'Image_BF2_dev') , os.path.join(DATA_DIR, 'Image_BF3_dev')]
y_valid_dir = os.path.join(DATA_DIR, 'Mask_3class_dev' + dir_tag)

x_test_dir = [os.path.join(DATA_DIR, 'Image_BF1_test') , os.path.join(DATA_DIR, 'Image_BF2_test') , os.path.join(DATA_DIR, 'Image_BF3_test')]
y_test_dir = os.path.join(DATA_DIR, 'Mask_3class_test' + dir_tag)

# Dataloader and utility functions 

In [18]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(20, 20))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, np.ceil(n / 1), i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x
    

# classes for data loading and preprocessing
class Dataset:
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = class_labels
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
            n_images=None
    ):
        # Make sure files are present in all folders
        self.ids1 = [os.path.basename(x) for x in glob.glob(images_dir[1] + '/*.tif')]
        self.ids3 = [os.path.basename(x) for x in glob.glob(masks_dir + '/*.tif')]
        self.ids = list(set(self.ids1) & set(self.ids3))

        # # Number of training examples
        if len(self.ids) > n_images:
            self.ids = self.ids[0:n_images]

        self.images_fps0 = [os.path.join(images_dir[0], image_id) for image_id in self.ids]
        self.images_fps1 = [os.path.join(images_dir[1], image_id) for image_id in self.ids]
        self.images_fps2 = [os.path.join(images_dir[2], image_id) for image_id in self.ids]

        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        if FLAG_DEBUG_LOADING == 1:
            print(self.masks_fps[i])
            
        # read data and convert
        if FLAG_SINGLE_Z == 0:
            image0 = cv2.imread(self.images_fps0[i] , cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH)
            image1 = cv2.imread(self.images_fps1[i] , cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH)
            image2 = cv2.imread(self.images_fps2[i] , cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH)
        else:
            image1 = cv2.imread(self.images_fps1[i] , cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH)
            image0 = image1
            image2 = image1
        image = np.stack((image0 , image1 , image2) , axis = 2)

        if imtype == "uint8":
            image = cv2.convertScaleAbs(image , alpha = (255. / 65535.))
        if imtype == "float32":
            image = image.astype("float32")

        # Resize image
        # print(image)
        if FLAG_RESIZE == 1:
            image = cv2.resize(image , (imsize , imsize))

        if FLAG_DEBUG_LOADING == 1:
            print(self.masks_fps[i])
        mask = cv2.imread(self.masks_fps[i], cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH)

        # Resize mask using nearest neighbor to avoid decimal points
        if FLAG_RESIZE == 1:
            mask = cv2.resize(mask , dsize = (imsize , imsize) , interpolation = cv2.INTER_NEAREST)
        
        # extract certain classes from mask (e.g. cars)
        # print(self.class_values)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary: not necessary since every pixel is labeled
        # if mask.shape[-1] != 1:
        #     background = 1 - mask.sum(axis=-1, keepdims=True)
        #     mask = np.concatenate((mask, background), axis=-1)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)
    
    
class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return batch
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)   
            
            
### Augmentations
# def round_clip_0_1(x, **kwargs):
#     return x.round().clip(0, 1)

# define heavy augmentations
def get_training_augmentation():
    train_transform = [

        A.HorizontalFlip(p=0.5),

        A.VerticalFlip(p=0.5),

        A.Transpose(p=0.5),

        # A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        A.PadIfNeeded(min_height=imsize, min_width=imsize, always_apply=True, border_mode=0),
        A.RandomCrop(height=imsize, width=imsize, always_apply=True),

    ]
    return A.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        A.PadIfNeeded(imsize_test, imsize_test)
    ]
    return A.Compose(test_transform)

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)

### Define loop parameters

In [19]:
### loops
n_images_loop = np.array([1500])
# n_images_loop = np.array([1500])

base_epochs = 40
epochs_loop = (40*1500/n_images_loop).astype('int')

# backbone_loop = ['efficientnetb4','vgg16']
# imsize_loop = [320, 640]
# batch_loop = [8, 4]

backbone_loop = ['vgg16', 'efficientnetb4']
imsize_loop = [640, 320]
batch_loop = [4, 8]

### other stuff
# Random crops
FLAG_RESIZE = 0
LR = 0.0001
imsize_test = 2176
n_images_dev = 150
n_images_test = 150




# Segmentation model training

In [20]:
for ind, BACKBONE in enumerate(backbone_loop):
    imsize = imsize_loop[ind]

    for ind2, n_images_train in enumerate(n_images_loop):
        EPOCHS = epochs_loop[ind2]

        preprocess_input = sm.get_preprocessing(BACKBONE)
        
        # define network parameters
        n_classes = 1 if len(class_labels) == 1 else (len(class_labels))  # case for binary and multiclass segmentation
        activation = 'sigmoid' if n_classes == 1 else 'softmax'

        #create model
        model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

        # define optomizer
        optim = keras.optimizers.Adam(LR)

        # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
        # set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
        # dice_loss = sm.losses.DiceLoss() 

        dice_loss = sm.losses.DiceLoss(class_weights) 
        focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
        total_loss = dice_loss + (1 * focal_loss)
        #total_loss = dice_loss

        # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
        # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 

        metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5), 'acc']

        # compile keras model with defined optimozer, loss and metrics
        model.compile(optim, total_loss, metrics)
        # Dataset for train images
        from datetime import datetime
        # current date and time
        now = datetime.now()
        print(str(now))

        train_dataset = Dataset(
            x_train_dir, 
            y_train_dir, 
            classes=class_labels, 
            augmentation=get_training_augmentation(),
            preprocessing=get_preprocessing(preprocess_input),
            n_images = n_images_train
        )

        # Dataset for validation images
        valid_dataset = Dataset(
            x_valid_dir, 
            y_valid_dir, 
            classes=class_labels, 
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocess_input),
            n_images = n_images_dev
        )
        valid_dataloader = Dataloder(valid_dataset, batch_size=1, shuffle=False)

        train_dataloader = Dataloder(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

        # check shapes for errors
        assert train_dataloader[0][0].shape == (BATCH_SIZE, imsize, imsize, 3)
        assert train_dataloader[0][1].shape == (BATCH_SIZE, imsize, imsize, n_classes)

        print('Size of training set: ' + str(len(train_dataset)))
        print('Size of training minibatch: ' + str(train_dataloader[0][0].shape))
        print('Size of training minibatch mask: ' + str(train_dataloader[0][1].shape))
        print('')
        print('Size of dev set: ' + str(len(valid_dataset)))
        print('Size of dev minibatch: ' + str(valid_dataloader[0][0].shape))
        print('Size of dev minibatch mask: ' + str(valid_dataloader[0][1].shape))


           
        model_save_foldername = '/data/models/models_' + tag + '_' + BACKBONE + '_' + str(n_images_train) + '_' + str(imsize) + '_' + str(BATCH_SIZE)  + '_' + str(EPOCHS) + '_' + str(now)
        model_save_filename = model_save_foldername + '/' + 'best_model_weights.h5'
        if not os.path.exists(model_save_foldername):
            os.makedirs(model_save_foldername)
        # model_save_filename = model_save_foldername + '/' + 'weights.epoch{epoch:02d}-loss:{loss:.4f}-f1:{f1-score:.4f}-iou:{iou_score:.4f}-accuracy:{acc:.4f}-val_loss:{val_loss:.4f}-val_f1:{val_f1-score:.4f}-val_iou:{val_iou_score:.4f}-val_accuracy:{val_acc:.4f}' + '.h5'
        callbacks = [
            keras.callbacks.ModelCheckpoint(model_save_filename, monitor = 'val_loss', save_weights_only=True, save_best_only=True, mode='min', period = 1),
            keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr = LR/16, verbose = 1),
            keras.callbacks.callbacks.CSVLogger(model_save_foldername + '/' + 'traininglog', separator=',', append=False)
        ]

        # train model
        history = model.fit_generator(
            train_dataloader, 
            steps_per_epoch=len(train_dataloader), 
            epochs=EPOCHS, 
            callbacks=callbacks, 
            validation_data=valid_dataloader, 
            validation_steps=len(valid_dataloader),
            workers = 16,
            use_multiprocessing=False,
        )

2020-03-09 21:16:35.334652
Size of training set: 1500
Size of training minibatch: (4, 640, 640, 3)
Size of training minibatch mask: (4, 640, 640, 3)

Size of dev set: 150
Size of dev minibatch: (1, 2176, 2176, 3)
Size of dev minibatch mask: (1, 2176, 2176, 3)
Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40

Epoch 00034: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
Epoch 35/40
Epoch 36/40


Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40
2020-03-09 23:10:08.517890
Size of training set: 1500
Size of training minibatch: (8, 320, 320, 3)
Size of training minibatch mask: (8, 320, 320, 3)

Size of dev set: 150
Size of dev minibatch: (1, 2176, 2176, 3)
Size of dev minibatch mask: (1, 2176, 2176, 3)
Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40


Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40

Epoch 00035: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40
