In [None]:
import os, sys

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
assert tf.__version__.startswith('2.'), 'assumes Tensorflow 2.0+'

import voxelmorph as vxm
import neurite as ne

from skimage import io
from skimage.color import rgb2gray

In [None]:
# values
orig_size = 2912
img_size = 512

In [None]:
def generator(init_data, batch_size=32):
    """
    init_data: list of pairs of images (moving, fixed)
    
    inputs:  moving [bs, H, W, 1], fixed image [bs, H, W, 1]
    outputs: moved image [bs, H, W, 1], zero-gradient [bs, H, W, 2]
    """
    while True:
        moving_imgs = []
        fixed_imgs = []
        
        # get batch-size number of random integers
        idx = np.random.randint(0, len(init_data), size=batch_size)
        
        # take batch-size number of random images from given pairs of images
        for id_ in idx:
            moving_imgs.append(init_data[id_][0])
            fixed_imgs.append(init_data[id_][1])
        moving_imgs = np.asarray(moving_imgs)
        fixed_imgs = np.asarray(fixed_imgs)
        
        # change dimensions of images from [bs, H, W] to [bs, H, W, 1]
        moving_imgs = moving_imgs[:, :, :, np.newaxis]
        fixed_imgs = fixed_imgs[:, :, :, np.newaxis]
        
        # set inputs and dummy outputs
        inputs = [moving_imgs, fixed_imgs]
        outputs = [fixed_imgs, np.zeros([batch_size, img_size, img_size, 2])]
        
        yield (inputs, outputs)

In [None]:
# Get Image pairs

# Get list of names of images in respective pre-downsized folder
images_path = os.path.join(os.getcwd(), "Images_downsized_"+str(img_size))

# For each image: normalize + add to list
images = []
for image_name in os.listdir(images_path):
    #im = rgb2gray(io.imread(os.path.join(images_path, image_name)))
    im = io.imread(os.path.join(images_path, image_name))
    im = im / 255
    images.append(im)

original_images = list(zip(images[::2], images[1::2]))
    
    
# Do the same for the other folders (data augmentated)    
images_path = os.path.join(os.getcwd(), "Images_close_to_gt_downsized_"+str(img_size))
for image_name in os.listdir(images_path):
    #im = rgb2gray(io.imread(os.path.join(images_path, image_name)))
    im = io.imread(os.path.join(images_path, image_name))
    im = im / 255
    images.append(im)
      
images_path = os.path.join(os.getcwd(), "Images_rotated_downsized_"+str(img_size))
for image_name in os.listdir(images_path):
    #im = rgb2gray(io.imread(os.path.join(images_path, image_name)))
    im = io.imread(os.path.join(images_path, image_name))
    im = im / 255
    images.append(im)
    
# add every 2 images together in a tuple as a pair
# This assumes that pairs are sorted in the folder to come right after each other
image_pairs = list(zip(images[::2], images[1::2]))

In [None]:
# Do the same for the other folders (data augmentated)    
images = []
images_path = os.path.join(os.getcwd(), "Images_close_to_gt_downsized_"+str(img_size))
for image_name in os.listdir(images_path):
    #im = rgb2gray(io.imread(os.path.join(images_path, image_name)))
    im = io.imread(os.path.join(images_path, image_name))
    im = im / 255
    images.append(im)
image_pairs = list(zip(images[::2], images[1::2]))

In [None]:
batch_size_=16
gen = generator(image_pairs, batch_size=batch_size_)

In [None]:
# configure unet features 
nb_features = [
    [32, 32, 32, 32],         # encoder features
    [32, 32, 32, 32, 32, 16]  # decoder features
]
# build model
vxm_model = vxm.networks.VxmDense([img_size, img_size], nb_features, int_steps=0)

# losses
losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]

# balance the two losses by a hyper-parameter
#try [1, 0.01]
loss_weights = [1, 0.05]

# Compile the model and fit it to training data
#try optimizer='Adam'
vxm_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=losses, loss_weights=loss_weights)

