In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import tqdm
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import neurite as ne
import voxelmorph as vxm

In [None]:
# Prevent TF model from taking whole GPU memory
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
        tf.config.experimental.set_virtual_device_configuration(gpu,[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])

## Reproducing the Demo (TF)
Code in this section is exactly similar to
[authors' demo](https://colab.research.google.com/drive/1GjpjkhKGrg5W-cvZVObBo3IoIUwaPZBU?usp=sharing)

In [None]:
# Input shapes.
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 4
num_maps = 40

# Shape generation.
label_maps = []
for _ in tqdm.tqdm(range(num_maps)):
    # Draw image and warp.
    im = ne.utils.augment.draw_perlin(
        out_shape=(*in_shape, num_label),
        scales=(32, 64), max_std=1,
    )
    warp = ne.utils.augment.draw_perlin(
        out_shape=(*in_shape, num_label, num_dim),
        scales=(16, 32, 64), max_std=16,
    )

    # Transform and create label map.
    im = vxm.utils.transform(im, warp)
    lab = tf.argmax(im, axis=-1)
    label_maps.append(np.uint8(lab))


# Visualize shapes.
num_row = 2
per_row = 4
for i in range(0, num_row * per_row, per_row):
    ne.plot.slices(label_maps[i:i + per_row], cmaps=['tab20c'])

In [None]:
plt.imshow(label_maps[0], cmap='tab20c')

In [None]:
# Image generation. For accurate registration, the landscape of generated warps
# and image contrasts will need to include the target distribution.
gen_arg = dict(
    in_shape=in_shape,
    in_label_list=np.unique(label_maps),
    warp_std=3,
    warp_res=(8, 16, 32),
)
gen_model_1 = ne.models.labels_to_image(**gen_arg, id=1)
gen_model_2 = ne.models.labels_to_image(**gen_arg, id=2)


# Test repeatedly for single input.
num_gen = 8
input = np.expand_dims(label_maps[0], axis=(0, -1))
slices = [gen_model_1.predict(input)[0] for _ in range(num_gen)]
ne.plot.slices(slices)

In [None]:
# Registration model.
reg_model = vxm.networks.VxmDense(
    inshape=in_shape,
    int_resolution=2,
    svf_resolution=2,
    nb_unet_features=([256] * 4, [256] * 8),
    reg_field='warp',
)


# Model for optimization.
ima_1, map_1 = gen_model_1.outputs
ima_2, map_2 = gen_model_2.outputs

_, warp = reg_model((ima_1, ima_2))
pred = vxm.layers.SpatialTransformer(fill_value=0)((map_1, warp))

inputs = gen_model_1.inputs + gen_model_2.inputs
out = (map_2, pred)
model = tf.keras.Model(inputs, out)


# Compilation.
model.add_loss(vxm.losses.Dice().loss(*out) + tf.repeat(1., tf.shape(pred)[0]))
model.add_loss(vxm.losses.Grad('l2', loss_mult=1).loss(None, warp))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4))

In [None]:
# Train model for a few epochs. Re-running the cell will continue training.
gen = vxm.generators.synthmorph(
    label_maps,
    batch_size=1,
    same_subj=True,
    flip=True,
)

hist = model.fit(
    gen,
    initial_epoch=0,
    epochs=4000,
    steps_per_epoch=40,
    verbose=1,
) 

weights_dir = Path(".") / "weights" / "keras"
model.save_weights(weights_dir / "400k_original.h5")

# Visualize loss.
plt.plot(hist.epoch, hist.history['loss'], '.-');
plt.xlabel('Epoch');
plt.ylabel('Loss');

In [None]:
# Download model weights to skip training and save time.
# !gdown -O weights.h5 1xridvtyEWgWsWJPYVrQfDCtSgbj2beRz
# model.load_weights('weights.h5')

In [None]:
# Conform test data.
def conform(x, in_shape=in_shape):
    '''Resize and normalize image.'''
    x = np.float32(x)
    x = np.squeeze(x)
    x = ne.utils.minmax_norm(x)
    x = ne.utils.zoom(x, zoom_factor=[o / i for o, i in zip(in_shape, x.shape)])
    return np.expand_dims(x, axis=(0, -1))

In [None]:
# Test on MNIST.
images, digits = tf.keras.datasets.mnist.load_data()[-1]
ind = np.flatnonzero(digits == 6)
moving = conform(images[ind[233]])
fixed = conform(images[ind[199]])
moved, warp = reg_model.predict((moving, fixed))

# Plot registration
ne.plot.slices(
    slices_in=(moving, fixed, moved),
    titles=('Moving', 'Fixed', 'Moved'),
    do_colorbars=True,
);

