In [None]:
# install packages
! pip install torch
! pip install numpy
! pip install matplotlib
! pip install torchvision
! pip install voxelmorph

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['NEURITE_BACKEND'] = 'pytorch'
os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph as vxm

## Registration Optimization

Registration optimization consists of three key components:

### 1. Transformation
- **Rigid**: Translation and rotation only.
- **Affine**: Includes scaling and shearing.
- **Deformable**: Allows complex, non-linear transformations.

### 2. Optimization
- **Gradient-based**: E.g., Gradient Descent, L-BFGS.

### 3. Cost Function
- **Mean Squared Error (MSE)**: Suitable for intensity-based similarity.
- **Normalized Cross Correlation (NCC)**: Suitable for Brain images.
- **Mutual Information (MI)**: Suitable for multi-modality images.


First generate a digital phatom of SAX

In [None]:
# Define image dimensions and center
image_size = (40, 64)
center = [20, 32]

# LV and RV radii (adjust as needed)
r_lv = [6, 12]  # LV endocardium and epicardium radii
r_rv = [16, 20, 9, 12]  # RV radii for ellipse (endo_a, epi_a, endo_b, epi_b)

# Create coordinate grid
y, x = np.ogrid[-center[0]:image_size[0] - center[0], -center[1]:image_size[1] - center[1]]

# Create LV masks
mask_lv_endo = x**2 + y**2 <= r_lv[0]**2  # Endocardium
mask_lv_epi = x**2 + y**2 <= r_lv[1]**2  # Epicardium

lv_mask = np.zeros(image_size)
lv_mask[mask_lv_epi] = 1  # Epicardium
lv_mask[mask_lv_endo] = 0  # Hollow out the endocardium

# Create RV masks (elliptical shape)
mask_rv_epi = ((x**2) / r_rv[1]**2 + (y**2) / r_rv[3]**2) <= 1  # Epicardium
mask_rv_endo = ((x**2) / r_rv[0]**2 + (y**2) / r_rv[2]**2) <= 1  # Endocardium

rv_mask = np.zeros(image_size)
rv_mask[mask_rv_epi] = 1  # Epicardium
rv_mask[mask_rv_endo] = 0  # Hollow out the endocardium
rv_mask[:, center[1]:] = 0  # Keep RV on the left half

# Combine LV and RV masks
heart_mask = lv_mask + rv_mask
heart_mask[heart_mask > 1] = 1  # Ensure binary mask

# Display the result
plt.imshow(heart_mask, cmap='gray')
plt.title("I am a binary heart image in short axis view.")
plt.axis("off")
plt.show()

# Transformation

To do the subpixel translation, we need to shift the pixels, then do the intepolation.

Here is the implementation of neighbour interpolation.

In [None]:
translated_values = np.zeros(heart_mask.shape)

shift_x = 2.4
shift_y = -7.6
padding_value = 0


for new_index_x in range(translated_values.shape[0]):
    for new_index_y in range(translated_values.shape[1]):
        previous_index_x = int(np.round(new_index_x + shift_x))
        previous_index_y = int(np.round(new_index_y + shift_y))
        if previous_index_x >= 0 and previous_index_x < translated_values.shape[0] and \
            previous_index_y >= 0 and previous_index_y < translated_values.shape[1]:
            translated_values[new_index_x, new_index_y] = heart_mask[previous_index_x, previous_index_y]
        else:
            translated_values[new_index_x, new_index_y] = padding_value

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(heart_mask,
           cmap=plt.get_cmap('gray'),
           aspect='equal',
           origin='lower',
           interpolation='nearest')
plt.title('Original Heart')
plt.subplot(1, 2, 2)
plt.imshow(translated_values,
           cmap=plt.get_cmap('gray'),
           aspect='equal',
           origin='lower',
           interpolation='nearest')
plt.title('Translated Heart')
plt.show()

Similarly, we can do the linear interpolation

