base on: https://www.tensorflow.org/tutorials/images/segmentation

### Краткий обзор решений победителей:

#### Bigcell

Mobilenet-V2, EfficientNet B2/B4 Unet

Mobilenet-V2 with FPN

CV: training on one split

Pytorch lightning, pytorch-toolbelt (infrerene for wsi-images), Segmentation models pytorch, albumentations

histomicstk (stain-augmentation), Digital slide archive scikit-image: color correction, tqdm, joblib, fire, omegaconf

GDAL (for WSI), openslide

settings in yaml files

BCE BDice loss

смешанный взвешенный loss BCE+BDICE, pretrain on BCE+BDICE, final on BDICE, Adam opt (0.007-0.005)

selfwritten H&E augmentation StainPertutbation (from histomic)

Color transfer (skexp.math_histograms)

Training on crops

External public dataset (DigestPath 2019)

#### Третий глаз

Resnet101/ MobileNetv3 + DeepLabV3

Crops 1024*1024, filter crops with empty masks

Noisy data in dataset

Augmentations: Distortion, ElasticTransform, OpticalTransform. CLAHE

Проверка корреляции между исходным и аугментированным изображением (матрица грамма -> SVD, first k vecors -> correlation)

Коррекция гистограммы

epochs = 50
batch = 16
1e-3
BCE

after post training on all data:
epochs=5
lr 1e-5

#### - 

* Обучение на патчах 
* pretrain on imagenet
* IoULoss
* EfficientNet-B7
* OneCycleLR
* Adam

* Test Time Augmentations (rotate 90, H/V Flips)
* Ansemble of two models: 1024*1024, 2048*2048

#### -

DiceLoss, AdamW

Augs: Sharpen, Random flips, Random rotate, Random brightness and contrast

Final Model: ResNest50 + Unet

In [None]:
# !pip install git+https://github.com/qubvel/classification_models.git
# !pip3 install tf-models-official
# !pip install -U tensorflow-addons

In [None]:
!pip install git+https://github.com/tensorflow/examples.git

In [None]:
import numpy as np 
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.backend as K
from tensorflow_examples.models.pix2pix import pix2pix
from IPython.display import clear_output
import matplotlib.pyplot as plt
from glob import glob
import cv2
import os
import math
import tensorflow_addons as tfa
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam

In [None]:
class Config:
    SEED = 1
#     base_path = '../input/tissuesegment-train/train/'
    base_path = '../input/tissuesegment/tissue-segment/'
    test_path = '../input/tissuesegment-test/test-public-images/'
    N_CLASSES = 1
    N_CHANNELS = 3
    IMG_SIZE = 512
    IMG_H = 512 # 768
    IMG_W = 512
#     TEST_IMG_H = 2048
#     TEST_IMG_W = 2048
    TEST_BATCH_SIZE = 1
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    BATCH_SIZE = 32
    NUMS_EXAMPLES = 220
    STEPS_PER_EPOCH = NUMS_EXAMPLES // BATCH_SIZE # 50 - numbers of training examples (pair image - mask)
    BUFFER_SIZE = 32
    EPOCHS = 300
    SEED = 42
    
tf.random.set_seed(Config.SEED)
rand_generator=tf.random.Generator.from_seed(42)

### DataSet

In [None]:
len(glob(Config.base_path + "*.jpg"))

In [None]:
# sorted(glob(Config.base_path + "*.jpg"))

In [None]:
len([item for item in glob(Config.base_path + "*.jpg") if not item.split('/')[-1].endswith('_mask.jpg')])

In [None]:
images = [item for item in glob(Config.base_path + "*.jpg") if not item.split('/')[-1].endswith('_mask.jpg')]
masks = [item.replace('.jpg', '_mask.jpg') for item in images]
masks = [mask for mask in masks if os.path.isfile(mask)]
images = [image for image in images if os.path.isfile(image.replace('.jpg', '_mask.jpg'))]
image_mask_pairs = list(zip(images, masks))

In [None]:
len(image_mask_pairs)

In [None]:
def random_crop(image, mask, shape=(128, 128, 4)): # concat mask and image
#     stacked_image = tf.stack([image, mask], axis=0)
#     cropped_image = tf.image.random_crop(
#         stacked_image, size=[2, *shape])