hist = vxm_model.fit(gen, epochs=200, steps_per_epoch=15, verbose=2)

In [None]:
# Plot the losses
fig, ax = plt.subplots(1,3)
ax[0].plot(hist.epoch, hist.history["loss"], '.-')
ax[0].title.set_text("loss")
ax[1].plot(hist.epoch, hist.history["transformer_loss"], '.-')
ax[1].title.set_text("transformer_loss")
ax[2].plot(hist.epoch, hist.history["flow_loss"], '.-')
ax[2].title.set_text("flow_loss")
plt.ylabel('loss')
plt.xlabel('epoch')
plt.tight_layout()
plt.show()

In [None]:
# having a look at random data

# Get random data
gen_test = generator(image_pairs, batch_size = 1)
test_input, _ = next(gen_test)

test_pred = vxm_model.predict(test_input)

# visualize
images = [img[0, :, :, 0] for img in test_input + test_pred] 
titles = ['moving', 'fixed', 'moved', 'flow']
ne.plot.slices(images, titles=titles, cmaps=['gray'], do_colorbars=True)

ne.plot.flow([test_pred[1].squeeze()], width=8)

In [None]:
# EVALUATION

# Path to ground truths
gt_path = os.path.join(os.getcwd(), "Ground Truth")
gt_names = os.listdir(gt_path)

all_mean_rel_errors = []
all_abs_errors_dfs = {}
# For every image pair in the original images
for i, (img1, img2) in enumerate(original_images):
    gt = gt_names[i]
    
    # Get the imagepair with the given ID
    moving_imgs = np.asarray(img1)
    fixed_imgs = np.asarray(img2)

    moving_imgs = moving_imgs[np.newaxis, :, :, np.newaxis]
    fixed_imgs = fixed_imgs[np.newaxis, :, :, np.newaxis]

    # Predict that image pair
    pred = vxm_model.predict([moving_imgs, fixed_imgs])

    # visualize
    images_ = [img[0, :, :, 0] for img in [moving_imgs, fixed_imgs] + pred] 
    titles = ['moving', 'fixed', 'moved', 'flow']
    ne.plot.slices(images_, titles=titles, cmaps=['gray'], do_colorbars=True)
    
    abs_errors_dics = []
    
    # Open the relating ground truth file
    with open(os.path.join(gt_path, gt), "r") as f:
        lines = f.readlines()
        relative_errors = []
        for line in lines:
            # [:-1] removes \n at the end
            vals = line[:-1].split(" ")
            # scale ground truth values to lower image size
            vals_adjusted = [float(val) * (img_size / orig_size) for val in vals]

            # create points 
            shift = pred[1].squeeze()[round(vals_adjusted[0])][round(vals_adjusted[1])]
            origin = np.asarray([vals_adjusted[0], vals_adjusted[1]])
            goal = np.asarray([vals_adjusted[2], vals_adjusted[3]])       
            predict = np.add(origin, shift)

            # calculate errors (lineare distance, l2-norm)
            error_label = np.linalg.norm(origin-goal)
            error_pred = np.linalg.norm(predict-goal)
            
            # Relative error to gold distance
            # >1 : Moved image is further away from fixed image than initial moving Image
            # =1 : Same Distance away (no change)
            # <1 : Closer
            # The closer to 0 the better
            relative_errors.append(error_pred/error_label)
            
            abs_errors_dics.append({"Original Error": error_label, "New Error": error_pred})
    
    # Calculate mean error for each pair
    mean_rel_error = np.mean(np.asarray(relative_errors))
    all_mean_rel_errors.append(mean_rel_error)
    print(gt, "Mean relative error:", mean_rel_error)
    df = pd.DataFrame(abs_errors_dics)
    all_abs_errors_dfs[gt] = df
    display(df)
    
    
# Show mean errors for each pair, and the average mean error over all pairs
print(all_mean_rel_errors)
print(np.mean(np.asarray(all_mean_rel_errors)))

In [None]:
#vxm_model.save(os.path.join(os.getcwd(), "models"))