In [1]:
import tensorflow as tf
from tensorflow import keras
import math
import tensorflow.keras.backend as K
#import tensorflow_addons as tda
#import tensorflow_hub as hub
import tensorflow_datasets as tfds
#import efficientnet.tfkeras as efn
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import tarfile
import os
import cv2
from functools import partial
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
!pip install keras-unet-collection >>/dev/null
from keras_unet_collection import models
from keras_unet_collection import losses
from keras_unet_collection import base

AUTO = tf.data.experimental.AUTOTUNE
DEVICE = "TPU"



In [2]:
if DEVICE == "TPU":
    print("connecting to TPU...")
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    print("Could not connect to TPU")
    tpu = None

if tpu:
    try:
        print("initializing  TPU ...")
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.TPUStrategy(tpu)
        print("TPU initialized")
    except Exception:
        print("failed to initialize TPU")
else:
    DEVICE = "GPU"

if DEVICE != "TPU":
    print("Using default strategy for CPU and single GPU")
    strategy = tf.distribute.get_strategy()

if DEVICE == "GPU":
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))


AUTO     = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

connecting to TPU...
Running on TPU  grpc://10.0.0.2:8470
initializing  TPU ...
TPU initialized
REPLICAS: 8


In [3]:
#test_img = plt.imread("../input/isic2017-and-ph2/ISIC_2017 + PH2/ISIC_2017/trainx/ISIC_0000000.jpg")
image = cv2.imread("../input/isic2017-and-ph2/ISIC_2017 + PH2/ISIC_2017/trainx/ISIC_0000000.jpg")
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
#image = tf.image.resize(image, [512, 512])
image = cv2.imencode('.jpg', image, (cv2.IMWRITE_JPEG_QUALITY, 94))[1]#.tobytes()
print(image)

[[255]
 [216]
 [255]
 ...
 [127]
 [255]
 [217]]


In [4]:
test_img = plt.imread("../input/isic2017-and-ph2/ISIC_2017 + PH2/ISIC_2017/trainx/ISIC_0000000.jpg")
test_img = tf.convert_to_tensor(test_img)
test_img = tf.image.resize(test_img, [512, 512])/255
test_img = tf.image.convert_image_dtype(test_img, tf.uint8, saturate=True, name=None)
test_img = tf.io.encode_jpeg(test_img)

## READING TFRECORDS:

In [5]:
def get_seg_paths(data_type="train", tfrec_roots=None, img_root_paths=None):
    if data_type == "tfrecords":
        test_paths = []
        train_paths = []
        for tfrec_root in tfrec_roots:  
            test_paths += tf.io.gfile.glob(tfrec_root+'/test*.tfrec')
            train_paths += tf.io.gfile.glob(tfrec_root+'/train*.tfrec')
        test_paths = np.sort(np.array(test_paths))
        train_paths = np.sort(np.array(train_paths))
        return train_paths, test_paths
    else:
        complete_img_paths = [0]*len(img_root_paths)
        for index, img_root_path in enumerate(img_root_paths):
            if index == 0:
                complete_img_paths[index] = np.sort(np.array(tf.io.gfile.glob(img_root_path + '/*.jpg')))
                complete_img_paths[index] = np.sort(np.array(tf.io.gfile.glob(img_root_path + '/*.png')))
            else:
                complete_img_paths[index] = np.sort(np.array(tf.io.gfile.glob(img_root_path + '/*.jpg')))
                complete_img_paths[index] = np.sort(np.array(tf.io.gfile.glob(img_root_path + '/*.png')))
        return complete_img_paths
    
    