#     return cropped_image[0], cropped_image[1]
    concat=tf.concat([image, mask], axis=2)
    crop=tf.image.random_crop(concat, size=shape)
    return crop[..., :3], tf.expand_dims(crop[..., 3], -1)

In [None]:
def parse_image(img_mask_path: str) -> dict:
    """Load an image and its annotation (mask) and returning
    a dictionary.

    Parameters
    ----------
    img_path : str
        Image (not the mask) location.

    Returns
    -------
    dict
        Dictionary mapping an image and its annotation.
    """
    tf.print(img_mask_path) # DEBUG
    image = tf.io.read_file(img_mask_path[0])
    image = tf.io.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.uint8)

    mask = tf.io.read_file(img_mask_path[1])
    mask = tf.io.decode_jpeg(mask, channels=1)
    
    mask = tf.where(mask < 128, np.dtype('uint8').type(0), mask)
    mask = tf.where(mask >= 128, np.dtype('uint8').type(1), mask)
    
    return {'image': image, 'segmentation_mask': mask}

def parse_image_test(img_path: str) -> dict:
    """Load an image and its annotation (mask) and returning
    a dictionary.

    Parameters
    ----------
    img_path : str
        Image (not the mask) location.

    Returns
    -------
    dict
        Dictionary mapping an image and its annotation.
    """
    image = tf.io.read_file(img_path)
    image = tf.io.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.uint8)

    return {'image': image, 'image_path': img_path}

@tf.function
def normalize(input_image: tf.Tensor, input_mask: tf.Tensor) -> tuple:
    """Rescale the pixel values of the images between 0.0 and 1.0
    compared to [0,255] originally.

    Parameters
    ----------
    input_image : tf.Tensor
        Tensorflow tensor containing an image of size [SIZE,SIZE,3].
    input_mask : tf.Tensor
        Tensorflow tensor containing an annotation of size [SIZE,SIZE,1].

    Returns
    -------
    tuple
        Normalized image and its annotation.
    """
    input_image = tf.cast(input_image, tf.float32) / 255.0
#     input_mask -= 1 ONLY for labels starts with zero !
#     input_mask = tf.cast(input_mask, tf.float32) / 255.0
    return input_image, input_mask

@tf.function
def normalize_test(input_image: tf.Tensor):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    return input_image

@tf.function
def load_image_train(datapoint: dict) -> tuple:
    """Apply some transformations to an input dictionary
    containing a train image and its annotation.

    Notes
    -----
    An annotation is a regular  channel image.
    If a transformation such as rotation is applied to the image,
    the same transformation has to be applied on the annotation also.

    Parameters
    ----------
    datapoint : dict
        A dict containing an image and its annotation.

    Returns
    -------
    tuple
        A modified image and its annotation.
    """
    # tf.image.crop_to_bounding_box or tf.image.crop_and_resize
    input_image = datapoint['image']
    input_mask = datapoint['segmentation_mask']
    
    # Central crop
#     input_image = tf.image.central_crop(input_image, central_fraction=0.5)
#     input_mask = tf.image.central_crop(input_mask, central_fraction=0.5)

    # Random crop
    CropToH, CropToW = 2048, 2048
    if (tf.shape(input_image)[0] >= CropToH) and (tf.shape(input_image)[1] >= CropToW):
#         input_image, input_mask = random_crop(image=input_image, mask=input_mask, shape=(CropToH, CropToW, 3))
        seed=rand_generator.uniform_full_int([2],dtype=tf.int32)
        input_image = tf.image.stateless_random_crop(input_image, (CropToH,CropToW,3), seed)
        input_mask = tf.image.stateless_random_crop(input_mask, (CropToH,CropToW,1), seed)
    
    input_image = tf.image.resize(input_image, (Config.IMG_H, Config.IMG_W))
    input_mask = tf.image.resize(input_mask, (Config.IMG_H, Config.IMG_W))
    
    if input_image.shape[0] > input_image.shape[1]:
        input_image = tf.image.rot90(input_image)
        input_mask = tf.image.rot90(input_mask)

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask

