In [None]:
import cv2
import numpy as np
import dask as d
import matplotlib.pyplot as plt
import time

from glob import glob
from tensorflow import test, device
# from tensorflow import keras 
from tensorflow.keras import backend, Input, Model, layers
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.utils import Sequence
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

In [None]:
from fl_tissue_model_tools import defs

In [None]:
data_root_path = "D:/oxford_pets_data" # Carson
# data_root_path = "./" # Mitchell

In [None]:
img_paths = sorted([fn.replace("\\", "/") for fn in glob(f"{data_root_path}/images/*.jpg")])
label_paths = sorted([fn.replace("\\", "/") for fn in glob(f"{data_root_path}/annotations/trimaps/[!._]*.png")])

# Helper functions

In [None]:
def map2bin(lab, fg_vals, bg_vals, fg=1, bg=0):
    fg_mask = np.isin(lab, fg_vals)
    bg_mask = np.isin(lab, bg_vals)
    lab_c = lab.copy()
    lab_c[fg_mask] = fg
    lab_c[bg_mask] = bg
    return lab_c


def augment(img, rot, hflip, vflip, expand_dims=True):
    og_shape = img.shape
    hw = img.shape[:2]
    # Horizontal flip
    if hflip:
        img = cv2.flip(img, 1)
    # Vertical flip
    if vflip:
        img = cv2.flip(img, 0)
    # Rotation
    rot_mat = cv2.getRotationMatrix2D((hw[1] // 2, hw[0] // 2), rot, 1.0)
    
    if expand_dims:
        img = np.expand_dims(cv2.warpAffine(img, rot_mat, hw), 2)
    
    return img

# Validate images match labels in order & count

In [None]:
def get_img_id(img_path):
    return img_path.split("/")[-1].split(".")[0]

In [None]:
assert(all([get_img_id(img_paths[i]) == get_img_id(label_paths[i]) for i in range(len(img_paths))]))

In [None]:
assert(len(img_paths) == len(label_paths))
print(len(img_paths))

# Constants

In [None]:
rand_seed = 12345
batch_size = 32
img_size = (128, 128)
# img_size = (160, 160)
# num_classes = 2
n_outputs = 1
rs = np.random.RandomState(seed=rand_seed)
# For collapsing mask into binary range
fg_vals = [1, 3]
bg_vals = [2]
cp_filepath = "oxford_pets_segmentation_best_weights.h5"

# Examine data

In [None]:
preview_idx = 10

In [None]:
# img = cv2.imread(img_paths[preview_idx], cv2.IMREAD_GRAYSCALE)
img = np.expand_dims(np.array(load_img(img_paths[preview_idx], target_size=img_size, color_mode="grayscale", interpolation="lanczos")), 2)
plt.imshow(img, cmap="gray")
plt.show()

In [None]:
mask = np.expand_dims(
    np.array(load_img(label_paths[preview_idx], target_size=img_size, color_mode="grayscale", interpolation="nearest")), 2
)
mask = map2bin(mask, fg_vals, bg_vals)
plt.imshow(mask, cmap="gray")
plt.show()

In [None]:
np.unique(mask)

# Data pipeline

In [None]:
class OxfordPetsSequence(Sequence):
    """Helper to iterate over the data"""
    # TODO: edit to turn off augmentations for validation 
    
    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths, random_state, fg_vals, bg_vals, augmentation_function=None):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths
        self.rs: np.random.RandomState = random_state
        self.fg_vals = fg_vals
        self.bg_vals = bg_vals
        self.augmentation_function = augmentation_function

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size
    
    def __getitem__(self, idx):
        """Returns the batch (input, target) at index `idx`"""
        # Image index, offset by batch
        i = idx * self.batch_size 
        
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]

        # Load the input images and convert them to grayscale
        def load_x():
            x = np.zeros((len(batch_input_img_paths),) + self.img_size + (1,), dtype=np.float32)
            for j, path in enumerate(batch_input_img_paths):
                # Ensure best quality downsampling (interpolation methods overview: https://stackoverflow.com/a/44083113)
                img = load_img(path, target_size=self.img_size, color_mode="grayscale", interpolation="lanczos")
                img = np.expand_dims(img, 2) # add a third dimension to the array
                x[j] = img
            return x

        # load the target images and condense the number of labels in the segmentation mask
        def load_y():
            y = np.zeros((len(batch_target_img_paths),) + self.img_size + (1,), dtype=np.uint8)
            for j, path in enumerate(batch_target_img_paths):
                # Use interpolation="nearest" to ensure mask is only valid bit values
                img = load_img(path, target_size=self.img_size, color_mode="grayscale", interpolation="nearest")
                # add a third dimension to the array
                img = np.expand_dims(img, 2)
                # Collapse the mask from three labels to two labels
                img = map2bin(img, self.fg_vals, self.bg_vals)
                y[j] = img
            return y

        x, y = d.compute((d.delayed(load_x)(), d.delayed(load_y)()))[0]
        
        if self.augmentation_function != None:
            m = len(x)
            # Cannot parallelize (random state ensures reproducibility)
            rots = self.rs.choice([0, 90, 180, 270], size=m)
            hflips = self.rs.choice([True, False], size=m)
            vflips = self.rs.choice([True, False], size=m)

            def aug_imgs(imgs):
                return np.array([self.augmentation_function(imgs[i], rots[i], hflips[i], vflips[i]) for i in range(m)])
            
            x, y = d.compute((d.delayed(aug_imgs)(x), d.delayed(aug_imgs)(y)))[0]

        return x, y

