https://github.com/karolzak/keras-unet/blob/master/notebooks/kz-whale-tails.ipynb

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import glob
import os
import sys
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.callbacks import ModelCheckpoint


In [None]:
DATA_DIR = Path("./data/tiles/ls78")

MAX_X = 255
MAX_Y = 65535  # TODO: Tilegen outputs rgb image, not grayscale

BAND_DIRS = sorted(list(DATA_DIR.glob("B*")))
TRUTH_DIR = DATA_DIR / "truth"

display(BAND_DIRS)
display(list())

In [None]:
truth_paths = list(TRUTH_DIR.glob("*.png"))

In [None]:
def read_fname(fname):
    # div by maxy, but mul by maxx in anticipation of preprocess_transform
    truth = np.array(Image.open(TRUTH_DIR / fname)) / (MAX_Y / MAX_X)
    truth = truth.astype('uint8')
    bands = np.asarray([np.array(Image.open(band_dir / fname)) for band_dir in BAND_DIRS])
#     bands = bands[1:4][::-1]
    mchannel = np.dstack(bands).astype('uint8')
    return mchannel, truth

def preprocess_transform(ds):
    return (ds / MAX_X).astype('float32')

def rgb_transform(ds):
#     return np.flip(ds, 3)  # Bands 2, 3, 4 -> rgb
    return np.flip(ds[:,:,:,1:4], 3)  # Bands 1, 2, 3, 4, 5, 7 -> rgb

In [None]:
display(truth_paths[0].name)

x0, y0 = read_fname(truth_paths[0].name)

display(x0.shape)
display(y0.shape)


In [None]:
img_lst = []
msk_lst = []

for truth_path in truth_paths:
    img, truth = read_fname(truth_path.name)
    img_lst.append(img)
    msk_lst.append(truth)

img_np = np.asarray(img_lst)
msk_np = np.asarray(msk_lst)
msk_np = msk_np.reshape(msk_np.shape[0], msk_np.shape[1], msk_np.shape[2], 1)

del img_lst
del msk_lst

display(f'{img_np.shape=}')
display(f'{msk_np.shape=}')

In [None]:
from keras_unet.utils import plot_imgs

display(img_np.min())
display(img_np.max())
display(msk_np.min())
display(msk_np.max())

plot_imgs(org_imgs=rgb_transform(img_np) / 255, mask_imgs=msk_np / 255, nm_img_to_plot=5, figsize=6)

In [None]:
from sklearn.model_selection import train_test_split

x_train, x_val, y_train, y_val = train_test_split(img_np, msk_np, test_size=0.30, random_state=0)

print("x_train: ", x_train.shape)
print("y_train: ", y_train.shape)
print("x_val: ", x_val.shape)
print("y_val: ", y_val.shape)

input_shape = x_train[0].shape
display(f'{input_shape=}')

In [None]:
x_val = preprocess_transform(x_val)
y_val = preprocess_transform(y_val)

x_val_tf = tf.convert_to_tensor(x_val)
y_val_tf = tf.convert_to_tensor(y_val)

In [None]:
from keras_unet.utils import get_augmented

train_gen = get_augmented(
    x_train, y_train, 
#     x_val_tf, y_val_tf,
    batch_size=2,
    data_gen_args = dict(
#         rescale=1 / 255,
        preprocessing_function=preprocess_transform,
        rotation_range=360.,
        width_shift_range=0.05,
        height_shift_range=0.05,
        shear_range=40,
        zoom_range=0.2,
#         brightness_range=[0.7,1.4],  # only works on 1 or 3 channel images
        channel_shift_range=10.0,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='reflect'
    ))

# train_gen = get_augmented(
#     x_train, y_train, batch_size=2,
#     data_gen_args=dict(
#         rotation_range=0.,
#         width_shift_range=0.00,
#         height_shift_range=0.00,
#         shear_range=0,
#         zoom_range=0.0,
#         horizontal_flip=True,
#         vertical_flip=True,
#         fill_mode='constant'
#     )
# )

del x_train
del y_train

In [None]:
sample_batch = next(train_gen)

xx, yy = sample_batch
print(xx.shape, yy.shape)
from keras_unet.utils import plot_imgs

plot_imgs(org_imgs=rgb_transform(xx), mask_imgs=yy, nm_img_to_plot=2, figsize=6)

In [None]:
# https://towardsdatascience.com/biomedical-image-segmentation-unet-991d075a3a4b

class conv_block_nested(nn.Module):

    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()
        self.activation = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        output = self.activation(x)

        return output

class Nested_UNet(nn.Module):

    def __init__(self, in_ch=3, out_ch=1):
        super(Nested_UNet, self).__init__()

        n1 = 512
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])

        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])

        self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0])
        self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1])
        self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2])

        self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0])
        self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1])

        self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0])

        self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)

    def forward(self, x):

        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))

        output = self.final(x0_4)
        return output

In [None]:
import torch.optim as optim

net = Nested_UNet(in_ch=6)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_gen, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = np.moveaxis(inputs, 3, 1)
        inputs = torch.from_numpy(inputs)
        labels = torch.from_numpy(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        del inputs
        del labels

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

In [None]:
display(model.summary())

In [None]:
model_filename = 'model_trainchpt.h5'
callback_checkpoint = ModelCheckpoint(
    model_filename, 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=True,
)


In [None]:
from tensorflow.keras.optimizers import Adam, SGD
from keras_unet.metrics import iou, iou_thresholded
from keras_unet.losses import jaccard_distance

def closs(y_true, y_pred):
    return 3 * jaccard_distance(y_true, y_pred) + tf.keras.losses.binary_crossentropy(y_true, y_pred)

model.compile(
    optimizer=Adam(learning_rate=0.001 * 10),
    #optimizer=SGD(lr=0.01, momentum=0.99),
    loss='binary_crossentropy',
#     loss=jaccard_distance,
#     loss=closs,
    metrics=[iou, iou_thresholded]
)


In [None]:
history = model.fit(
    train_gen,
    steps_per_epoch=1000,
    epochs=500,
    
    validation_data=(x_val_tf, y_val_tf),
    validation_batch_size=2,
    callbacks=[callback_checkpoint]
)


In [None]:
from keras_unet.utils import plot_segm_history

plot_segm_history(history)

In [None]:
model.load_weights(model_filename)

In [None]:
y_pred = model.predict(x_val, batch_size=2)

In [None]:
from keras_unet.utils import plot_imgs

display(len(x_val))
plot_imgs(org_imgs=rgb_transform(x_val), mask_imgs=y_val, pred_imgs=y_pred, nm_img_to_plot=10, figsize=4)