# TF to Torch Conversion

## Purpose
- Weight transfer
- Torch reproducibility
- Torch debugging

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import torch
import tensorflow as tf
import gdown
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import neurite as ne
import voxelmorph as vxm

from torch import nn
from tensorflow.keras.models import Model
from skimage.metrics import structural_similarity

# local code
from synthmorph import layers, networks, datamodule as dm, utils

device = 'cuda' if torch.cuda.is_available() else 'cpu' # note: only gpu has been tested so far
torch.multiprocessing.set_start_method('spawn')

2024-01-12 01:01:16.815534: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-12 01:01:16.876039: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# Prevent TF model from taking whole GPU memory
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

2024-01-12 01:01:23.061035: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-01-12 01:01:23.067815: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-01-12 01:01:23.068191: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysf

## Deformable Registration


### Transfer Weights TF -> Torch

In [None]:
# Define fresh Keras model, only for loading original author's weights
# This section is just a copy of the orginal demo to define the Keras model

# Label maps
in_shape = (256,) * 2
num_dim = len(in_shape)
num_label = 16
num_maps = 1
label_maps = []
for _ in range(num_maps):
    # Draw image and warp.
    im = ne.utils.augment.draw_perlin(
        out_shape=(*in_shape, num_label),
        scales=(32, 64), max_std=1,
    )
    warp = ne.utils.augment.draw_perlin(
        out_shape=(*in_shape, num_label, num_dim),
        scales=(16, 32, 64), max_std=16,
    )

    # Transform and create label map.
    im = vxm.utils.transform(im, warp)
    lab = tf.argmax(im, axis=-1)
    label_maps.append(np.uint8(lab))

# Image generator
gen_arg = dict(
    in_shape=in_shape,
    in_label_list=np.unique(label_maps),
    warp_std=3,
    warp_res=(8, 16, 32),
)
gen_model_1 = ne.models.labels_to_image(**gen_arg, id=1)
gen_model_2 = ne.models.labels_to_image(**gen_arg, id=2)

# Registration model.
reg_model = vxm.networks.VxmDense(
    inshape=in_shape,
    int_resolution=2,
    svf_resolution=2,
    nb_unet_features=([256] * 4, [256] * 8),
    reg_field='warp',
)

# Model for optimization.
ima_1, map_1 = gen_model_1.outputs
ima_2, map_2 = gen_model_2.outputs

_, warp = reg_model((ima_1, ima_2))
pred = vxm.layers.SpatialTransformer(fill_value=0)((map_1, warp))

inputs = gen_model_1.inputs + gen_model_2.inputs
out = (map_2, pred)
model = tf.keras.Model(inputs, out)

In [None]:
# Load Keras pretrained weights
gdown.download('https://drive.google.com/uc?id=1xridvtyEWgWsWJPYVrQfDCtSgbj2beRz', 'weights.h5')
model.load_weights('weights.h5')

# Extract weights from the registration model only
keras_vxmdense = reg_model   
keras_weights = {w.name: (w.numpy(), w.dtype, w.shape) for w in keras_vxmdense.weights}

In [None]:
# Only get kernel weight (ie. skip biases)
keras_weights_keys = list(keras_weights.keys())
keras_kernels = [string for string in keras_weights_keys if 'bias' not in string]

In [None]:
# Define fresh Torch model
vol_size = (256,) * 2
unet_enc_nf = [256] * 4
unet_dec_nf = [256] * 8
int_steps = 7 
int_downsize = 2
bidir = False
torch_vxmdense = networks.VxmDense(
    inshape=vol_size,
    nb_unet_features=[unet_enc_nf, unet_dec_nf],
    int_steps=int_steps,
    int_downsize=int_downsize,
    bidir=bidir,
    unet_half_res=True,
)

torch_weights = torch_vxmdense.state_dict()

In [None]:
new_weights = {}

# Transfer the weights (the order of layers are the same)
for k,t in zip(keras_weights.keys(), torch_weights.keys()):
    if k in keras_kernels:
        new_weights[t] = torch.Tensor(np.moveaxis(keras_weights[k][0], [-1, -2], [0, 1]))
    else:
        new_weights[t] = torch.Tensor(keras_weights[k][0])

torch_vxmdense.load_state_dict(new_weights)
# torch.save(torch_vxmdense.state_dict(), Path(".") / 'authors.pth' )  # uncomment to save weights

### PyTorch Reimplementation Debug
Make sure that both TF and Torch models are using the same weights 

#### Load test data