@tf.function
def load_image_test(datapoint: dict) -> tuple:
    """Normalize and resize a test image and its annotation.

    Notes
    -----
    Since this is for the test set, we don't need to apply
    any data augmentation technique.

    Parameters
    ----------
    datapoint : dict
        A dict containing an image and its annotation.

    Returns
    -------
    tuple
        A modified image and its annotation.
    """
#     input_image = tf.image.resize(datapoint['image'], (Config.IMG_H, Config.IMG_W))
#     input_image = tf.image.resize(datapoint['image'], (Config.TEST_IMG_H, Config.TEST_IMG_W))
    
    input_image = datapoint["image"]
    
    if (tf.shape(input_image)[0] != None):
        input_image_h = tf.shape(input_image)[0]
        input_image_w = tf.shape(input_image)[1]
#         target_h = (input_image_h // 64)*64 # 32
#         target_w = (input_image_w // 64)*64 # 32
        target_h = tf.cast(tf.math.ceil(tf.cast(input_image_h, dtype=tf.float32) / 64.), dtype=tf.int32)*64
        target_w = tf.cast(tf.math.ceil(tf.cast(input_image_w, dtype=tf.float32) / 64.), dtype=tf.int32)*64
        input_image = tf.image.resize_with_crop_or_pad(input_image, target_h, target_w)
#         input_image = tf.image.resize(input_image, (target_h ,target_w), tf.image.ResizeMethod.LANCZOS5)
    
#     if input_image.shape[0] > input_image.shape[1]:
#         input_image = tf.image.rot90(input_image)

    input_image = normalize_test(input_image)

    return input_image, datapoint["image_path"]

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(image_mask_pairs)

In [None]:
train_dataset = train_dataset.map(parse_image)
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=Config.AUTOTUNE)

In [None]:
test_dataset = tf.data.Dataset.list_files(Config.test_path+'*.jpg', shuffle=False)
test_dataset = test_dataset.map(parse_image_test)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(Config.TEST_BATCH_SIZE)

### Augmentation

In [None]:
class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=Config.SEED)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=Config.SEED)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

From: https://www.kaggle.com/code/cdeotte/rotation-augmentation-gpu-tpu-0-96/notebook

Work only for square image shape

In [None]:
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