In [None]:
for new_index_x in range(translated_values.shape[0]):
    for new_index_y in range(translated_values.shape[1]):
        previous_index_x = int(np.floor(new_index_x + shift_x))
        previous_index_y = int(np.floor(new_index_y + shift_y))
        relative_x = new_index_x + shift_x - previous_index_x
        relative_y = new_index_y + shift_y - previous_index_y
        for [a,b] in [[0,0],[0,1],[1,0],[1,1]]:
            if previous_index_x + a >=0 and previous_index_x + a < heart_mask.shape[0] and \
                previous_index_y + b >= 0 and previous_index_y +b < heart_mask.shape[1]:
                    translated_values[new_index_x,new_index_y] += \
                    heart_mask[previous_index_x + a, previous_index_y + b]*\
                        abs(1-a-relative_x)*abs(1-b-relative_y)
                # translated_value, initialised to be zero
                # (1- relative) * value from the floating image
                # a = 0 add the weighted value from left (same as b,since initial is 0)
                # a = 1 add the weighted value from right
            else:
                translated_values[new_index_x, new_index_y] = padding_value *\
                    abs(1-a-relative_x)*abs(1-b-relative_y)

# show the image
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(heart_mask,
           cmap=plt.get_cmap('gray'),
           aspect='equal',
           origin='lower',
           interpolation='nearest')
plt.title('I am the original heart.')
plt.subplot(1, 2, 2)
plt.imshow(translated_values,
           cmap=plt.get_cmap('gray'),
           aspect='equal',
           origin='lower',
           interpolation='nearest')
plt.title('I am a translated heart.')
plt.show()

### 2D Affine Transformation

A 2D affine transformation can be represented using a $3 \times 3$ matrix to describe transformations like translation, rotation, scaling, shearing, and combinations thereof.

The general affine transformation matrix $ \mathbf{A} $ is:

$$
\mathbf{A} =
\begin{bmatrix}
a_{11} & a_{12} & t_x \\
a_{21} & a_{22} & t_y \\
0      & 0      & 1
\end{bmatrix}
$$

#### Parameters:
1. **Translation** ($t_x, t_y$): Moves the image along the x-axis and y-axis.
2. **Rotation and Scaling** ($a_{11}, a_{22}$):
   - These diagonal elements contribute to scaling or uniform scaling with rotation.
3. **Shearing** ($a_{12}, a_{21}$):
   - These off-diagonal elements introduce shearing along the x or y axes.

#### Expanded Equation:
For a point $(x, y)$ in the original image, the transformed coordinates $(x', y')$ are computed as:

$$
\begin{bmatrix}
x' \\
y' \\
1
\end{bmatrix}
=
\begin{bmatrix}
a_{11} & a_{12} & t_x \\
a_{21} & a_{22} & t_y \\
0      & 0      & 1
\end{bmatrix}
\begin{bmatrix}
x \\
y \\
1
\end{bmatrix}
$$

This expands to:

$$
x' = a_{11}x + a_{12}y + t_x
$$
$$
y' = a_{21}x + a_{22}y + t_y
$$

### Special Cases:
- **Pure Translation**:
  $$
  \mathbf{A} =
  \begin{bmatrix}
  1 & 0 & t_x \\
  0 & 1 & t_y \\
  0 & 0 & 1
  \end{bmatrix}
  $$

- **Pure Rotation** (by angle $ \theta $):
  $$
  \mathbf{A} =
  \begin{bmatrix}
  \cos\theta & -\sin\theta & 0 \\
  \sin\theta &  \cos\theta & 0 \\
  0          &  0          & 1
  \end{bmatrix}
  $$

- **Scaling**:
  $$
  \mathbf{A} =
  \begin{bmatrix}
  s_x & 0   & 0 \\
  0   & s_y & 0 \\
  0   & 0   & 1
  \end{bmatrix}
  $$

- **Shearing**:
  $$
  \mathbf{A} =
  \begin{bmatrix}
  1 & k_x & 0 \\
  k_y & 1 & 0 \\
  0   & 0 & 1
  \end{bmatrix}
  $$

## Exercise 1: Applying a 30-Degree Clock-wise Rotation to the Heart Image