In [None]:
# Data preprocessing for TF
def tf_conform(x, in_shape=in_shape):
    '''Resize and normalize image.'''
    x = np.float32(x)
    x = np.squeeze(x)
    x = ne.utils.minmax_norm(x)
    x = ne.utils.zoom(x, zoom_factor=[o / i for o, i in zip(in_shape, x.shape)])
    return np.expand_dims(x, axis=(0, -1))


In [None]:
# Load MNIST in TF
images, digits = tf.keras.datasets.mnist.load_data()[-1]
ind = np.flatnonzero(digits == 6)
moving = tf_conform(images[ind[256]])
fixed = tf_conform(images[ind[22]])

In [None]:
# Data preprocessing for Torch
torch_conform = lambda x, size: dm.conform(x, size, device) 
torch_moving = torch_conform(moving, (256,256))
torch_fixed = torch_conform(fixed, (256,256))

#### Custom models

In [None]:
# Load TF layers by indexing
custom_keras_layers = keras_vxmdense.layers[:]
custom_keras_model= Model(
    inputs=keras_vxmdense.inputs, 
    outputs=custom_keras_layers[-1].output    # output from chosen layer
)

In [None]:
# Load Torch layers by specifying each module
# Note: SpatialTransformer is not compatible with nn.Sequential, hence separated
unet_torch = torch_vxmdense.unet_model
flow_torch = torch_vxmdense.flow
vecint_torch = torch_vxmdense.integrate
rescale_torch = torch_vxmdense.fullsize
spatial_torch = torch_vxmdense.transformer
custom_torch_model = nn.Sequential(
    unet_torch, 
    flow_torch,
    vecint_torch,
    rescale_torch,
)
custom_torch_model = custom_torch_model.eval().cuda()

#### SSIM of TF vs Torch

In [None]:
# Compare chosen output from TF and Torch registration models
keras_output = custom_keras_model.predict((moving, fixed))
keras_output = keras_output.transpose(0, 3, 1, 2).squeeze(0)

torch_input = torch.cat([torch_moving, torch_fixed], dim=1)
torch_output = custom_torch_model(torch_input)
torch_output = spatial_torch([torch_moving, torch_output]) # uncomment only when comparing moved image (i.e. whole reg model)
torch_output = torch_output.squeeze(0).cpu().detach().numpy()

data_range = torch_output.max() - torch_output.min()
channel_axis = 0 if torch_output.ndim > 2 else None
ssim_mean, ssim_full= structural_similarity(
    torch_output, 
    keras_output,
    win_size=11,    # must be odd
    data_range=data_range,
    channel_axis=0,
    multichannel=True,
    full=True,
)

In [None]:
# Plot each axis of resulting SSIM
num_plots = ssim_full.shape[0]
plot_scale = 6
fig, axs = plt.subplots(1, num_plots, figsize=(plot_scale*num_plots, plot_scale), squeeze=False)   # subplots in  one row
fig.suptitle(f"SSIM Plot for each channel, Mean = {ssim_mean*100:.4f}")
for i in range(num_plots):
    ax = axs[0, i]
    im = ax.imshow(ssim_full[i], cmap='gray')
    ax.set_title(f'Channel {i+1}')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.tight_layout()
plt.show()



#### Inference time TF vs PyTorch

In [None]:
rng = np.random.default_rng(seed=42)
width, height = 256, 256
channel = 1
batch_size = 1
shape = (batch_size, width, height, channel)
moving = rng.standard_normal(size=shape)
fixed = rng.standard_normal(size=shape)

prepare_torch = lambda x: torch.from_numpy(x).to(device, torch.float32).permute(0, -1, 1, 2)
torch_moving = prepare_torch(moving)
torch_fixed = prepare_torch(fixed)

tf_moving = tf.constant(moving)
tf_fixed = tf.constant(fixed)

In [None]:
%%timeit -n 100 -r 10 -p 4 
keras_vxmdense.predict((tf_moving, tf_fixed), verbose=None)

In [None]:
%%timeit -n 100 -r 10 -p 4
torch_vxmdense(torch_moving, torch_fixed)

## Affine Registration

## Debug Modules

In [4]:
rng = np.random.default_rng()

### TF to Torch functions

#### tf.gather(params, indices, axis, batch_dims)

In [12]:
params = rng.random((3, 3,))
ind = (1,)
axis = -1
tf_gather = tf.gather(params, ind, axis=axis)
torch_gather = params[..., ind]
np.all(tf_gather == torch_gather)

True