In [None]:
def transform(image, label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = Config.IMG_H
    XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32', seed=Config.SEED)
    shr = 5. * tf.random.normal([1],dtype='float32', seed=Config.SEED)
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32', seed=Config.SEED)/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32', seed=Config.SEED)/10.
    h_shift = 16. * tf.random.normal([1],dtype='float32', seed=Config.SEED) 
    w_shift = 16. * tf.random.normal([1],dtype='float32', seed=Config.SEED) 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(Config.IMG_H//2,-Config.IMG_W//2,-1), DIM )
    y = tf.tile( tf.range(-Config.IMG_H//2,Config.IMG_W//2),[DIM] )
    z = tf.ones([Config.IMG_H*Config.IMG_W],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-Config.IMG_H//2+XDIM+1,Config.IMG_W//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [Config.IMG_H//2-idx2[0,], Config.IMG_W//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
    
    label = tf.gather_nd(label,tf.transpose(idx3))
    label = tf.reshape(label,[Config.IMG_H,Config.IMG_W,1])
        
    return tf.reshape(d,[Config.IMG_H,Config.IMG_W,3]), label

In [None]:
train_dataset = train_dataset.map(transform)

#### Prepare DataSet for training

In [None]:
# train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(buffer_size=Config.BUFFER_SIZE, seed=Config.SEED)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(Config.BATCH_SIZE)
# train_dataset = train_dataset.map(Augment())
train_dataset = train_dataset.prefetch(buffer_size=Config.AUTOTUNE)

...

In [None]:
def display_sample(display_list):
    """Show side-by-side an input image,
    the ground truth and the prediction.
    """
    plt.figure(figsize=(18, 9))

    title = ['Input Image', 'True Mask']

    for i in range(len(display_list)):
        plt.subplot(2, len(display_list) // 2, i+1)
        plt.title(title[i % 2])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()
for image, mask in train_dataset.take(1):
    sample_image, sample_mask = image, mask

display_sample([sample_image[0], sample_mask[0], sample_image[1], sample_mask[1], sample_image[2], sample_mask[2], sample_image[3], sample_mask[3]])

### Model

In [None]:
# from classification_models.tfkeras import Classifiers
# SeResNeXT, preprocess_input = Classifiers.get('seresnext50')
# base_model = SeResNeXT(include_top = False, input_shape=(None, None, 3), weights='imagenet')

In [None]:
# !git clone https://github.com/rishigami/Swin-Transformer-TF.git
# import sys
# sys.path.append('./Swin-Transformer-TF')
# from swintransformer import SwinTransformer
# model = SwinTransformer('swin_tiny_224', num_classes=1000, include_top=False, pretrained=True)

# model = tf.keras.Sequential([
#   tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*(224,224), 3]),
#   SwinTransformer('swin_tiny_224', include_top=False, pretrained=True),
#   tf.keras.layers.Dense(5, activation='softmax')
# ])

In [None]:
# img_adjust_layer = tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3])
# pretrained_model = SwinTransformer('swin_large_224', num_classes=len(CLASSES), include_top=False, pretrained=True, use_tpu=True)

# model = tf.keras.Sequential([
#     img_adjust_layer,
#     pretrained_model,
#     tf.keras.layers.Dense(len(CLASSES), activation='softmax')
# ])

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[None, None, 3], include_top=False) # 1024, 1024

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [None]:
def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[None, None, 3]) # 1024, 1024

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same', activation='sigmoid')  #64x64 -> 128x128 # 

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
# doesn't work now
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.dot(y_true, K.transpose(y_pred))
    union = K.dot(y_true,K.transpose(y_true))+K.dot(y_pred,K.transpose(y_pred))
    return (2. * intersection + smooth) / (union + smooth)

In [None]:
def DiceLoss(targets, inputs, smooth=1e-6):

   inputs = K.flatten(inputs)
   targets = K.flatten(targets)

   intersection = K.sum(targets*inputs)
   dice = (2.*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
   return 1 - dice

In [None]:
def jaccard_distance_loss(y_true, y_pred, smooth=100):
    """
    Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
            = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
    
    The jaccard distance loss is usefull for unbalanced datasets. This has been
    shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
    gradient.
    
    Ref: https://en.wikipedia.org/wiki/Jaccard_index
    
    @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
    @author: wassname
    """
    intersection = K.sum(K.sum(K.abs(y_true * y_pred), axis=-1))
    sum_ = K.sum(K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1))
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac) * smooth

def dice_metric(y_pred, y_true):
    intersection = K.sum(K.sum(K.abs(y_true * y_pred), axis=-1))
    union = K.sum(K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1))
    # if y_pred.sum() == 0 and y_pred.sum() == 0:
    #     return 1.0

    return 2*intersection / union

In [None]:
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    gts = tf.reduce_sum(gt_sorted)
    intersection = gts - tf.cumsum(gt_sorted)
    union = gts + tf.cumsum(1. - gt_sorted)
    jaccard = 1. - intersection / union
    jaccard = tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0)
    return jaccard


# --------------------------- BINARY LOSSES ---------------------------


