In [None]:
import os
import glob

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm.notebook import tqdm
import deepblink as pink
import matplotlib.pyplot as plt
import numpy as np
import segmentation_models as sm
import skimage.io
import skimage.util
import tensorflow as tf

In [None]:
# General CPU/GPU setting
os.environ["OMP_NUM_THREADS"] = "10"
os.environ["OPENBLAS_NUM_THREADS"] = "10"
os.environ["MKL_NUM_THREADS"] = "10"
os.environ["VECLIB_MAXIMUM_THREADS"] = "10"
os.environ["NUMEXPR_NUM_THREADS"] = "10"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.nice(19)

# Tensorflow CPU setting
tf.config.threading.set_intra_op_parallelism_threads(4)
tf.config.threading.set_inter_op_parallelism_threads(4)
gpus = tf.config.list_physical_devices("GPU")

# Generate data for labelling

# Data preparation and splitting

In [None]:
images = sorted(glob.glob("../raw/jess_flies/images_hr38_seg/*.tif"))
labels = sorted(glob.glob("../raw/jess_flies/labels_hr38_seg/*.tif"))

In [None]:
data_x = []
data_y = []

for img, label in tqdm(zip(images, labels)):
    assert pink.io.basename(img) == pink.io.basename(label)

    # Read and crop
    img = skimage.io.imread(img)
    label = skimage.io.imread(label)
    data_x.append(img)
    data_y.append(label)
                
data_x = np.array(data_x)
data_y = np.array(data_y)

In [None]:
np.random.seed(42)

# Shuffle
index = np.arange(len(data_x))
np.random.shuffle(index)
data_x = data_x[index]
data_y = data_y[index]

# Train val split
split = 0.2
n_split = int(len(data_x) * split)
x_train = data_x[n_split:]
y_train = data_y[n_split:]
x_valid = data_x[:n_split]
y_valid = data_y[:n_split]

In [None]:
# Save dataset
np.savez_compressed(
    "../datasets/jess_flies/20210922_seg_hr38_kay_large.npz",
    x_train=x_train,
    y_train=y_train,
    x_valid=x_valid,
    y_valid=y_valid,
)

# Data loading

In [None]:
data =  np.load(
    "../datasets/jess_flies/20210922_seg_hr38_kay_large.npz"
)
x_train = data["x_train"]
y_train = data["y_train"]
x_valid = data["x_valid"]
y_valid = data["y_valid"]

In [None]:
# Add axes to mimick RGB used by pretrained-models
x_train = np.stack((x_train,) * 3, axis=-1)
y_train = np.expand_dims(y_train, axis=-1)
x_valid = np.stack((x_valid,) * 3, axis=-1)
y_valid = np.expand_dims(y_valid, axis=-1)

In [None]:
# Define the model
BACKBONE = "inceptionv3"
preprocess_input = sm.get_preprocessing(BACKBONE)
x_train = preprocess_input(x_train)
x_valid = preprocess_input(x_valid)

In [None]:
# Sanity check, view few images
idx1 = np.random.randint(0, len(x_train))
idx2 = np.random.randint(0, len(x_valid))

fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0,0].imshow(x_train[idx1, ..., 0])
ax[0,1].imshow(y_train[idx1].squeeze())
ax[1,0].imshow(x_valid[idx2, ..., 0])
ax[1,1].imshow(y_valid[idx2].squeeze())
plt.tight_layout()
plt.show()

In [None]:
seed = 42
batch_size = 4

# Datagenerator for images
image_data_generator = ImageDataGenerator(
    rotation_range=90,
    width_shift_range=0.3,
    height_shift_range=0.3,
    shear_range=0.5,
    zoom_range=0.3,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode="reflect",
    preprocessing_function=lambda x: (x - x.mean()) / x.std(),
)
image_data_generator.fit(x_train, augment=True, seed=seed)
train_img_generator = image_data_generator.flow(
    x_train, seed=seed, batch_size=batch_size
)
valid_img_generator = image_data_generator.flow(
    x_valid, seed=seed, batch_size=batch_size
)

# Datagenerator for masks
mask_data_generator = ImageDataGenerator(
    rotation_range=90,
    width_shift_range=0.3,
    height_shift_range=0.3,
    shear_range=0.5,
    zoom_range=0.3,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode="reflect",
    preprocessing_function=lambda x: np.where(x > 0, 1, 0).astype(x.dtype),
)
mask_data_generator.fit(y_train, augment=True, seed=seed)
train_mask_generator = mask_data_generator.flow(
    y_train, seed=seed, batch_size=batch_size
)
valid_mask_generator = mask_data_generator.flow(
    y_valid, seed=seed, batch_size=batch_size
)

In [None]:
def my_data_generator(image_generator, mask_generator):
    for (img, mask) in zip(image_generator, mask_generator):
        yield (img, mask)


train_generator = my_data_generator(train_img_generator, train_mask_generator)
valid_generator = my_data_generator(valid_img_generator, valid_mask_generator)

In [None]:
# Sanity check with data generator
image, mask = next(train_generator)
idx = np.random.randint(0, batch_size)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image[idx, ..., 0])
ax[1].imshow(mask[idx, ..., 0])
plt.show()

# Model definition and training

In [None]:
# Define model
tf.keras.backend.clear_session()
model = sm.Unet(BACKBONE, classes=1, activation="sigmoid")
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=sm.losses.bce_jaccard_loss,
    metrics=[sm.metrics.iou_score, "accuracy"],
)
model.summary()

In [None]:
# Fit the model
history = model.fit(
    train_generator,
    validation_data=valid_generator,
    callbacks=[
        tf.keras.callbacks.ModelCheckpoint(
            "./20210922_inceptionv3_aug.h5", save_best_only=True
        )
    ],
    steps_per_epoch=32,
    validation_steps=32,
    epochs=50,
)

In [None]:
# Plot training and validation accuracy / loss at each epoch
loss = history.history["loss"]
val_loss = history.history["val_loss"]
acc = history.history["iou_score"]
val_acc = history.history["val_iou_score"]
epochs = range(1, len(loss) + 1)

fig, ax = plt.subplots(1, 2, figsize=(15, 7))
ax[0].plot(epochs, loss, "y", label="Training loss")
ax[0].plot(epochs, val_loss, "r", label="Validation loss")
ax[0].set_title("Training and validation loss")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("Loss")
ax[0].legend()

ax[1].plot(epochs, acc, "y", label="Training IOU")
ax[1].plot(epochs, val_acc, "r", label="Validation IOU")
ax[1].set_title("Training and validation IOU")
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("IOU")
ax[1].legend()
plt.show()

In [None]:
# View example predictions on valid images
image, mask = next(valid_generator)

fig, ax = plt.subplots(4, 3, figsize=(16, 20))
for i in range(4):
    ax[i, 0].set_title("Testing Image")
    ax[i, 0].imshow(image[i, ..., 0])
    ax[i, 0].set_axis_off()
    ax[i, 1].set_title("Testing Label")
    ax[i, 1].imshow(mask[i, ..., 0])
    ax[i, 1].set_axis_off()
    ax[i, 2].set_title("Prediction on test image")
    prediction = model.predict(image[i][None]).squeeze()
    ax[i, 2].imshow(prediction)
    ax[i, 2].set_axis_off()
plt.tight_layout()
plt.show()