In [None]:
# Plot warp matrix
ne.plot.slices(
    slices_in=(warp[..., 0], warp[..., 1]),
    titles=('Warp (x-axis)', 'Warp (y-axis)'),
    do_colorbars=True,
);

In [None]:
# Test on OASIS-1.
images = ne.py.data.load_dataset('2D-OASIS-TUTORIAL')
moving = conform(images[2])
fixed = conform(images[7])
moved, warp = reg_model.predict((moving, fixed))


ne.plot.slices(
    slices_in=(moving, fixed, moved),
    titles=('Moving', 'Fixed', 'Moved'),
    do_colorbars=True,
);

## Transfer Weights TF -> Torch

In [None]:
import tqdm
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import neurite as ne
import voxelmorph as vxm


In [None]:
import vte.experiments.voxel_morph.model.synthmorph as models
import vte.experiments.voxel_morph.model.synthmorph_new as new
import vte.experiments.voxel_morph.datamodule.synth as datamodule
import vte.experiments.voxel_morph.utils as utils
from vte.experiments.voxel_morph.synthmorph_utils import(
    conform as torch_conform, post_predict, image_to_numpy,\
    invert_grayscale, overlay_images,\
    plot_array_row, superimpose_circles,\
    convert_to_single_rgb, rotate
)
from pathlib import Path
import urllib.request
from PIL import Image
from matplotlib import pyplot as plt
from cv2 import resize
import numpy as np 
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl
import torchinfo
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity

import kornia.metrics as metrics

In [None]:
# Define fresh Torch model
vol_size = (256,) * 2
unet_enc_nf = [256] * 4
unet_dec_nf = [256] * 8
int_steps = 7 
int_downsize = 2
bidir = False
torch_vxmdense = new.VxmDense(
    inshape=vol_size,
    nb_unet_features=[unet_enc_nf, unet_dec_nf],
    int_steps=int_steps,
    int_downsize=int_downsize,
    bidir=bidir,
    unet_half_res=True,
)

# torch_vxmdense.load_state_dict(
#     torch.load('/home/jovyan/vte/vte/experiments/voxel_morph/model/torch_original.pth')
# )

torch_weights = torch_vxmdense.state_dict()

In [None]:
# Define fresh Keras model
# This section is just a copy of the demo to define the Keras model

# Label maps
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 16
num_maps = 1
label_maps = []
for _ in range(num_maps):
    # Draw image and warp.
    im = ne.utils.augment.draw_perlin(
        out_shape=(*in_shape, num_label),
        scales=(32, 64), max_std=1,
    )
    warp = ne.utils.augment.draw_perlin(
        out_shape=(*in_shape, num_label, num_dim),
        scales=(16, 32, 64), max_std=16,
    )

    # Transform and create label map.
    im = vxm.utils.transform(im, warp)
    lab = tf.argmax(im, axis=-1)
    label_maps.append(np.uint8(lab))

# Image generator
gen_arg = dict(
    in_shape=in_shape,
    in_label_list=np.unique(label_maps),
    warp_std=3,
    warp_res=(8, 16, 32),
)
gen_model_1 = ne.models.labels_to_image(**gen_arg, id=1)
gen_model_2 = ne.models.labels_to_image(**gen_arg, id=2)

# Registration model.
reg_model = vxm.networks.VxmDense(
    inshape=in_shape,
    int_resolution=2,
    svf_resolution=2,
    nb_unet_features=([256] * 4, [256] * 8),
    reg_field='warp',
)

# Model for optimization.
ima_1, map_1 = gen_model_1.outputs
ima_2, map_2 = gen_model_2.outputs

_, warp = reg_model((ima_1, ima_2))
pred = vxm.layers.SpatialTransformer(fill_value=0)((map_1, warp))

inputs = gen_model_1.inputs + gen_model_2.inputs
out = (map_2, pred)
model = tf.keras.Model(inputs, out)

# Compilation.
model.add_loss(vxm.losses.Dice().loss(*out) + tf.repeat(1., tf.shape(pred)[0]))
model.add_loss(vxm.losses.Grad('l2', loss_mult=1).loss(None, warp))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4))

In [None]:
# Load Keras pretrained weights
# Alternatively, download weights from https://drive.google.com/uc?id=1xridvtyEWgWsWJPYVrQfDCtSgbj2beRz
# !gdown -O weights.h5 1xridvtyEWgWsWJPYVrQfDCtSgbj2beRz
model.load_weights('weights.h5')
# Extract weights from the registration model only
keras_vxmdense = reg_model   
keras_weights = {w.name: (w.numpy(), w.dtype, w.shape) for w in keras_vxmdense.weights}

In [None]:
# Only get kernel weights
keras_weights_keys = list(keras_weights.keys())
keras_kernels = [string for string in keras_weights_keys if 'bias' not in string]
keras_kernels

In [None]:
new_weights = {}