def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        def treat_image(log_lab):
            log, lab = log_lab
            log, lab = tf.expand_dims(log, 0), tf.expand_dims(lab, 0)
            log, lab = flatten_binary_scores(log, lab, ignore)
            return lovasz_hinge_flat(log, lab)
        losses = tf.map_fn(treat_image, (logits, labels), dtype=tf.float32)
        loss = tf.reduce_mean(losses)
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """

    def compute_loss():
        labelsf = tf.cast(labels, logits.dtype)
        signs = 2. * labelsf - 1.
        errors = 1. - logits * tf.stop_gradient(signs)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], name="descending_sort")
        gt_sorted = tf.gather(labelsf, perm)
        grad = lovasz_grad(gt_sorted)
        loss = tf.tensordot(tf.nn.relu(errors_sorted), tf.stop_gradient(grad), 1, name="loss_non_void")
        return loss

    # deal with the void prediction case (only void pixels)
    loss = tf.cond(tf.equal(tf.shape(logits)[0], 0),
                   lambda: tf.reduce_sum(logits) * 0.,
                   compute_loss,
                   name="loss"
                   )
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = tf.reshape(scores, (-1,))
    labels = tf.reshape(labels, (-1,))
    if ignore is None:
        return scores, labels
    valid = tf.not_equal(labels, ignore)
    vscores = tf.boolean_mask(scores, valid, name='valid_scores')
    vlabels = tf.boolean_mask(labels, valid, name='valid_labels')
    return vscores, vlabels

In [None]:
OUTPUT_CLASSES = 1

loss =tf.keras.losses.BinaryCrossentropy(from_logits=True)
opt_rec_adam = tfa.optimizers.RectifiedAdam(learning_rate=1e-3)

sch_cos_dec = tf.keras.optimizers.schedules.CosineDecay(3e-4, 1000) # use as lr parameter
lr_decayed_fn = (tf.keras.optimizers.schedules.CosineDecayRestarts(1e-3, 1000))  # use as lr parameter

# if OUTPUT_CLASSES > 1:
#     loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

metrics = ['accuracy', dice_metric]

model = unet_model(output_channels=OUTPUT_CLASSES)

model.compile(optimizer='adam',
              loss=DiceLoss, # 
              metrics=metrics)

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

In [None]:
def create_mask(pred_mask):
#   pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0].squeeze(axis=-1)

def show_predictions(dataset=train_dataset, num=1):
#   if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
#   else:
#     display([sample_image, sample_mask,
#              create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [None]:
show_predictions()

### Fit

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    if (epoch % 50)==0:
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
    
tensorboard_callback = tf.keras.callbacks.TensorBoard('/logdir', histogram_freq=1)
save_model_checkpoint = tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='loss', verbose=1, save_best_only=True, save_weights_only=True)
early_stopping_callback = tf.keras.callbacks.EarlyStopping(patience=10, verbose=1)
eas = EarlyStopping(monitor='val_loss', patience=8, min_delta=1e-5, verbose=1, mode='min', baseline=None, restore_best_weights=True)

In [None]:
# import math
# from keras.callbacks import Callback
# from keras import backend as K


# class CosineAnnealingScheduler(Callback):
#     """Cosine annealing scheduler.
#     """

#     def __init__(self, T_max, eta_max, eta_min=0, verbose=0):
#         super(CosineAnnealingScheduler, self).__init__()
#         self.T_max = T_max
#         self.eta_max = eta_max
#         self.eta_min = eta_min
#         self.verbose = verbose

#     def on_epoch_begin(self, epoch, logs=None):
#         if not hasattr(self.model.optimizer, 'lr'):
#             raise ValueError('Optimizer must have a "lr" attribute.')
#         lr = self.eta_min + (self.eta_max - self.eta_min) * (1 + math.cos(math.pi * epoch / self.T_max)) / 2
#         K.set_value(self.model.optimizer.lr, lr)
#         if self.verbose > 0:
#             print('\nEpoch %05d: CosineAnnealingScheduler setting learning '
#                   'rate to %s.' % (epoch + 1, lr))

#     def on_epoch_end(self, epoch, logs=None):
#         logs = logs or {}
#         logs['lr'] = K.get_value(self.model.optimizer.lr)



# callbacks = [
#     CosineAnnealingScheduler(T_max=100, eta_max=1e-2, eta_min=1e-4)
# ]


