In [None]:
!pip install segmentation_models
!pip install albumentations==0.4.5
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
from PIL import Image
import albumentations as A
import glob

import skimage.transform as sk_transform
import skimage.filters as sk_filters

from skimage.measure import label, regionprops
import os

import numpy as np

import segmentation_models as sm

In [None]:
def get_aug(aug, min_area=0., min_visibility=0.):
    bbox_params = A.BboxParams(format='coco', min_area=min_area, min_visibility=min_visibility, label_fields=['category_id'])
    return A.Compose(aug, bbox_params)


class DataLoader(tf.keras.utils.Sequence):
    def __init__(self, 
                 dataset, 
                 batch_size, 
                 shuffle=True, 
                 output_size=(512,512), 
                 is_validation=False,
                 **kwargs):
        self.dataset = dataset
        self._len = len(self.dataset)
        self.indices = range(self._len)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self._output_size = output_size
        self.aug = self.init_aug(is_validation)        
        self.on_epoch_end()

    def init_aug(self, is_validation):
        if is_validation:
            aug = get_aug([                      
              A.Resize(width=self._output_size[0], height=self._output_size[1], always_apply=True),
              A.Normalize(),
            ], min_visibility=0.1)
        else:
            aug = get_aug([
                A.RGBShift(p=0.1),
                #A.JpegCompression(p=0.2, quality_lower=80),
                A.OneOf([
                  A.RandomBrightnessContrast(p=0.5),            
                  A.HueSaturationValue(),
                  A.RandomGamma(p=0.25),
                  A.RandomBrightness(p=0.25),
                  A.Blur(blur_limit=2,p=0.25),
                ],p=0.0),

                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.05),

                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.05, rotate_limit=15,  border_mode=0,  p=0.2, value=(144.75479165, 137.70713403, 129.666091), mask_value=0.0 ),

                #A.RandomSizedBBoxSafeCrop(width=self._output_size[0], height=self._output_size[1], erosion_rate=0.2, always_apply=True),
                A.Resize(width=self._output_size[0], height=self._output_size[1], always_apply=True),
              
                A.Normalize(),
          ], min_visibility=0.1)
        return aug
        
    def __len__(self):        
        return self._len // self.batch_size

    def __getitem__(self, index):
        """ Generate one batch of data. """
        s = index * self.batch_size % self._len
        e = s + self.batch_size
        indices = self.indices[s:e]

        return self.__data_generator(indices)

    def on_epoch_end(self):
        """ Updates indices after each epoch. """
        if self.shuffle:
            self.indices = np.random.permutation(self._len)
            
    def augment(self, img, mask):
        label_image = label(mask)
        bboxes = []
        for region in regionprops(label_image):
            if region.area >= 100:
                minr, minc, maxr, maxc = region.bbox
                bboxes.append((minc, minr, maxc-minc, maxr-minr ))
                
                
        if len(bboxes) == 0:
            #print ("no bboxes")
            bboxes = [ [0, 0, img.shape[1], img.shape[0]] ]
            

        new_img = None
        new_mask = None
        try:
            annotations = {'image': img, 
                   "masks" : [mask],
                   'bboxes': bboxes,
                   #'cropping_bbox': [minc, minr,  maxc - minc , maxr - minr],
                   #'cropping_bbox': [0.1, 0.1, 0.2, 0.2],
                   'category_id' : [255] * len(bboxes)}
            
            augmented = self.aug(**annotations)
            new_img = augmented['image']
            new_mask = augmented["masks"][0]
        except Exception as e:
            print(e)
            new_img = img
            new_mask = mask
        return new_img, new_mask
        

    def __data_generator(self, indices):
        # Init the matrix
        batch_images, batch_target = [], []
        for idx in indices:
            image_path, label_path = self.dataset[idx]
            image = np.array(Image.open(image_path).convert('RGB'))
            target = np.array(Image.open(label_path))

            ## Rescale masks from [0; 255] to [0; 1]
            target[target > 0] = 1
            target = target.astype('float32')                                        
            


            image, target = self.augment(image, target)            
            # For some unclear reasons sometimes albumentations tries to generate a crop larger then the image by itself
            # and i didn't find any way how to catch this situtation
            # In this case we will just resize an input image to the destination size and that's all            
            image_shape = image.shape[:2]
#             if image_shape[0] != self._output_size[0] or image_shape[1] != self._output_size[1]:
#                 new_img_shape = list(image.shape)
#                 new_img_shape[:2] = self._output_size[:2]
#                 image = sk_transform.resize(image, output_shape=tuple(new_img_shape), preserve_range=True)
#                 image = (image - np.array(mean)) / (np.array(std) + 1e-7)

                
#                 new_mask_shape = list(target.shape)
#                 new_mask_shape[:2] = self._output_size[:2]                
#                 target = sk_transform.resize(target, order=0, output_shape=tuple(new_mask_shape), preserve_range=True)
    

            # if shape of mask is not h*w*c
            if len(target.shape) != 3:
                ## the keras model require h*w*1
                target = np.expand_dims(target, axis=-1)
            
            batch_images.append(image)
            batch_target.append(target)
        
        if len(batch_images) < self.batch_size:
            pad_images = [np.zeros_like(batch_images[0]) 
                          for _ in range(self.batch_size-len(batch_images))]
            pad_target = [np.zeros_like(batch_target[0]) 
                          for _ in range(self.batch_size-len(batch_target))]
            batch_images.extend(pad_images)
            batch_target.extend(pad_target)

        return np.stack(batch_images), np.stack(batch_target)