In this exercise, you will revise the affine matrix to apply a **30-degree rotation** to the center of the original heart image. The transformation will use **linear interpolation** to ensure smooth results.

### Instructions

1. Modify the affine transformation matrix to include a rotation by **30 degrees**.
2. Ensure the rotation is centered on the image to maintain alignment.
3. Use **linear interpolation** for warping the image.

### Key Points
- Rotation matrix for a 30-degree rotation is defined as:
  $$
  R =
  \begin{bmatrix}
  \cos\theta & -\sin\theta & 0 \\
  \sin\theta & \cos\theta & 0
  \end{bmatrix}
  $$
  Where $\theta = 30^\circ$.
- Ensure the affine matrix includes translation to keep the rotation centered.


### Interpolation and Transformation Customization

You can modify **shearing** and **translation** parameters using different interpolation methods such as:

- **Bilinear**
- **Cubic**

### Useful Packages for Interpolation and Transformation

The process of interpolation and transformation has been streamlined into convenient packages. Here are some you can explore:

- **`scipy.ndimage`**: Use `affine_transform` for affine transformations.
- **`torch.nn.functional`**: Combine `affine_grid` and `grid_sample` for efficient grid-based sampling.

### Important Notes
- Be **cautious about indexing**, as it may vary between different packages and implementations.


## Optimization and Loss Function

Pairwise registration typically involves a pair of images: the **moving image** and the **fixed image**. The goal is to find the optimal transformation parameters that align the moving image to the fixed image by minimizing an objective loss function.

Below is a custom implementation of an affine registration process:




In [6]:
def affine_register(moving_img,
                    fixed_img,
                    lr=1E-5,
                    epochs=1000,
                    device='cpu',
                    criterions=None, # one of nn.MSELoss(), NCCLoss, NMILoss()
                    ):
  # params initialization
  params = torch.tensor([1.0, 0.0, 0.0,  # a, b, tx
                        0.0, 1.0, 0.0],  # c, d, ty
                        requires_grad=True)

  # Optimizer
  optimizer = torch.optim.Adam([params], lr=lr)

  losses = []
  # Optimization loop
  for iteration in range(epochs):
      # Create affine matrix
      affine_matrix = params.view(2, 3).unsqueeze(0)  # Shape: (1, 2, 3)

      # Warp the moving image
      moved_grid = F.affine_grid(affine_matrix, moving_img.size(), align_corners=True)
      moved_img = F.grid_sample(moving_img, moved_grid, align_corners=True)

      # Compute similarity loss (e.g., Mean Squared Error)
      loss = criterions(moved_img, fixed_img)

      # Backpropagation
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      losses.append(loss.item())

  # plot the error.
  plt.plot(losses, label='Error')
  plt.title('Optimization Criterion')
  plt.xlabel('Epoch')
  plt.ylabel('Error')
  plt.legend()
  plt.show()
  # Final transformation matrix
  moved_affine = params.view(2, 3).detach()

  # Warp the moving image with the final transformation
  moved_grid = F.affine_grid(moved_affine.unsqueeze(0), moving_img.size(), align_corners=True)
  moved_img = F.grid_sample(moving_img, moved_grid, align_corners=True)
  return moved_img, moved_affine

In [7]:
device = 'cpu'
fixed_img = torch.Tensor(heart_mask).unsqueeze(0).unsqueeze(0).to(device)
moving_img = torch.Tensor(translated_values).unsqueeze(0).unsqueeze(0).to(device)

## Exercise 2: Defining and Using MSE Loss

In this exercise, we define the **Mean Squared Error (MSE)** as the loss function and use the `affine_register` function to optimize the transformation parameters that align the moving image to the fixed image.

The **Mean Squared Error (MSE)** measures the average squared difference between corresponding pixels of two images. The formula for MSE is given by:

$$
\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} \left( \text{moved}_i - \text{fixed}_i \right)^2
$$

Where:
- $\text{moved}_i$: Pixel intensity of the moving image after transformation.
- $\text{fixed}_i$: Pixel intensity of the fixed reference image.
- $N$: Total number of pixels.


