# Install Depedencies and Import

In [None]:
import gdown
import zipfile
import os

from PIL import ImageOps
from PIL import Image as im

from IPython.display import Image, display
from IPython.display import clear_output

import matplotlib.pyplot as plt

import numpy as np
import random

from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.utils import plot_model
import tensorflow_addons as tfa

import sys
from models import *

In [None]:
# Download and unzip training and test data
url = ''
output = ''
gdown.download(url, output, quiet=False)

with zipfile.ZipFile("","r") as zip_ref:
    zip_ref.extractall("")

# Prepare paths of input images and target segmentation masks

In [None]:
input_dir = ""
uimg_dir = ""
sos_dir = ""
img_dir = ""

input_img_paths = sorted(
    [
        os.path.join(input_dir, fname)
        for fname in os.listdir(input_dir)
        if fname.endswith(".png")
    ]
)
uimg_img_paths = sorted(
    [
        os.path.join(uimg_dir, fname)
        for fname in os.listdir(uimg_dir)
        if fname.endswith(".png")
    ]
)
sos_img_paths = sorted(
    [
        os.path.join(sos_dir, fname)
        for fname in os.listdir(sos_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)
bm_img_paths = sorted(
    [
        os.path.join(img_dir, fname)
        for fname in os.listdir(img_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(input_img_paths))

for input_path, uimg_path, sos_path, img_path in zip(input_img_paths[:10], uimg_img_paths[:10], sos_img_paths[:10], bm_img_paths[:10]):
    print(input_path, "|", uimg_path, "|", sos_path, "|", img_path)

In [None]:
test_input_dir = ""
test_sos_dir = ""
test_img_dir = ""
test_uimg_dir = ""

test_input_img_paths = sorted(
    [
        os.path.join(test_input_dir, fname)
        for fname in os.listdir(test_input_dir)
        if fname.endswith(".png")
    ]
)
test_uimg_img_paths = sorted(
    [
        os.path.join(test_uimg_dir, fname)
        for fname in os.listdir(test_uimg_dir)
        if fname.endswith(".png")
    ]
)
test_sos_img_paths = sorted(
    [
        os.path.join(test_sos_dir, fname)
        for fname in os.listdir(test_sos_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)
test_bm_img_paths = sorted(
    [
        os.path.join(test_img_dir, fname)
        for fname in os.listdir(test_img_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(test_input_img_paths))

for input_path, uimg_path, sos_path, img_path in zip(test_input_img_paths[:10],test_uimg_img_paths[:10], test_sos_img_paths[:10], test_bm_img_paths[:10]):
    print(input_path, "|", uimg_path, "|", sos_path, "|", img_path)

In [None]:
val_input_dir = ""
val_sos_dir = ""
val_img_dir = ""
val_uimg_dir = ""

val_input_img_paths = sorted(
    [
        os.path.join(val_input_dir, fname)
        for fname in os.listdir(val_input_dir)
        if fname.endswith(".png")
    ]
)
val_uimg_img_paths = sorted(
    [
        os.path.join(val_uimg_dir, fname)
        for fname in os.listdir(val_uimg_dir)
        if fname.endswith(".png")
    ]
)    
val_sos_img_paths = sorted(
    [
        os.path.join(val_sos_dir, fname)
        for fname in os.listdir(val_sos_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)
val_bm_img_paths = sorted(
    [
        os.path.join(val_img_dir, fname)
        for fname in os.listdir(val_img_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(val_input_img_paths))

for input_path, uimg_path, sos_path, img_path in zip(val_input_img_paths[:10],val_uimg_img_paths[:10], val_sos_img_paths[:10], val_bm_img_paths[:10]):
    print(input_path, "|", uimg_path, "|", sos_path, "|", img_path)

# What does one input image and corresponding segmentation mask look like?

In [None]:
# Display input image #7
ascan = ImageOps.autocontrast(load_img(input_img_paths[7]))
display(ascan)

uimg = ImageOps.autocontrast(load_img(uimg_img_paths[7]))
display(uimg)

# Display auto-contrast version of corresponding SoS Map 
sos = ImageOps.autocontrast(load_img(sos_img_paths[7]))
display(sos)

# Display auto-contrast version of corresponding reconstructed Root image
img = ImageOps.autocontrast(load_img(bm_img_paths[7]))
display(img)

# Prepare `Sequence` class to load & vectorize batches of data

In [None]:
img_size = (256, 256)
batch_size = 12

class Ascans(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths, input_uimg_paths, sos_img_paths, bm_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.input_uimg_paths = input_uimg_paths
        self.sos_img_paths = sos_img_paths
        self.bm_img_paths = bm_img_paths

    def __len__(self):
        return len(self.sos_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_input_uimg_paths = self.input_uimg_paths[i : i + self.batch_size]
        batch_sos_img_paths = self.sos_img_paths[i : i + self.batch_size]
        batch_bm_img_paths = self.bm_img_paths[i : i + self.batch_size]
        x1 = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            x1[j] = np.expand_dims(img, 2)
        x2 = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="uint8")
        for j, path in enumerate(batch_input_uimg_paths):
            img = load_img(path, target_size=self.img_size)
            x2[j] = img
        y1 = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        y2 = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_sos_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y1[j] = np.expand_dims(img, 2)    # converts array to (img,1)
        for j, path in enumerate(batch_bm_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y2[j] = np.expand_dims(img, 2)    # converts array to (img,1)
        return [x1, x2], [y1, y2]

# Multi-task U-Net model

In [None]:
# Build model
keras.backend.clear_session()
model = unet_mtl_mimo_deeper(img_size, 1, 3)
model.summary()


# Set aside a validation split

In [None]:
train_input_img_paths = input_img_paths
train_input_uimg_paths = uimg_img_paths
train_sos_img_paths = sos_img_paths
train_bm_img_paths = bm_img_paths
val_input_img_paths = val_input_img_paths
val_input_uimg_paths = val_uimg_img_paths
val_sos_img_paths = val_sos_img_paths
val_bm_img_paths = val_bm_img_paths

# Instantiate data Sequences for each split
train_gen = Ascans(batch_size, img_size, train_input_img_paths, train_input_uimg_paths, train_sos_img_paths, train_bm_img_paths)
val_gen = Ascans(batch_size, img_size, val_input_img_paths, val_input_uimg_paths, val_sos_img_paths, val_bm_img_paths)

# Train the model

In [None]:
# Configure the model for training.
plot_progress = PlotProgress(entity='loss')
optimizer = tfa.optimizers.AdamW(weight_decay = 0, learning_rate = 5e-4)

model.compile(optimizer = optimizer, loss={'sos_output': 'mse', 'img_output': 'mse'}, 
              loss_weights={'sos_output': 0.5, 'img_output': 0.5})

#model.load_weights('sos_recon.h5')

callbacks = [
    keras.callbacks.ModelCheckpoint("sos_recon_val.h5", monitor="val_loss", save_best_only=True),
    keras.callbacks.ModelCheckpoint("sos_recon.h5", monitor="loss", save_best_only=True),
    plot_progress
]

In [None]:
# Train the model, doing validation at the end of each epoch.
epochs = 100
model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks)

# Visualize predictions

In [None]:
# Generate predictions for all images in the validation set
model.load_weights('sos_recon_val.h5')

val_gen = Ascans(10, img_size, val_input_img_paths, val_input_uimg_paths, val_sos_img_paths, val_bm_img_paths)
val_preds = model.predict(val_gen)
val_results = model.evaluate(val_gen)

train_gen = Ascans(10, img_size, train_input_img_paths[0:500], train_input_uimg_paths[0:500], train_sos_img_paths[0:500], train_bm_img_paths[0:500])
train_preds = model.predict(train_gen)
train_results = model.evaluate(train_gen)


In [None]:
test_gen = Ascans(10, img_size, test_input_img_paths, test_uimg_img_paths, test_sos_img_paths, test_bm_img_paths)
test_preds = model.predict(test_gen)
test_results = model.evaluate(test_gen)


In [None]:
print(train_preds[0].shape)
train_truth = np.zeros((4,)+ np.squeeze(train_preds[0]).shape)
val_truth = np.zeros((4,)+ np.squeeze(train_preds[0]).shape)
test_truth = np.zeros((4,)+ np.squeeze(train_preds[0]).shape)

for i in range(500):
    train_truth[0,i] = np.asarray(load_img(train_input_img_paths[i], color_mode = 'grayscale'))
    train_truth[1,i] = np.asarray(load_img(train_input_uimg_paths[i], color_mode = 'grayscale'))
    train_truth[2,i] = np.asarray(load_img(train_sos_img_paths[i], color_mode = 'grayscale'))
    train_truth[3,i] = np.asarray(load_img(train_bm_img_paths[i], color_mode = 'grayscale'))
    
    val_truth[0,i] = np.asarray(load_img(val_input_img_paths[i], color_mode = 'grayscale'))
    val_truth[1,i] = np.asarray(load_img(val_input_uimg_paths[i], color_mode = 'grayscale'))
    val_truth[2,i] = np.asarray(load_img(val_sos_img_paths[i], color_mode = 'grayscale'))
    val_truth[3,i] = np.asarray(load_img(val_bm_img_paths[i], color_mode = 'grayscale'))
    
    test_truth[0,i] = np.asarray(load_img(test_input_img_paths[i], color_mode = 'grayscale'))
    test_truth[1,i] = np.asarray(load_img(test_uimg_img_paths[i], color_mode = 'grayscale'))
    test_truth[2,i] = np.asarray(load_img(test_sos_img_paths[i], color_mode = 'grayscale'))
    test_truth[3,i] = np.asarray(load_img(test_bm_img_paths[i], color_mode = 'grayscale'))
          

In [None]:
from scipy.io import *

save_path = ''
save_var = {"val_preds":val_preds, "train_preds":train_preds, 
            "test_preds": test_preds, "val_results": val_results, "train_results": train_results, 
            "test_results": test_results, "batch_size": batch_size, "train_truth": train_truth, 
            "val_truth": val_truth, "test_truth": test_truth, "train_input_img_paths": train_input_img_paths, 
            "train_input_uimg_paths": train_input_uimg_paths, "val_input_img_paths": val_input_img_paths, 
            "test_input_img_paths": test_uimg_img_paths}
save_var["plot_progress"] = plot_progress.logs
savemat(save_path, save_var)