In [None]:
base_dir = "/kaggle/input/supervisely/processed"
ims_dir = os.path.join(base_dir, "imgs")
labels_dir = os.path.join(base_dir, "labels")
img_files = sorted(glob.glob(ims_dir + "/*.*"))
mask_files = sorted(glob.glob(labels_dir + "/*.png"))

In [None]:
from sklearn import model_selection

In [None]:
dataset = list(zip(img_files, mask_files))
train_dataset, test_dataset = model_selection.train_test_split(dataset, test_size=0.2, random_state=0)
len(train_dataset), len(test_dataset)
train_data_loader = DataLoader(dataset=train_dataset, batch_size=16, output_size=(256, 256), shuffle=True)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=16, is_validation=True, output_size=(256, 256), shuffle=False)

In [None]:
test_dataset[0]

In [None]:
OUTPUT_CHANNELS = 1
base_model = tf.keras.applications.MobileNetV2(input_shape=[256, 256, 3], include_top=False, weights="imagenet")

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = False
# down_stack.trainable = True

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [None]:
def unet_model(output_channels):
  inputs = tf.keras.layers.Input(shape=[256, 256, 3])
  x = inputs

  # Downsampling through the model
  skips = down_stack(x)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)
  x = tf.keras.layers.Activation('sigmoid')(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model = unet_model(OUTPUT_CHANNELS)
model.load_weights("/kaggle/input/segmentation/segmentation/models/best_model.h5")

In [None]:
batch = test_data_loader[0]
image_batch = batch[0]
mask_batch = batch[1]
len(batch)

In [None]:
res = model.predict(image_batch, batch_size=len(image_batch))
res.shape

In [None]:
def denormalize(img):
  mean=(0.485, 0.456, 0.406)
  std=(0.229, 0.224, 0.225)
  # img = test_data_loader[batch_idx][0][img_idx]
  reverse = ((img * std + mean) * 255).astype('uint8')
  return reverse

In [None]:
for i in range(16):
  fig, axes = plt.subplots(1, 3, figsize=(8, 8))
  axes[0].imshow(res[i].squeeze())
  axes[1].imshow(denormalize(image_batch[i]))
  axes[2].imshow(mask_batch[i].squeeze())

In [None]:
import tqdm

In [None]:
res = model.predict(test_data_loader,  steps=len(test_data_loader), verbose=1)

In [None]:
dice_loss = sm.losses.DiceLoss()
binary_focal_loss = sm.losses.BinaryFocalLoss()
total_loss = sm.losses.DiceLoss() + (1 * sm.losses.BinaryFocalLoss())
losses = {
          "dice": dice_loss,
          "bin focal" : binary_focal_loss,
          "total" : total_loss 
}

"""
res_copy =  res[0].squeeze().copy()
thr = 0.2
res_copy[res_copy <= thr] = 0.0
res_copy[res_copy > thr] = 1.0
plt.figure(figsize=(10, 10))
plt.imshow(res_copy)
"""

all_losses = []

batch_size = 16

for idx in tqdm.tqdm(range(len(res))):
  MASKS_IDX = 1
  batch_idx = idx // batch_size
  img_idx = idx - batch_idx * batch_size

  gt = test_data_loader[batch_idx][MASKS_IDX][img_idx]  
  pred = res[idx]

  img_losses = []
  for loss_name, loss_func in losses.items():
    loss_res = loss_func(gt, pred)
    # loss_res = val(test_data_loader[0][1][0], np.expand_dims(res_copy, -1))
    # print("{} : {}".format(key, loss_res))
    img_losses.append(loss_res)
  all_losses.append(img_losses)


In [None]:
new_all_losses = map(lambda row: [float(row[0]), float(row[1]), float(row[2])], all_losses)
new_all_losses = list(new_all_losses)

In [None]:
import pandas as pd

In [None]:
losses_data = pd.DataFrame(new_all_losses, columns = ["dice", "focal", "total"])
losses_data.to_csv("/kaggle/working/losses.csv")
losses_data

In [None]:
losses_data.hist(bins=100, figsize=(10, 10));

In [None]:
losses_data.describe()

In [None]:
sorted_losses = losses_data.sort_values(by="total")

In [None]:
names = []
for img_idx in np.linspace(start=0, stop=len(sorted_losses), num=10, endpoint=False):
  item =sorted_losses.iloc[int(img_idx)]
  print(item.name, item.total)
  names.append(sorted_losses.iloc[int(img_idx)].name)

In [None]:
def denormalize(img):
  mean=(0.485, 0.456, 0.406)
  std=(0.229, 0.224, 0.225)
  # img = test_data_loader[batch_idx][0][img_idx]
  reverse = ((img * std + mean) * 255).astype('uint8')
  return reverse

In [None]:
for idx in names:
  fig, axes = plt.subplots(1, 3)

  batch_idx = idx // batch_size
  img_idx = idx - batch_idx * batch_size
  
  reverse = denormalize(test_data_loader[batch_idx][0][img_idx])
  axes[0].imshow(reverse.squeeze())
  axes[1].imshow(test_data_loader[batch_idx][1][img_idx].squeeze())
  axes[2].imshow(res[idx].squeeze())  

In [None]:
fig, axes = plt.subplots(1, 3)

idx = sorted_losses.iloc[-1].name
batch_idx = idx // batch_size
img_idx = idx - batch_idx * batch_size


axes[0].imshow(res[idx].squeeze())  
# axes[1].imshow(test_data_loader[-1][1][-2].squeeze())
axes[1].imshow(test_data_loader[batch_idx][1][img_idx].squeeze())
# reverse = denormalize(test_data_loader[-1][0][-2])
reverse = denormalize(test_data_loader[batch_idx][0][img_idx])
axes[2].imshow(reverse.squeeze())