## Exercise 3: Revising to Rigid Transformation

In this exercise, we revise the `affine_register` function to create a `rigid_register` function. For **rigid transformations**, the transformation is limited to **rotation** and **translation**. This means scaling and shearing are not allowed, reducing the number of free parameters to **three**: $ \theta, t_x, t_y $.

### Definition of Rigid Transformation

The rigid transformation for 2D images is represented by the following equation:

$$
T_{\text{rigid}} =
\begin{bmatrix}
\cos\theta & -\sin\theta & t_x \\
\sin\theta & \cos\theta & t_y
\end{bmatrix}
$$

Where:
- $ \theta $: The rotation angle.
- $ t_x, t_y $: The translation in the x and y directions.


## Exercise 4: Defining NCC Loss and Using `rigid_register`

In this exercise, we define the **Normalized Cross-Correlation (NCC)** as the loss function and use the `rigid_register` function to optimize the transformation parameters. If the `rigid_register` function fails, you can alternatively refer back to the `affine_register` function.

### What is NCC?

The **Normalized Cross-Correlation (NCC)** measures the similarity between two images, accounting for differences in mean and variance. NCC is a commonly used similarity metric in image registration, especially for intensity-based methods.

The formula for NCC is:

$$
\text{NCC} = \frac{\text{Covariance}(moved, fix)}{\sqrt{\text{Var}(moved) \cdot \text{Var}(fix)} + \epsilon}
$$

Where:
- $ \text{Covariance}(moved, fix) $: Measures the relationship between the pixel intensities of the moving and fixed images.
- $ \text{Var}(moved) $: Variance of the moving image.
- $ \text{Var}(fix) $: Variance of the fixed image.
- $ \epsilon $: A small constant to avoid division by zero.

## Exercise 5: Define Mutual Information (MI) and Using `affine_register` on real-world Cine and T2 Cardiac Magnetic Resonance (CMR) images



In this exercise, we define the **Mutual Information (MI)** as the loss function and use the `affine_register` function to optimize the transformation parameters on real cardiac magentic resonance images (cine and T2).

We now transition from **binary CMR images** to **real-world Cine and T2 CMR images**, adopted from the public **MSCMR dataset**. Despite the use of motion control techniques such as **breath-holding** and **cardiac triggering**, shifts remain a common occurrence, necessitating robust image registration techniques.

Unlike the simplified digital phantoms we previously created, these real-world images feature **complex backgrounds**, introducing significant challenges to the registration process. These complexities highlight the need for advanced and adaptable registration frameworks capable of handling variability and noise in real-collected datasets.

---

### What is Mutual Information?

**Mutual Information (MI)** measures the amount of information shared between two images. It is defined as:

$$
I(X; Y) = \sum_{x \in X} \sum_{y \in Y} p(x, y) \log \frac{p(x, y)}{p(x) p(y) + \epsilon}
$$

Where:
- $X$ and $Y$: Intensity values of the moving and fixed images, respectively.
- $p(x, y)$: Joint probability distribution of $X$ and $Y$.
- $p(x)$ and $p(y)$: Marginal probability distributions of $X$ and $Y$.
- $\epsilon$: A small constant to avoid division by zero.

---

### Why Use MI?

- **Robust to Intensity Differences**: MI is ideal for registering images with different intensity distributions (e.g., multi-modal MRI or CT images).
- **Nonlinear Relationships**: Unlike Mean Squared Error or NCC, MI can capture complex, nonlinear relationships between pixel intensities.

---

### Implementation Steps
0. Remember to upload the 'realSCMR.npz' to your Google Drive.
1. Compute the joint histogram of the moving and fixed images.
2. Normalize the histogram to get the joint probability distribution, $p(x, y)$.
3. Calculate the marginal distributions, $p(x)$ and $p(y)$.
4. Compute the MI using the formula above.

---

In [None]:
# upload npz
MSCMR = np.load('realSCMR.npz')
print(MSCMR.keys())