# class SGDRScheduler(tf.keras.callbacks.Callback):
#     '''Cosine annealing learning rate scheduler with periodic restarts.
#     # Usage
#         ```python
#             schedule = SGDRScheduler(min_lr=1e-5,
#                                      max_lr=1e-2,
#                                      steps_per_epoch=np.ceil(epoch_size/batch_size),
#                                      lr_decay=0.9,
#                                      cycle_length=5,
#                                      mult_factor=1.5)
#             model.fit(X_train, Y_train, epochs=100, callbacks=[schedule])
#         ```
#     # Arguments
#         min_lr: The lower bound of the learning rate range for the experiment.
#         max_lr: The upper bound of the learning rate range for the experiment.
#         steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. 
#         lr_decay: Reduce the max_lr after the completion of each cycle.
#                   Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8.
#         cycle_length: Initial number of epochs in a cycle.
#         mult_factor: Scale epochs_to_restart after each full cycle completion.
#     # References
#         Blog post: jeremyjordan.me/nn-learning-rate
#         Original paper: http://arxiv.org/abs/1608.03983
#     '''
#     def __init__(self,
#                  min_lr,
#                  max_lr,
#                  steps_per_epoch,
#                  lr_decay=1,
#                  cycle_length=10,
#                  mult_factor=2):

#         self.min_lr = min_lr
#         self.max_lr = max_lr
#         self.lr_decay = lr_decay

#         self.batch_since_restart = 0
#         self.next_restart = cycle_length

#         self.steps_per_epoch = steps_per_epoch

#         self.cycle_length = cycle_length
#         self.mult_factor = mult_factor

#         self.history = {}

#     def clr(self):
#         '''Calculate the learning rate.'''
#         fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length)
#         lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi))
#         return lr

#     def on_train_begin(self, logs={}):
#         '''Initialize the learning rate to the minimum value at the start of training.'''
#         logs = logs or {}
#         K.set_value(self.model.optimizer.lr, self.max_lr)

#     def on_batch_end(self, batch, logs={}):
#         '''Record previous batch statistics and update the learning rate.'''
#         logs = logs or {}
#         self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
#         for k, v in logs.items():
#             self.history.setdefault(k, []).append(v)

#         self.batch_since_restart += 1
#         K.set_value(self.model.optimizer.lr, self.clr())

#     def on_epoch_end(self, epoch, logs={}):
#         '''Check for end of current cycle, apply restarts when necessary.'''
#         if epoch + 1 == self.next_restart:
#             self.batch_since_restart = 0
#             self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
#             self.next_restart += self.cycle_length
#             self.max_lr *= self.lr_decay
#             self.best_weights = self.model.get_weights()

#     def on_train_end(self, logs={}):
#         '''Set weights to the values from the end of the most recent cycle for best performance.'''
#         self.model.set_weights(self.best_weights)



# lr_sched = SGDRScheduler(min_lr=1e-5,
#                          max_lr=1e-2,
#                          steps_per_epoch=np.ceil(Config.NUMS_EXAMPLES/Config.BATCH_SIZE), # epoch_size = len(train)
#                          lr_decay=0.85,
#                          mult_factor=1.5)


# class ValLossDisplay(tf.keras.callbacks.Callback):
#     def __init__(self, loss_index='val_loss'):
#         self.best_loss = None
#         self.loss_index = loss_index
#         super().__init__()
    
#     def on_epoch_end(self, epoch, logs=None):
#         if self.best_loss is None or self.best_loss > logs[self.loss_index]:
#             self.best_loss = logs[self.loss_index]
#         print("\r",'Epoch: %d - Loss: %.6f (best: %.6f)' % (epoch, logs[self.loss_index], self.best_loss), ' '*5, end='')
    
#     def on_train_end(self, logs=None):
#         print("\r",'Loss: %.6f' % (self.best_loss), ' '*40)

In [None]:
model.load_weights('../input/tf-unet-model-100-epochs-768-1024-weights/tf_unet_model_100_epochs_768_1024_weights.h5')

In [None]:
# VAL_SUBSPLITS = 5
# VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=Config.EPOCHS,
                          steps_per_epoch=Config.STEPS_PER_EPOCH,
#                           validation_steps=VALIDATION_STEPS,
#                           validation_data=test_batches,
                          callbacks=[DisplayCallback()])

In [None]:
# model.save('tf_unet_model_100_epochs_768_1024')
model.save_weights('tf_unet_model_300_epoch_512_fullds_weights.h5')

In [None]:
for image, path in test_dataset.take(1):
    plt.imshow(image[0])
    title = path[0].numpy().decode('utf-8').split('/')[-1]
    plt.title(title)

In [None]:
for image, path in test_dataset:
    plt.imshow(image[0])
    print(image[0].shape)
    break