#All the code below comes from TensorFlow's docs here

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_example(feature0, feature1, feature2):
  feature = {
      'image': _bytes_feature(feature0),
      'image_name': _bytes_feature(feature1),
      'mask': _bytes_feature(feature2),
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

def create_seg_tfrecords(tfrecord_type="train", SIZE=500, tfrec_roots=None, img_root_paths=None):
    image_paths, mask_paths = get_seg_paths(data_type=tfrecord_type,
                                            tfrec_roots=tfrec_roots,
                                            img_root_paths=img_root_paths)
    folder_name = f'{tfrecord_type}_tfrecords'
    path = os.path.join(os.getcwd(), folder_name)
    os.makedirs(path, exist_ok=True)
    path_zip = zip(image_paths, mask_paths)
    tfrecord_nums = image_paths.size // 500 + 1
    for tfrecord_counter in range(tfrecord_nums):
        tfrecord_size = min(SIZE, image_paths.size-tfrecord_counter*SIZE)
        print('\nCreating {tfrecord_type}_{tfrecord_counter}.tfrec......'.format(
                            tfrecord_type=tfrecord_type,tfrecord_counter=tfrecord_counter))
        with tf.io.TFRecordWriter(os.path.join(path, f'{tfrecord_type}{tfrecord_counter}.tfrec')) as writer:
            for k in range(tfrecord_size):
                # processing image
                image = cv2.imread(image_paths[tfrecord_size * tfrecord_counter + k])
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                image = tf.convert_to_tensor(image)
                image = tf.image.resize(image, [512, 512])/255
                image = tf.image.convert_image_dtype(image, tf.uint8, saturate=True, name=None)
                image = tf.io.encode_jpeg(image)

                # processing mask
                mask = cv2.imread(mask_paths[tfrecord_size * tfrecord_counter + k])
                mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
                mask = tf.convert_to_tensor(mask)
                mask = tf.image.resize(mask, [512, 512])/255
                mask = tf.image.convert_image_dtype(mask, tf.uint8, saturate=True, name=None)
                mask = tf.io.encode_jpeg(mask)

                # extracting image name
                image_name = os.path.split(image_paths[tfrecord_size*tfrecord_counter+k])[1]
                image_name = image_name.split('.')[0]
                
                # writing the example
                example = serialize_example(image, str.encode(image_name), mask)
                writer.write(example)
                if k%100==0: print(k,', ',end='')

def read_seg_tfrecord(example, labeled=True, return_image_names=False):
    if labeled:
        tfrec_format = {
            'image'                        : tf.io.FixedLenFeature([], tf.string),
            'image_name'                   : tf.io.FixedLenFeature([], tf.string),
            'mask'                       : tf.io.FixedLenFeature([], tf.string)
        }      
    else:
        tfrec_format = {
            'image'                        : tf.io.FixedLenFeature([], tf.string),
            'image_name'                   : tf.io.FixedLenFeature([], tf.string),
        }
        
    example = tf.io.parse_single_example(example, tfrec_format)
    if labeled:
        return ({"seg_input": example['image']}, example['mask'])
    else:
        return ({"seg_input": example['image']},
                example['image_name'] if return_image_names else 0)


In [6]:
def prepare_image(images, dim=384, mask=False, mask_channels=3): 
    if mask:
        img = images
        img = tf.image.decode_jpeg(img, channels=mask_channels)
    else:
        img = images['seg_input']
        img = tf.image.decode_jpeg(img, channels=3)
        
    channels = img.shape.as_list()[-1]
    img = tf.cast(img, tf.float32) / 255.0
    
    img = tf.reshape(img, [dim, dim, channels])
    if mask:
        return img
    else:
        images['seg_input'] = img
        return images

In [7]:
@tf.function
def transform_grid_mark(image, inv_mat, image_shape):
    h, w, c = image_shape
    
    cx, cy = w//2, h//2

    new_xs = tf.repeat( tf.range(-cx, cx, 1), h)
    new_ys = tf.tile( tf.range(-cy, cy, 1), [w])
    new_zs = tf.ones([h*w], dtype=tf.int32)

    old_coords = tf.matmul(inv_mat, tf.cast(tf.stack([new_xs, new_ys, new_zs]), tf.float32))
    old_coords_x, old_coords_y = tf.round(old_coords[0, :] + tf.cast(w, tf.float32)//2.), tf.round(old_coords[1, :] + tf.cast(h, tf.float32)//2.)
    old_coords_x = tf.cast(old_coords_x, tf.int32)
    old_coords_y = tf.cast(old_coords_y, tf.int32)    

    clip_mask_x = tf.logical_or(old_coords_x<0, old_coords_x>w-1)
    clip_mask_y = tf.logical_or(old_coords_y<0, old_coords_y>h-1)
    clip_mask = tf.logical_or(clip_mask_x, clip_mask_y)

    old_coords_x = tf.boolean_mask(old_coords_x, tf.logical_not(clip_mask))
    old_coords_y = tf.boolean_mask(old_coords_y, tf.logical_not(clip_mask))
    new_coords_x = tf.boolean_mask(new_xs+cx, tf.logical_not(clip_mask))
    new_coords_y = tf.boolean_mask(new_ys+cy, tf.logical_not(clip_mask))

    old_coords = tf.cast(tf.stack([old_coords_y, old_coords_x]), tf.int32)
    new_coords = tf.cast(tf.stack([new_coords_y, new_coords_x]), tf.int64)
    rotated_image_values = tf.gather_nd(image, tf.transpose(old_coords))
    rotated_image_channel = list()
    for i in range(c):
        vals = rotated_image_values[:,i]
        sparse_channel = tf.SparseTensor(tf.transpose(new_coords), vals, [h, w])
        rotated_image_channel.append(tf.sparse.to_dense(sparse_channel, default_value=0, validate_indices=False))

    return tf.transpose(tf.stack(rotated_image_channel), [1,2,0])


@tf.function
def random_rotate(image, angle, image_shape):
    def get_rotation_mat_inv(angle):
          #transform to radian
        angle = math.pi * angle / 180

        cos_val = tf.math.cos(angle)
        sin_val = tf.math.sin(angle)
        one = tf.constant([1], tf.float32)
        zero = tf.constant([0], tf.float32)

        rot_mat_inv = tf.concat([cos_val, sin_val, zero,
                                     -sin_val, cos_val, zero,
                                     zero, zero, one], axis=0)
        rot_mat_inv = tf.reshape(rot_mat_inv, [3,3])

        return rot_mat_inv
    angle = float(angle) * tf.random.normal([1],dtype='float32')
    rot_mat_inv = get_rotation_mat_inv(angle)
    return transform_grid_mark(image, rot_mat_inv, image_shape)


@tf.function
def get_grid_mask(DIM=384):
    h = tf.constant(DIM, dtype=tf.float32)
    w = tf.constant(DIM, dtype=tf.float32)
    
    image_height, image_width = (h, w)

    # CHANGE THESE PARAMETER
    d1 = int(DIM / 6)
    d2 = int(DIM / 4)
    rotate_angle = 45
    ratio = 0.4 # this is delete ratio, so keep ratio = 1 - delete

    hh = tf.math.ceil(tf.math.sqrt(h*h+w*w))
    hh = tf.cast(hh, tf.int32)
    hh = hh+1 if hh%2==1 else hh
    d = tf.random.uniform(shape=[], minval=d1, maxval=d2, dtype=tf.int32)
    l = tf.cast(tf.cast(d,tf.float32)*ratio+0.5, tf.int32)

    st_h = tf.random.uniform(shape=[], minval=0, maxval=d, dtype=tf.int32)
    st_w = tf.random.uniform(shape=[], minval=0, maxval=d, dtype=tf.int32)

    y_ranges = tf.range(-1 * d + st_h, -1 * d + st_h + l)
    x_ranges = tf.range(-1 * d + st_w, -1 * d + st_w + l)

    for i in range(0, hh//d+1):
        s1 = i * d + st_h
        s2 = i * d + st_w
        y_ranges = tf.concat([y_ranges, tf.range(s1,s1+l)], axis=0)
        x_ranges = tf.concat([x_ranges, tf.range(s2,s2+l)], axis=0)

    x_clip_mask = tf.logical_or(x_ranges <0 , x_ranges > hh-1)
    y_clip_mask = tf.logical_or(y_ranges <0 , y_ranges > hh-1)
    clip_mask = tf.logical_or(x_clip_mask, y_clip_mask)

    x_ranges = tf.boolean_mask(x_ranges, tf.logical_not(clip_mask))
    y_ranges = tf.boolean_mask(y_ranges, tf.logical_not(clip_mask))

    hh_ranges = tf.tile(tf.range(0,hh), [tf.cast(tf.reduce_sum(tf.ones_like(x_ranges)), tf.int32)])
    x_ranges = tf.repeat(x_ranges, hh)
    y_ranges = tf.repeat(y_ranges, hh)

    y_hh_indices = tf.transpose(tf.stack([y_ranges, hh_ranges]))
    x_hh_indices = tf.transpose(tf.stack([hh_ranges, x_ranges]))

    y_mask_sparse = tf.SparseTensor(tf.cast(y_hh_indices, tf.int64),  tf.zeros_like(y_ranges), [hh, hh])
    y_mask = tf.sparse.to_dense(y_mask_sparse, 1, False)

    x_mask_sparse = tf.SparseTensor(tf.cast(x_hh_indices, tf.int64), tf.zeros_like(x_ranges), [hh, hh])
    x_mask = tf.sparse.to_dense(x_mask_sparse, 1, False)

    mask = tf.expand_dims( tf.clip_by_value(x_mask + y_mask, 0, 1), axis=-1)

    mask = random_rotate(mask, rotate_angle, [hh, hh, 1])
    mask = tf.image.crop_to_bounding_box(mask, (hh-tf.cast(h, tf.int32))//2, (hh-tf.cast(w, tf.int32))//2, tf.cast(image_height, tf.int32), tf.cast(image_width, tf.int32))

    return mask


@tf.function
def apply_grid_mask(image, mask, DIM=384):
    #mask = grid_mask(DIM=DIM)
    mask = tf.concat([mask, mask, mask], axis=-1)
    return image * tf.cast(mask, 'float32')



def augmenter(image, mask, dim=384, grid_mask=True, grid_mask_aug=True):
    """
    position_change_func_list = [
        tf.image.stateless_random_flip_left_right,
        tf.image.stateless_random_flip_up_down,
    ]

    color_change_func_list =[
        partial(tf.image.stateless_random_brightness, max_delta=0.95),
        partial(tf.image.stateless_random_contrast, upper=0.5, lower=0.1),
        partial(tf.image.stateless_random_hue, max_delta=0.3),
        partial(tf.image.stateless_random_saturation, upper=0.6, lower=0.1), 
    ] 
    
    #print(image["seg_input"].shape)
    img_cum_mask_array = tf.stack([image["seg_input"], mask], 3)
    print(img_cum_mask_array.shape)
    #print(asdad)
    for i, aug_func in enumerate(position_change_func_list):
        rand_seed = tf.random.uniform(shape=[2], minval=-10**5, maxval=10**5, dtype=tf.int64)
        aug_img = aug_func(image["seg_input"], seed=rand_seed)
        aug_mask = aug_func(mask, seed=rand_seed)
        augmented_img_cum_mask = tf.stack([aug_img, aug_mask], 3)
        if i == 0:
            img_cum_mask_array = tf.stack([img_cum_mask_array,
                                           augmented_img_cum_mask], 0)
        else:
            img_cum_mask_array = tf.concat([img_cum_mask_array,
                                            tf.expand_dims(augmented_img_cum_mask, 0)], 0)
        print(img_cum_mask_array.shape)
    
    for i, aug_func in enumerate(color_change_func_list):
        rand_seed = tf.random.uniform(shape=[2], minval=-10**5, maxval=10**5, dtype=tf.int64)
        aug_img = aug_func(image["seg_input"], seed=rand_seed)
        augmented_img_cum_mask = tf.stack([aug_img, mask], 3)
        img_cum_mask_array = tf.concat([img_cum_mask_array,
                                        tf.expand_dims(augmented_img_cum_mask, 0)], 0)
        print(img_cum_mask_array.shape)
            
    augmented_ds = tf.data.Dataset.from_tensor_slices(img_cum_mask_array)
    """

    position_change_func_list = [
        tf.image.stateless_random_flip_left_right,
        tf.image.stateless_random_flip_up_down,
    ]
    
    color_change_func_list =[
        partial(tf.image.stateless_random_brightness, max_delta=0.95),
        partial(tf.image.stateless_random_contrast, upper=0.5, lower=0.1),
        partial(tf.image.stateless_random_hue, max_delta=0.3),
        partial(tf.image.stateless_random_saturation, upper=0.6, lower=0.1), 
    ] 
    
    img_cum_mask_array = tf.stack([image["seg_input"], mask], 3)

    for i, aug_func in enumerate(position_change_func_list):
        rand_seed = tf.random.uniform(shape=[2], minval=-10**5, maxval=10**5, dtype=tf.int64)
        aug_img = aug_func(image["seg_input"], seed=rand_seed)
        aug_mask = aug_func(mask, seed=rand_seed)
        augmented_img_cum_mask = tf.stack([aug_img, aug_mask], 3)
        if i == 0:
            img_cum_mask_array = tf.stack([img_cum_mask_array,
                                           augmented_img_cum_mask], 0)
        else:
            img_cum_mask_array = tf.concat([img_cum_mask_array,
                                            tf.expand_dims(augmented_img_cum_mask, 0)], 0)
    
    for i, aug_func in enumerate(color_change_func_list):
        rand_seed = tf.random.uniform(shape=[2], minval=-10**5, maxval=10**5, dtype=tf.int64)
        aug_img = aug_func(image["seg_input"], seed=rand_seed)
        augmented_img_cum_mask = tf.stack([aug_img, mask], 3)
        img_cum_mask_array = tf.concat([img_cum_mask_array,
                                        tf.expand_dims(augmented_img_cum_mask, 0)], 0)
    
    if grid_mask:
        grid_mask = get_grid_mask(DIM=dim)
        grid_masked_img = apply_grid_mask(image["seg_input"], grid_mask)
        grid_masked_mask = apply_grid_mask(mask, grid_mask)
        augmented_img_cum_mask = tf.stack([grid_masked_img, grid_masked_mask], 3)
        img_cum_mask_array = tf.concat([img_cum_mask_array,
                                        tf.expand_dims(augmented_img_cum_mask, 0)], 0)
        if grid_mask_aug:
            """
            all_aug_func_list = position_change_func_list + color_change_func_list
            for i, aug_func in enumerate(all_aug_func_list):
                rand_seed = tf.random.uniform(shape=[2], minval=-10**5, maxval=10**5, dtype=tf.int64)
                grid_mask_aug_img = aug_func(grid_masked_img, seed=rand_seed)
                aug_img_array = tf.concat([aug_img_array, tf.expand_dims(aug_img, 0)], 0)
            """
            for i, aug_func in enumerate(position_change_func_list):
                rand_seed = tf.random.uniform(shape=[2], minval=-10**5, maxval=10**5, dtype=tf.int64)
                aug_img = aug_func(grid_masked_img, seed=rand_seed)
                aug_mask = aug_func(grid_masked_mask, seed=rand_seed)
                augmented_img_cum_mask = tf.stack([aug_img, aug_mask], 3)
                img_cum_mask_array = tf.concat([img_cum_mask_array,
                                                tf.expand_dims(augmented_img_cum_mask, 0)], 0)

            for i, aug_func in enumerate(color_change_func_list):
                rand_seed = tf.random.uniform(shape=[2], minval=-10**5, maxval=10**5, dtype=tf.int64)
                aug_img = aug_func(grid_masked_img, seed=rand_seed)
                augmented_img_cum_mask = tf.stack([aug_img, grid_masked_mask], 3)
                img_cum_mask_array = tf.concat([img_cum_mask_array,
                                                tf.expand_dims(augmented_img_cum_mask, 0)], 0)
        
    augmented_ds = tf.data.Dataset.from_tensor_slices(img_cum_mask_array)
    
    def input_handler(img_cum_mask):
        return {"seg_input": img_cum_mask[:, :, :, 0]}, img_cum_mask[:, :, :, 1]
    
    augmented_ds = augmented_ds.map(lambda img_cum_mask: input_handler(img_cum_mask),
                                    num_parallel_calls=AUTO)
    return augmented_ds

In [8]:
def get_seg_dataset(files, shuffle = False, repeat = False, labeled=True, augment=True,
                    return_image_names=False, batch_size=32, dim=384, mask_channels=3,
                    grid_mask=True, grid_mask_aug=True):
    
    read_labeled_unlabeled = partial(read_seg_tfrecord, labeled=labeled)

    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO)
    ds = ds.cache()
    
    if repeat:
        ds = ds.repeat()
    
    if shuffle: 
        ds = ds.shuffle(1024*8)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)
    
    ds = ds.map(read_labeled_unlabeled, num_parallel_calls=AUTO)
    ds = ds.map(lambda image, mask_or_image_name: (prepare_image(image, dim=dim),
                                                   mask_or_image_name),
                num_parallel_calls=AUTO)
    # for image dataset
    if labeled: 
        ds = ds.map(lambda image, mask: (image, prepare_image(mask, dim=dim,
                                                              mask=True,
                                                              mask_channels=mask_channels)),
                    num_parallel_calls=AUTO)
        
        if augment: # I am not using Test Time Augmentation, this is for Train only.
            aug_func = partial(augmenter, dim=dim, grid_mask=grid_mask,
                               grid_mask_aug=grid_mask_aug)
            ds = ds.flat_map(lambda image, mask: aug_func(image, mask))
            ds = ds.shuffle(1024*8)
            
    ds = ds.batch(batch_size * REPLICAS, drop_remainder=True)
    ds = ds.prefetch(AUTO)
    return ds

In [9]:
#IMG_ROOT_PATHS = ["../input/isic2017-and-ph2/ISIC_2017 + PH2/ISIC_2017/trainx",
#                  "../input/isic2017-and-ph2/ISIC_2017 + PH2/ISIC_2017/trainy"]

#image_paths, mask_paths = get_seg_paths("train", img_root_paths=IMG_ROOT_PATHS)

#image_paths

In [10]:
#create_seg_tfrecords(tfrecord_type="train", SIZE=500, tfrec_roots=None, img_root_paths=IMG_ROOT_PATHS)

In [11]:
TFREC_ROOT_PATHS = [KaggleDatasets().get_gcs_path("isic2018andph2384x384tfrecords"),
                    KaggleDatasets().get_gcs_path("isic2017384x384tfrecords")]

train_paths_segs, test_paths_segs = get_seg_paths("tfrecords", tfrec_roots=TFREC_ROOT_PATHS)
valid_paths_segs = train_paths_segs[-1]
train_paths_segs = train_paths_segs[:-1]

BATCH_SIZE = 16
MASK_CHANNELS = 3
train_dataset_seg = get_seg_dataset(train_paths_segs, batch_size=BATCH_SIZE,# augment=False,
                                    labeled=True, mask_channels=MASK_CHANNELS,
                                    grid_mask_aug=False)
valid_dataset_seg = get_seg_dataset(valid_paths_segs, batch_size=BATCH_SIZE,
                                    labeled=True, mask_channels=MASK_CHANNELS,
                                    grid_mask_aug=False)

In [23]:
test_dataset_seg = get_seg_dataset(test_paths_segs, batch_size=BATCH_SIZE,
                                   labeled=True, mask_channels=MASK_CHANNELS,
                                   grid_mask=False, grid_mask_aug=False)

In [13]:
#for item in train_dataset_seg.take(1):
#    print(item)

## METRICS:

In [14]:
def iou(y_true, y_pred, smooth = 100):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.square(y_true), axis = -1) + K.sum(K.square(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return jac


def dice_coe(y_true, y_pred, smooth = 1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def precision(y_true, y_pred):
    '''Calculates the precision, a metric for multi-label classification of
    how many selected items are relevant.
    '''
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision


def recall(y_true, y_pred):
    '''Calculates the recall, a metric for multi-label classification of
    how many relevant items are selected.
    '''
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall


def accuracy(y_true, y_pred):
    '''Calculates the mean accuracy rate across all predictions for binary
    classification problems.
    '''
    return K.mean(K.equal(y_true, K.round(y_pred)))

## BUILDING AND TRAINING THE MODEL:

In [15]:
help(base.unet_3plus_2d_base)

Help on function unet_3plus_2d_base in module keras_unet_collection._model_unet_3plus_2d:

unet_3plus_2d_base(input_tensor, filter_num_down, filter_num_skip, filter_num_aggregate, stack_num_down=2, stack_num_up=1, activation='ReLU', batch_norm=False, pool=True, unpool=True, backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus')
    The base of UNET 3+ with an optional ImagNet-trained backbone.
    
    unet_3plus_2d_base(input_tensor, filter_num_down, filter_num_skip, filter_num_aggregate, 
                       stack_num_down=2, stack_num_up=1, activation='ReLU', batch_norm=False, pool=True, unpool=True, 
                       backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus')
                  
    ----------
    Huang, H., Lin, L., Tong, R., Hu, H., Zhang, Q., Iwamoto, Y., Han, X., Chen, Y.W. and Wu, J., 2020. 
    UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation. 
    In

In [16]:
def get_lr_callback(batch_size=8, REPLICAS=1):
    lr_start   = 0.00005
    lr_max     = 0.0000125 * REPLICAS * batch_size
    lr_min     = 0.00001
    lr_ramp_ep = 5
    lr_sus_ep  = 0
    lr_decay   = 0.8
   
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
            
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
            
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
            
        return lr

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    return lr_callback

def binary_jaccard_index(y_true, y_pred, smooth=100):
    y_true = K.round(y_true)
    y_pred = K.round(y_pred)
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3])
    #print(intersection)
    union = K.sum(K.abs(y_true) + K.abs(y_pred), axis=[1, 2, 3])
    #print(union)
    #print(tf.clip_by_value(union - intersection, K.epsilon(), None))
    iou = intersection / K.clip(union - intersection, K.epsilon(), None)
    return iou

def jaccard_distance(y_true, y_pred, smooth=100):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.square(y_true), axis = -1) + K.sum(K.square(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac)
      
DIM = 384 

if DEVICE=='TPU':
        if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)
        

activation = 'ReLU'
filter_num = [32, 64, 128, 256, 512, 1024]
stack_num_down = 2
stack_num_up = 1
recur_num=2

with strategy.scope():
    input_layer = keras.layers.Input(shape=(DIM, DIM, 3), name="seg_input")
    unet_base = base.r2_unet_2d_base(input_layer, filter_num=filter_num, stack_num_down=stack_num_down,
                                     stack_num_up=stack_num_up, recur_num=recur_num, activation="ReLU",
                                     batch_norm=True, pool="max", unpool="nearest", name="res_unet_base")
    unet_output = keras.layers.Conv2D(MASK_CHANNELS, (1, 1), activation="sigmoid")(unet_base)
    unet_model = keras.Model(inputs=[input_layer], outputs=[unet_output])
    unet_model.compile(optimizer=keras.optimizers.Nadam(0.0001),
                       loss=[losses.focal_tversky],
                       metrics=[dice_coe])
    
#unet_model.summary()

In [17]:
#keras.utils.plot_model(unet_model)

In [18]:
history = unet_model.fit(
    train_dataset_seg,
    epochs=50,
    validation_data=valid_dataset_seg,
    callbacks=[#get_lr_callback(BATCH_SIZE, REPLICAS),
               keras.callbacks.ModelCheckpoint("r2_unet.h5", monitor='val_loss',
                                               verbose=2, save_best_only=True,
                                               save_weights_only=True, mode='min',
                                               save_freq='epoch'),
               keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)],
    #steps_per_epoch=count_data_items(train_paths_segs)/32//REPLICAS,
    verbose=2,
)

Epoch 1/50
281/281 - 567s - loss: 0.3565 - dice_coe: 0.7382 - val_loss: 0.4886 - val_dice_coe: 0.6125

Epoch 00001: val_loss improved from inf to 0.48855, saving model to r2_unet.h5
Epoch 2/50
281/281 - 523s - loss: 0.2063 - dice_coe: 0.8762 - val_loss: 0.2993 - val_dice_coe: 0.7964

Epoch 00002: val_loss improved from 0.48855 to 0.29932, saving model to r2_unet.h5
Epoch 3/50
281/281 - 520s - loss: 0.1758 - dice_coe: 0.9002 - val_loss: 0.3087 - val_dice_coe: 0.7866

Epoch 00003: val_loss did not improve from 0.29932
Epoch 4/50
281/281 - 524s - loss: 0.1584 - dice_coe: 0.9131 - val_loss: 0.2560 - val_dice_coe: 0.8335

Epoch 00004: val_loss improved from 0.29932 to 0.25602, saving model to r2_unet.h5
Epoch 5/50
281/281 - 524s - loss: 0.1449 - dice_coe: 0.9230 - val_loss: 0.2181 - val_dice_coe: 0.8662

Epoch 00005: val_loss improved from 0.25602 to 0.21813, saving model to r2_unet.h5
Epoch 6/50
281/281 - 521s - loss: 0.1329 - dice_coe: 0.9315 - val_loss: 0.1902 - val_dice_coe: 0.8880

Epo

In [None]:
weight_dir = os.path.join(os.getcwd(), "weights")
os.makedirs(weight_dir, exist_ok=True)
unet_model.save_weights(os.path.join(weight_dir, "r2_unet_weights.h5"))

In [19]:
unet_model.load_weights("./r2_unet.h5")

In [20]:
import pickle
pickle.dump(history.history, open("r2_unet_history.pkl", "wb"))

In [25]:
weights_list = ["../input/r2-unet-weight/gm_aug_r2_unet_extreme.h5",
                "../input/r2-unet-weight/r2_unet_extreme.h5",
                "../input/r2-unet-weight/r2_unet_better.h5",
                "../input/r2-unet-weight/r2_unet.h5",]
for weights in weights_list:
    unet_model.load_weights(weights)
    unet_model.evaluate(test_dataset_seg)



In [None]:
train_preds = unet_model.predict(train_dataset_seg)

In [None]:
plt.imshow(train_preds[0])

In [None]:
plt.imshow(preds[10])

In [None]:
for item in test_dataset_seg.take(1):
    images = item[0]["seg_input"]
    image_names = item[1]

In [None]:
plt.imshow(images[20])

In [None]:
plt.imshow(preds[20])

In [None]:
## If the predictions are blurry, then use this function for better thresholding
def enhance_preds(img_data, threshold=0.5, dim_x=384, dim_y=384, channels=3):
    
    if len(img_data.shape) == 3:
        real_img_data_dims = 3
        img_data = tf.expand_dims(img_data, axis=0)
    else:
        real_img_data_dims = 4
        
    preds = unet_model.predict(img_data)
    plt.imshow(preds[0])
    if real_img_data_dims == 3:
        preds = preds.flatten()
        for i in range(len(preds)):
            if preds[i] > 0.5:
                preds[i] = 1
            else:
                preds[i] = 0 
        return tf.reshape(preds, [dim_x, dim_y, channels])
    else:
        for i in range(preds.shape[0]):
            pred_img = preds[i].flatten() 
            for j in range(len(pred_img)):
                if pred_img[j] > 0.5:
                    pred_img[j] = 1
                else:
                    pred_img[j] = 0
            preds[i] = tf.reshape(pred_img, [dim_x, dim_y, channels])
        return preds