# Transfer the weights (the order of layers are the same)
for k,t in zip(keras_weights.keys(), torch_weights.keys()):
    if k in keras_kernels:
        new_weights[t] = torch.Tensor(np.moveaxis(keras_weights[k][0], [-1, -2], [0, 1]))
    else:
        new_weights[t] = torch.Tensor(keras_weights[k][0])

torch_vxmdense.load_state_dict(new_weights)

In [None]:
torch.save(torch_vxmdense.state_dict(), 'authors.pth')

## Weight transfer and Torch layers debug

In [None]:
custom_keras_layers = keras_vxmdense.layers[:]
custom_keras_model= Model(
    inputs=keras_vxmdense.inputs, 
    outputs=custom_keras_layers[-1].output,
)

In [None]:
# Conform test data.
def conform(x, in_shape=in_shape):
    '''Resize and normalize image.'''
    x = np.float32(x)
    x = np.squeeze(x)
    x = ne.utils.minmax_norm(x)
    x = ne.utils.zoom(x, zoom_factor=[o / i for o, i in zip(in_shape, x.shape)])
    return np.expand_dims(x, axis=(0, -1))

In [None]:
images, digits = tf.keras.datasets.mnist.load_data()[-1]
ind = np.flatnonzero(digits == 5)
moving = conform(images[ind[223]])
fixed = conform(images[ind[199]])

In [None]:
unet_torch = torch_vxmdense.unet_model
flow_torch = torch_vxmdense.flow
vecint_torch = torch_vxmdense.integrate
rescale_torch = torch_vxmdense.fullsize
spatial_torch = torch_vxmdense.transformer
custom_torch_model = nn.Sequential(
    unet_torch, 
    flow_torch, 
    vecint_torch,
    rescale_torch,
    spatial_torch,
)
custom_torch_model = custom_torch_model.cuda()

In [None]:
def preprocess_torch(x):
    x = torch.from_numpy(x)
    x = x.to('cuda')
    x = x.permute(0, 3, 1, 2)
    return x

In [None]:
torch_moving = torch_conform(x=moving, in_shape=(256,256))
torch_moving = preprocess_torch(torch_moving)
torch_fixed = torch_conform(x=fixed, in_shape=(256,256))
torch_fixed = preprocess_torch(torch_fixed)

torch_unet_input = torch.cat([torch_moving, torch_fixed], dim=1)

### SSIM of TF vs Torch

In [None]:
ssim_win_size = 13

In [None]:
# keras_output = custom_keras_model.predict((moving, fixed))
keras_output = keras_vxmdense.predict((moving, fixed))
keras_source, keras_flow = keras_output

keras_flow = keras_flow.transpose(0, 3, 1, 2)
keras_source = keras_source.transpose(0, 3, 1, 2)

In [None]:
# torch_output = custom_torch_model.forward(torch_unet_input)
torch_output = torch_vxmdense(torch_moving, torch_fixed)['y_source']
ssim_torch = torch.mean(metrics.ssim(torch_output, torch.tensor(keras_source).to('cuda'), ssim_win_size))
print(f'ssim_torch = {ssim_torch}')

In [None]:
torch_output.shape

In [None]:
keras_output = keras_source.squeeze()
# torch_output = custom_torch_model.forward(torch_unet_input)
torch_output = torch_vxmdense(torch_moving, torch_fixed)['y_source']
torch_output = torch_output.permute(0, 2, 3, 1).cpu().detach().numpy()
torch_output = torch_output.squeeze()
print(keras_output.shape, torch_output.shape)
ssim_skimage = structural_similarity(torch_output, keras_output, multichannel=True)
ssim_skimage

### TF -> Torch functions debug

#### Voxelmorph interpn() function (Solved)

In [None]:
vol = np.random.randn(128, 128, 2)
loc = np.random.randn(128, 128, 2)

In [None]:
keras_vol = tf.constant(vol)
keras_loc = tf.constant(loc)
keras_interp = ne.utils.interpn(keras_vol, keras_loc, 'linear', None)
keras_interp = tf.expand_dims(keras_interp, 0)
keras_interp.shape

In [None]:
torch_vol = torch.from_numpy(vol)
torch_loc = torch.from_numpy(loc)
torch_interp = utils.interpn(torch_vol, torch_loc, 'linear', None)
torch_interp = torch.unsqueeze(torch_interp, 0)
torch_interp.shape

In [None]:
(torch_interp.numpy() == keras_interp.numpy()).all()



In [None]:
torch_interp = torch_interp.permute(0, 3, 1, 2)
keras_interp = torch.tensor(keras_interp.numpy()).permute(0, 3, 1, 2)

In [None]:
interp_ssim = torch.mean(metrics.ssim(torch_interp, keras_interp, ssim_win_size))
interp_ssim