### Predict

In [None]:
# preds = model.predict(test_dataset)
preds = []
for image, path in test_dataset:
    h, w = image.shape[1] //2, image.shape[2] //2
    split_1 = image[: ,:h ,: , :]
    split_2 = image[: ,h: ,: , :]
    
    pred_mask_split_1 = model.predict_on_batch(split_1)
    pred_mask_split_2 = model.predict_on_batch(split_2)
    
    pred_mask_full = tf.concat([pred_mask_split_1, pred_mask_split_2], axis=1)
    print(image.shape, split_1.shape, pred_mask_full.shape)
    preds.append(pred_mask_full)

In [None]:
# h, w = image.shape[1] //2, image.shape[2] //2 # TensorShape([1, 6080, 6912, 3])
# split_1 = image[: ,: ,:w , :] # TensorShape([1, 6080, 3456, 3])
# split_2 = image[: ,: ,w: , :] # TensorShape([1, 6080, 3456, 3])

# tf.concat([split_1, split_2], axis=2).shape # TensorShape([1, 6080, 6912, 3])

In [None]:
pred_ = tf.squeeze(preds[0], axis=0)
pred_ = tf.image.convert_image_dtype(pred_, dtype=tf.int8)
pred_ = pred_.numpy()
# pred_[pred_ < 128] = 0
# pred_[pred_ >= 128] = 255
# cv2.imwrite('pred.jpg',pred_)
plt.imshow(pred_)

In [None]:
# get test filenames
file_names = []
for image, path in test_dataset:
    file_names.extend(list(path.numpy()))
file_names = [file.decode('utf-8').split('/')[-1] for file in file_names]

In [None]:
# Read original files size
test_img_sizes = {}
test_files = [item for item in glob('../input/tissuesegment-test/test-public-images/*.jpg')]
for file in test_files:
    img = cv2.imread(file)
    test_img_sizes[file.split('/')[-1]] = (img.shape[0], img.shape[1])

In [None]:
# Restore original size
try:
    os.mkdir('predicted_masks')
except FileExistsError:
    print('dir already exist')

for i in range(0, len(preds)):
    pred_mask = preds[i]
    pred_mask = tf.image.convert_image_dtype(pred_mask, dtype=tf.int8)
    
    target_h = test_img_sizes[file_names[i]][0]
    target_w = test_img_sizes[file_names[i]][1]
    
#     pred_mask = tf.image.resize(pred_mask, (target_h, target_w), method=tf.image.ResizeMethod.BICUBIC)
    pred_mask = tf.image.resize_with_crop_or_pad(pred_mask, target_h, target_w)
    pred_mask = tf.squeeze(pred_mask ,axis=0)
    pred_mask = pred_mask.numpy()
    pred_mask = pred_mask.astype("uint8")
    