In [None]:
# plot the three images
plt.subplot(1, 3, 1)
plt.imshow(MSCMR['C0'],cmap = 'gray')
plt.title('Cine')

plt.subplot(1, 3, 2)
plt.imshow(MSCMR['T2'],cmap = 'gray')
plt.title('T2')

plt.subplot(1, 3, 3)
plt.imshow(MSCMR['T2_transformed'],cmap = 'gray')
plt.title('T2 Transformed')

# Deep-Learning Baseline: VoxelMorph with PyTorch

VoxelMorph is a deep-learning-based framework for image registration. It uses a U-Net-like architecture to predict dense deformation fields for aligning images, making it faster and more scalable compared to traditional methods. Below, we demonstrate its application using PyTorch and the MNIST dataset as a simplified example.

---

## Key Features of VoxelMorph

1. **End-to-End Learning**:
   - Learns a dense deformation field directly from the data in a single forward pass.
2. **CNN-Based**:
   - Employs a U-Net to predict the deformation field.
3. **Loss Function**:
   - Combines similarity loss (e.g., Mean Squared Error) with a regularization term to enforce smooth deformations.
4. **Applications**:
   - Medical imaging (e.g., MRI, CT scans), computer vision tasks requiring alignment.

---

## VoxelMorph Workflow

1. **Input**: A pair of images (e.g., a moving image and a fixed image).
2. **Network**: A U-Net predicts a dense deformation field.
3. **Transformation**: The deformation field is applied to the moving image to align it with the fixed image.
4. **Loss Function**:
   - Similarity Loss: Ensures alignment of the two images.
   - Regularization Loss: Encourages smoothness in the deformation field.
5. **Output**: The warped image and the deformation field.

---

## Implementation with PyTorch

Below is a PyTorch implementation of VoxelMorph using the MNIST dataset:

### Data Preparation

We use pairs of MNIST digits as an illustrative example for image registration.


In [None]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [None]:
#Downloading MNIST data

train_data = dsets.MNIST(root = './data', train = True,
                        transform = transforms.ToTensor(), download = True)

# take the digit with 5 out of the training dataset
digit_5 = train_data.data[train_data.targets == 5]
print(digit_5.shape)
# normalize the digit_5 to [0,1]
digit_5 = digit_5.float()/255
# just use the first 1000 samples for experiments
digit_5 = digit_5[:1000]

In [None]:
nb_val = 200  # keep smaller number of samples for a faster training.
nb_tst = 100

x_trn = digit_5[:-(nb_val + nb_tst), ...]
x_val = digit_5[-(nb_val + nb_tst): -nb_tst, ...]
x_tst = digit_5[-nb_tst:, ...]

In [None]:
pad_amount = ((0, 0), (2,2), (2,2))

# fix data
x_trn = np.pad(x_trn, pad_amount, 'constant')
x_val = np.pad(x_val, pad_amount, 'constant')
x_tst = np.pad(x_tst, pad_amount, 'constant')

# verify
print('shape of training data', x_trn.shape)

In [None]:
# configure unet input shape
ndim = 2
inshape = x_trn.shape[1:]

# configure unet features
# unet architecture
enc_nf = [32, 32, 32]
dec_nf = [32, 32, 32, 32, 16]

model = vxm.networks.VxmDense(
    inshape=inshape,
    nb_unet_features=[enc_nf, dec_nf],
    bidir=False,
    int_steps=0,
    int_downsize=2,
)

