# Supervised Registration for LASTEN
A U-Net is trained here to predict a displacement field and the euclidean distance is taken as error.

# Import statements
Following packages are necessary:

In [None]:
import sys
import random
import os
import imageio
import json
import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as kb
from endolas  import closs
from endolas import ccall
from endolas import LASTENSequence
from endolas import utils
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint
from endolas import UNet
from endolas import preprocess_input as pre_une
from matplotlib import pyplot as plt

# Checks
The version of tensorflow as well as the GPU support are checked.

In [None]:
print(tf.__version__)
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.config.experimental.get_visible_devices('GPU')

# Seeding
Seeds are set to ensure reproducible training.

In [None]:
tf.keras.backend.clear_session()

SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
random.seed(SEED)

# Data
Import training images

In [None]:
path_train = r'Data/LASTEN/train'
path_validation = r'Data/LASTEN/validation'
path_fixed = r'Data/LASTEN/fix_768_32'

width = 384
height = 384

grid_width = 18
grid_height = 18

batch_size=4

train_gen = LASTENSequence(path_train,
                           path_fixed,
                           batch_size=batch_size,
                           width=width,
                           height=height,
                           grid_width=grid_width, 
                           grid_height=grid_height,
                           preprocess_input=pre_une,
                           shuffle=True,
                           label="keypoints",
                           channel="moving+fixed")

val_gen = LASTENSequence(path_validation,
                           path_fixed,
                           batch_size=batch_size,
                           width=width,
                           height=height,
                           grid_width=grid_width, 
                           grid_height=grid_height,
                           preprocess_input=pre_une,
                           shuffle=False,
                           label="keypoints",
                           channel="moving+fixed")

# Network
A U-Net based network is instantiated with keras to run a semantic segmentation.

In [None]:
model = UNet(filters=32, layers=4, activation='linear', classes=2, input_shape=(width, height, 2))
model.summary()

# Training Preparation
Prepare the settings for training the model.

In [None]:
save_path =  r'Data/LASTEN/results'

logger = CSVLogger(save_path + "/log")
timelogger = ccall.TimeHistory(save_path)
checker = ModelCheckpoint(save_path + "/weights.{epoch:02d}.hdf5", period=10)
checker_best = ModelCheckpoint(save_path + "/best_weights.hdf5", save_best_only=True)
callbacks = [timelogger, logger, checker, checker_best]

eu_loss = closs.EuclideanLoss(batch_size=batch_size, grid_width=grid_width, grid_height=grid_height, loss_type='msed')
eu_met = closs.EuclideanLoss(batch_size=batch_size, grid_width=grid_width, grid_height=grid_height, loss_type='med')

model.compile(optimizer='adam', loss=eu_loss, metrics=[eu_met])

# Training
Run the training.

In [None]:
model.fit(train_gen,
          epochs=100,
          callbacks=callbacks,
          validation_data=val_gen,
          validation_freq=1)

# Evaluation
Quick overview if the proposed displacements are reasonable.

In [None]:
def plot_cube(img, x, y, val):
    
    img[y][x] = val
    img[y][x-1] = val
    img[y][x+1] = val
    img[y-1][x] = val
    img[y-1][x-1] = val
    img[y-1][x+1] = val
    img[y+1][x] = val
    img[y+1][x-1] = val
    img[y+1][x+1] = val 

In [None]:
X, y = val_gen[0]

y_pred = model.predict(X)

u_x = y_pred[0,:,:,0]
u_y = y_pred[0,:,:,1]

plt.imshow(y_pred[0,:,:,0], cmap="gray")

#plt.imsave(store_path + "/u_x.png", u_x, cmap="gray")
#plt.imsave(store_path + "/u_y.png", u_y, cmap="gray")

#u_x.dump(store_path + "/u_x")
#u_y.dump(store_path + "/u_y")

warp = np.zeros((width, height))

for index in range(0,grid_width*grid_height):
    x_pos = int(y[0, index, 0, 0])
    y_pos = int(y[0, index, 1, 0])
    
    plot_cube(warp, x_pos, y_pos, 1) #blue
    
    ux_field = y_pred[0,:,:,0]
    uy_field = y_pred[0,:,:,1]
    
    ux = ux_field[y_pos][x_pos]
    uy = uy_field[y_pos][x_pos]
    
    x_pos = int(round(x_pos + ux))
    y_pos = int(round(y_pos + uy))
            
    plot_cube(warp, x_pos, y_pos, 2) #green    
    
    x_pos = int(y[0, index, 0, 1])
    y_pos = int(y[0, index, 1, 1])
    
    plot_cube(warp, x_pos, y_pos, 3) #yellow    
    
plt.imshow(warp)
#plt.imsave(store_path + "/warp.png", warp)