#     if target_h > target_w:
#         pred_mask = cv2.rotate(pred_mask, cv2.cv2.ROTATE_90_CLOCKWISE)
    
    
#     pred_mask = cv2.resize(pred_mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR_EXACT)
    thresh, pred_mask = cv2.threshold(pred_mask, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    cv2.imwrite(f'./predicted_masks/{file_names[i]}', pred_mask)

In [None]:
!zip -r predicted_masks.zip ./predicted_masks/*

In [None]:
# with tf.device("gpu:0"):
#     pass

try post processing

In [None]:
!mkdir predicted_masks_post

In [None]:
# plt.figure(figsize=(18, 18))
for i in range(15):
    img = cv2.imread(f'./predicted_masks/{i+1}.jpg', cv2.COLOR_RGB2GRAY)
    kernel = np.ones((5,5),np.uint8)
    closing = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
# plt.imshow(closing)
    closing = cv2.cvtColor(closing, cv2.COLOR_RGB2GRAY)
    cv2.imwrite(f'predicted_masks_post/{i+1}.jpg', closing)

In [None]:
!zip -r predicted_masks_post.zip ./predicted_masks_post/*
!rm -r ./predicted_masks_post
!rm -r ./predicted_masks

In [None]:
# plt.figure(figsize=(18, 18))
# img = cv2.imread('./predicted_masks/7.jpg', cv2.COLOR_RGB2GRAY)
# img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# # Get rid of JPG artifacts
# img = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY)[1]

# # Create structuring elements
# horizontal_size = 50
# vertical_size = 50
# horizontalStructure = cv2.getStructuringElement(cv2.MORPH_RECT, (horizontal_size, 5))
# verticalStructure = cv2.getStructuringElement(cv2.MORPH_RECT, (5, vertical_size))

# # Morphological opening
# mask1 = cv2.morphologyEx(img, cv2.MORPH_OPEN, horizontalStructure)
# mask2 = cv2.morphologyEx(img, cv2.MORPH_OPEN, verticalStructure)

# plt.imshow(mask1)

In [None]:
# plt.figure(figsize=(18, 18))
# img = cv2.imread('./predicted_masks/7.jpg')
# imgray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# ret, thresh = cv2.threshold(imgray, 127, 255, 0)
# contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# c = cv2.drawContours(img, contours, -1, (0,255,0), thickness=cv2.FILLED)

# plt.imshow(c)

In [None]:
# plt.figure(figsize=(18, 18))
# img = cv2.imread('./predicted_masks/7.jpg')
# img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
# img = cv2.medianBlur(img, 21)
# plt.imshow(img)

In [None]:
# plt.figure(figsize=(18, 18))
# img = cv2.imread('./predicted_masks/7.jpg')
# # plt.imshow(img)
# # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# # blur = cv2.GaussianBlur(gray, (7,7), 0)
# # ret, BW = cv2.threshold(gray, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY)
# # plt.imshow(BW)
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
# kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5,10))
# close = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)

# plt.imshow(close)

In [None]:
# ## plt.figure(figsize=(18, 18))
# bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, \
#                                 cv2.THRESH_BINARY, 15, -2)
# # plt.imshow(bw)
# horizontal = np.copy(bw)
# vertical = np.copy(bw)

# # Specify size on horizontal axis
# cols = horizontal.shape[1]
# horizontal_size = cols // 30
# # Create structure element for extracting horizontal lines through morphology operations
# horizontalStructure = cv2.getStructuringElement(cv2.MORPH_RECT, (horizontal_size, 1))
# # Apply morphology operations
# horizontal = cv2.erode(horizontal, horizontalStructure)
# horizontal = cv2.dilate(horizontal, horizontalStructure)

# # Specify size on vertical axis
# rows = vertical.shape[0]
# verticalsize = rows // 30
# # Create structure element for extracting vertical lines through morphology operations
# verticalStructure = cv2.getStructuringElement(cv2.MORPH_RECT, (1, verticalsize))
# # Apply morphology operations
# vertical = cv2.erode(vertical, verticalStructure)
# vertical = cv2.dilate(vertical, verticalStructure)

# # Inverse vertical image
# vertical = cv2.bitwise_not(vertical)
# '''
# Extract edges and smooth image according to the logic
# 1. extract edges
# 2. dilate(edges)
# 3. src.copyTo(smooth)
# 4. blur smooth img
# 5. smooth.copyTo(src, edges)
# '''
# # Step 1
# edges = cv2.adaptiveThreshold(vertical, 255, cv2.ADAPTIVE_THRESH_MEAN_C, \
#                             cv2.THRESH_BINARY, 3, -2)
# # Step 2
# kernel = np.ones((2, 2), np.uint8)
# edges = cv2.dilate(edges, kernel)
# # Step 3
# smooth = np.copy(vertical)
# # Step 4
# smooth = cv2.blur(smooth, (2, 2))
# # Step 5
# (rows, cols) = np.where(edges != 0)
# vertical[rows, cols] = smooth[rows, cols]
# # Show final result
# # plt.imshow(vertical)

In [None]:
# plt.figure(figsize=(18, 18))
# kernel = np.ones((5,5),np.uint8)
# erosion = cv2.erode(img,kernel,iterations = 1)
# dilation = cv2.dilate(img,kernel,iterations = 1)
# opening = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)
# closing = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
# gradient = cv2.morphologyEx(img, cv2.MORPH_GRADIENT, kernel)
# plt.imshow(gradient)