In [None]:
def vxm_data_generator(x_data, batch_size=1):
    """
    Generator that takes in data of size [N, H, W], and yields data for
    our custom vxm model. Note that we need to provide numpy data for each
    input, and each output.

    inputs:  moving [bs, H, W, 1], fixed image [bs, H, W, 1]
    outputs: moved image [bs, H, W, 1], zero-gradient [bs, H, W, 2]
    """

    # preliminary sizing
    vol_shape = x_data.shape[1:] # extract data shape
    ndims = len(vol_shape)

    # prepare a zero array the size of the deformation
    # we'll explain this below
    zero_phi = np.zeros([batch_size, *vol_shape, ndims])

    while True:
        # prepare inputs:
        # images need to be of the size [batch_size, H, W, 1]
        idx1 = np.random.randint(0, x_data.shape[0], size=batch_size)
        moving_images = x_data[idx1, ..., np.newaxis]
        idx2 = np.random.randint(0, x_data.shape[0], size=batch_size)
        fixed_images = x_data[idx2, ..., np.newaxis]
        inputs = [moving_images, fixed_images]

        # prepare outputs (the 'true' moved image):
        # of course, we don't have this, but we know we want to compare
        # the resulting moved image with the fixed image.
        # we also wish to penalize the deformation field.
        outputs = [fixed_images, zero_phi]

        yield (inputs, outputs)

In [None]:
# let's train
train_generator = vxm_data_generator(x_trn)
in_sample, out_sample = next(train_generator)

Define the loss, configs and start training

In [None]:
losses = [nn.MSELoss()]
weights = [1]
epochs = 200
learning_rate = 1E-3

In [None]:
# prepare the model for training and send to device
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
total_loss = []
for epoch in range(epochs):

    epoch_loss = []
    epoch_total_loss = []
    epoch_step_time = []

    for step in range(5):

        # generate inputs (and true outputs) and convert them to tensors
        inputs, y_true = next(train_generator)
        inputs = [torch.from_numpy(d).to(device).float().permute(0, 3, 1, 2) for d in inputs]
        y_true = [torch.from_numpy(d).to(device).float().permute(0, 3, 1, 2) for d in y_true]

        # run inputs through the model to produce a warped image and flow field
        y_pred = model(*inputs)

        # calculate total loss
        loss = 0
        loss_list = []

        for n, loss_function in enumerate(losses):
            curr_loss = loss_function(y_true[n], y_pred[n]) * weights[n]
            loss_list.append(curr_loss.item())
            loss += curr_loss

        epoch_loss.append(loss_list)
        epoch_total_loss.append(loss.item())

        # backpropagate and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_loss.append(np.mean(epoch_total_loss))

In [None]:
# plot error
plt.plot(total_loss, label='Error')
plt.title('Optimization Criterion')
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.legend()
plt.show()

In [None]:
# Evaluation
model.eval()


val_generator = vxm_data_generator(x_val)
in_sample, out_sample = next(val_generator)

# Convert inputs and true outputs to tensors
inputs = [torch.from_numpy(d).to(device).float().permute(0, 3, 1, 2) for d in in_sample]
y_true = [torch.from_numpy(d).to(device).float().permute(0, 3, 1, 2) for d in out_sample]

# Forward pass through the model to generate predictions
with torch.no_grad():  # Disable gradients for evaluation
    y_pred = model(*inputs)

# plot the images
plt.subplot(1, 3, 1)
plt.imshow(in_sample[0][0], cmap=plt.get_cmap('gray'), aspect='equal',
           origin='lower', interpolation='nearest')
plt.title('moving')

plt.subplot(1, 3, 2)
plt.imshow(y_pred[0].squeeze().cpu().detach().numpy(), cmap=plt.get_cmap('gray'), aspect='equal',
           origin='lower', interpolation='nearest')
plt.title('moved')

plt.subplot(1, 3, 3)
plt.imshow(y_true[0].squeeze().cpu().detach().numpy(), cmap=plt.get_cmap('gray'), aspect='equal',
           origin='lower', interpolation='nearest')
plt.title('fixed')

## Exercise 7 (Challenging): Train and Validate VoxelMorph with the MedMNIST Dataset

In this exercise, we will tackle the challenge of medical imaging registration using the **MedMNIST** dataset. MedMNIST is a large-scale, lightweight benchmark designed for biomedical image classification and segmentation tasks. Specifically, we will train and validate a **VoxelMorph** model to align medical images, leveraging the diverse subsets and labels provided by MedMNIST. This hands-on exercise combines state-of-the-art registration techniques with the versatility of the MedMNIST dataset to deepen your understanding of medical image analysis.