# Data generator demo

In [None]:
pets_demo = OxfordPetsSequence(batch_size, img_size, img_paths, label_paths, rs, fg_vals, bg_vals, augment)
start = time.time()
X, y = pets_demo[1]
stop = time.time()
print(stop - start)
plt.imshow(X[1][:,:,0], cmap='gray')
plt.show()
plt.imshow(y[1][:,:,0], cmap='gray')
plt.show()

# Build model

In [None]:
def get_oxford_pets_model(img_size, num_classes):
    inputs = Input(shape=img_size + (1,))
    
    ### Downsampling the inputs ###
    
    # compute an initial set of 32 features using convolutional layers 
    # also, downsample the image using strided convolutions
    x = layers.Conv2D(32, 3, strides = 2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    previous_block_activation = x
    
    # hidden layers using Xception convolutions
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        
        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
        
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(previous_block_activation)
        
        x = layers.add([x, residual])
        
        previous_block_activation = x
        
    ### upsampling ###
    
    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])
        previous_block_activation = x
        
    # add a per-pixel classification layer
    # outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)
    outputs = layers.Conv2D(n_outputs, 3, activation="sigmoid", padding="same")(x)
    
    # define the model 
    model = Model(inputs, outputs)
    
    return model

# Set up data generators

In [None]:
n_val = int(len(img_paths) * 0.2)
n_test = int(len(img_paths) * 0.2)
# Shuffle data
data_idx = np.array(range(len(img_paths)))
rs.shuffle(data_idx)

In [None]:
train_img_paths = img_paths[: -(n_val + n_test)]
train_label_paths = label_paths[: -(n_val + n_test)]

val_img_paths = img_paths[-(n_val + n_test): -n_test]
val_label_paths = label_paths[-(n_val + n_test): -n_test]

test_img_paths = img_paths[-n_test:]
test_label_paths = label_paths[-n_test:]

In [None]:
train_gen = OxfordPetsSequence(batch_size, img_size, train_img_paths, train_label_paths, rs, fg_vals, bg_vals, augment)
val_gen = OxfordPetsSequence(batch_size, img_size, val_img_paths, val_label_paths, rs, fg_vals, bg_vals, augment)
# No augmentation for test_gen
test_gen = OxfordPetsSequence(batch_size, img_size, test_img_paths, test_label_paths, rs, fg_vals, bg_vals)

In [None]:
### create the model ###
backend.clear_session()
# model = get_oxford_pets_model(img_size, n_classes)
model = get_oxford_pets_model(img_size, n_outputs)
model.summary()

In [None]:
# tell tf to use the gpu
# import tensorflow as tf
device(test.gpu_device_name())

In [None]:
test.gpu_device_name()

In [None]:
### train the model ###
# model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy")
model.compile(optimizer="rmsprop", loss="binary_crossentropy")
callbacks = [ModelCheckpoint(cp_filepath, save_best_only=True, save_weights_only=True)]
num_epochs = 50
h = model.fit(train_gen, validation_data=val_gen, epochs=num_epochs, callbacks=callbacks)

# Load best weights

In [None]:
model.load_weights(cp_filepath)

In [None]:
# Generate predictions for all images in the test set
# val_gen = OxfordPetsSequence(batch_size, img_size, val_input_img_paths, val_target_img_paths)
test_preds = model.predict(test_gen)

In [None]:
# pred = np.argmax(test_preds[pred_idx], axis=-1)
pred_idx = 934
pred = np.copy(test_preds[pred_idx])
pred[pred < 0.5] = 0
pred[pred > 0] = 1
# true = test_img_paths[pred_idx]
true = np.expand_dims(np.array(load_img(test_img_paths[pred_idx], target_size=img_size, color_mode="grayscale", interpolation="lanczos")), 2)

In [None]:
plt.imshow(true, cmap="gray", vmin=defs.GS_MIN, vmax=defs.GS_MAX)
plt.show()
plt.imshow(pred * defs.GS_MAX, cmap="gray", vmin=defs.GS_MIN, vmax=defs.GS_MAX)
plt.show()