In [None]:
weights_path = '/root/data/models/erko/segmentation/0905_balanced_oneclass_fg_46.h5'
input_shape = (512, 512, 3)

### #1 LOAD MODEL

In [None]:
import os

import numpy as np
from keras.optimizers import Adam
from keras.callbacks import Callback, LearningRateScheduler
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
from PIL import Image

from unet import get_unet, jaccard_coef_int, jaccard_coef_loss

import matplotlib.pyplot as plt

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
model = get_unet(3, input_shape[0], input_shape[1], classes=1)
model.load_weights(weights_path)

### #2 LOAD DATA

In [None]:
import glob
import cv2
import numpy as np
import random
import math
import json
SEED = 448
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
import shutil
from collections import Counter

In [None]:
!pwd

In [None]:
images = json.load(open('/root/thomas/github/cv_research/thomas/full_pipeline/erko/balanced_images.json'))

In [None]:
random.seed(SEED)
random.shuffle(images)
cutoff = int(len(images)*0.8)
train = images[:cutoff]
val = images[cutoff:]

In [None]:
batch_size = 8
steps_per_epoch = len(train) // batch_size
steps_per_epoch_val = len(val) // batch_size

In [None]:
def flip_axis(x, axis):
    x = np.asarray(x).swapaxes(axis, 0)
    x = x[::-1, ...]
    x = x.swapaxes(0, axis)
    return x

In [None]:
def generator(images, steps_per_epoch, BATCH_SIZE, input_shape):
    i = 0
    seq = iaa.Sequential([iaa.Sometimes(0.7, iaa.GaussianBlur(sigma=(0, 2.0))),
                  iaa.Sharpen(alpha=(0, 0.1), lightness=(0.7, 1.3)),
                  iaa.ContrastNormalization((0.5, 1.2))],
                 random_order=True)
    img_size = input_shape[0]
    while True:
        x_batch = np.empty((BATCH_SIZE, input_shape[0], input_shape[1], input_shape[2]), dtype=np.uint8)
        y_batch = np.empty((BATCH_SIZE, input_shape[0], input_shape[1], 1), dtype=np.uint8)
        for (ind, j) in enumerate(range(i*BATCH_SIZE, (i+1)*BATCH_SIZE)):
            # img_path = images[j]
            # img_name = os.path.basename(img_path)
            # mask_path = '/root/data/erko/labels/{}.semantic.png'.format(img_name)
            mask_path = images[j]
            img_path = mask_path.replace('.semantic.png', '')
            img_path = img_path.replace('.semantic.jpg', '')
            
            xb = np.array(Image.open(img_path).resize((input_shape[0], input_shape[1])))
                        
            mask_img = np.array(Image.open(mask_path).resize((input_shape[0], input_shape[1])))
            
            mask0 = np.zeros((input_shape[0], input_shape[1]))
            
            red, green, blue = mask_img[:,:,0], mask_img[:,:,1], mask_img[:,:,2]
            if mask_path.endswith('png'):
                pink_mask = (red == 255) & (green == 105) & (blue == 180)
            elif mask_path.endswith('jpg'):
                pink_mask = (red == 255) & (green == 105) & (blue == 179)
            
            mask0[pink_mask] = 1
            
            y0 = mask0
            
            if np.random.random() > 0.5:
                xb = flip_axis(xb, 1)
                y0 = flip_axis(y0, 1)
                
            x_batch[ind,...] = xb
            y_batch[ind,...,0] = y0
            
        x_batch = seq.augment_images(x_batch)
        i += 1
        if i >= steps_per_epoch:
            i = 0
        yield x_batch, y_batch

In [None]:
val_generator = generator(val, steps_per_epoch_val, batch_size, input_shape)

### #3 Predictions

In [None]:
import scipy
import cv2
from skimage.measure import label
from matplotlib.patches import Rectangle
from skimage.feature import peak_local_max
from skimage.morphology import watershed
from scipy import ndimage

In [None]:
kernel = np.ones((3, 3))

In [None]:
ratio = []
for _ in range(steps_per_epoch_val):
    X, Y = next(val_generator)
    Ypred = model.predict_on_batch(X)   
    
    for i in range(batch_size):
        # ground truths
        labels = label(cv2.erode(Y[i,...], kernel))
        gt_count = len(np.unique(labels)) - 1
        
        mask = np.zeros((input_shape[0], input_shape[1]))
        bboxes = []
        eroded = cv2.erode(Ypred[i,...,0], kernel)
        labels = label(eroded)
        for lab in np.unique(labels):
            if lab == 0:
                continue
            tmp = labels == lab
            if np.count_nonzero(tmp) < 1000:
                continue
            mask += tmp
            y, x = np.nonzero(tmp)
            xmin, xmax = np.min(x), np.max(x)
            ymin, ymax = np.min(y), np.max(y)
            bboxes.append([xmin, ymin, xmax-xmin, ymax-ymin])
            
        f, ax = plt.subplots(1, 4, figsize=(20, 10))
        ax[0].imshow(X[i,...])
        ax[0].set_title('raw image')
        ax[1].imshow(X[i,...])
        ax[1].imshow(Y[i,...,0], alpha=0.3)
        ax[1].set_title('ground truth: {} fish'.format(gt_count))
        ax[2].imshow(X[i,...])
        ax[2].imshow(Ypred[i,...,0], alpha=0.3)
        ax[2].set_title('predictions')
        ax[3].imshow(X[i,...])
        ax[3].imshow(mask, alpha=0.3)
        ax[3].set_title('pp predictions: {} fish'.format(len(bboxes)))
        for bbox in bboxes:
            rec = Rectangle((bbox[0],bbox[1]),bbox[2],bbox[3],linewidth=2,
                            edgecolor='r',
                            facecolor='none', 
                            linestyle='--')
            ax[3].add_patch(rec)
            
        if len(bboxes) > 0:
            ratio.append(float(len(bboxes))/ gt_count)
        plt.show()

In [None]:
steps_per_epoch_val

In [None]:
mean_ratio = np.mean(ratio)
print('Mean ratio {}'.format(mean_ratio))

### #4 Research

In [None]:
def merge_close_markers(markers, thresh=50):
    x, y = np.where(markers!=0)
    dist = scipy.spatial.distance_matrix(np.stack([x, y ], axis=0).transpose(), 
                                         np.stack([x, y ], axis=0).transpose())
    dist = np.triu(dist)
    xclose, yclose = np.where(np.logical_and(dist> 0, dist<=10))
    for (i,j) in zip(xclose, yclose):
        markers[x[i], y[i]] = 0
    return markers

In [None]:
plt.imshow(mask)

In [None]:
D = ndimage.distance_transform_edt(mask)

In [None]:
plt.imshow(D)

In [None]:
localMax = peak_local_max(D, indices=False, min_distance=10, labels=mask)

In [None]:
plt.imshow(localMax)

In [None]:
markers = ndimage.label(localMax, structure=np.ones((3, 3)))[0]
markers = merge_close_markers(markers)

In [None]:
plt.imshow(markers)

In [None]:
labels = watershed(-D, markers, mask=mask)

In [None]:
plt.imshow(labels)