<div align="center">
    <h1>
Image Registration of In Vivo Micro-Ultrasound and Ex Vivo Pseudo-Whole Mount Histopathology Images of the Prostate: A Proof-of-Concept Study
    </h1>
</div>

**This notebook walks you through the steps required to train a deep learning architecture for registering micro-US and histopathological images**


<h2>
    Contents
    
1. [Imports](#imports)
2. [Parameters](#parameters)
3. [Training Affine Registration Network](#affinenetwork)
4. [Inference for Affine Registration [Optional]](#inference_affine_registration)
5. [Training Deformable Registration Network](#deformablenetwork)
6. [Inference for Deformable Registration [Optional]](#inference_deformable_registration)   
    
</h2>
    

---

## 1. Imports <a id="imports"><a>

Let't import all the necessary packages

---

In [None]:
import os
import glob
import csv
import torch

import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F


from PIL import Image
from torchvision import transforms  # Import for image transformations
from torch.utils.data import Dataset, DataLoader
from torchmetrics.functional.image import image_gradients
from tqdm import tqdm  # Import tqdm for progress bar
from torchvision.utils import save_image

# User Defined Utilities
from utils.Dataset import ImageRegistrationDataset
from utils.AffineRegistrationModel import AffineNet
from utils.DeformableRegistrationNetwork import DeformRegNet
from utils.SpatialTransformation import SpatialTransformer
from utils.miscellaneous import apply_affine_transformation, smoothness, loss_function, ssd_loss

---

Uncomment the following cell if you want to print the version of each module!!!

---

In [None]:
# import torchvision, PIL, scipy, matplotlib
# from platform import python_version

# print(f"Python Version: {python_version()}")
# print(f"Open CV Version: {cv2.__version__}")
# print(f"Numpy Version: {np.__version__}")
# print(f"PIL Version: {PIL.__version__}")
# print(f"Matplotlib Version: {matplotlib.__version__}")
# print(f"Torchvision Version: {torchvision.__version__}")
# print(f"Scipy Version: {scipy.__version__}")



In [None]:
# If using MIQ_Kernel, we have the follwing versions:
#     Python Version: 3.10.14
#     Open CV Version: 4.10.0
#     Numpy Version: 1.26.4
#     PIL Version: 10.3.0
#     Matplotlib Version: 3.8.3
#     Torchvision Version: 0.15.2
#     Scipy Version: 1.12.0
#     CSV Version 1.0

---

[OPTIONAL] Let's visualize the dataset. We will just display some random histopathology and microu-US images and associated masks.

---

In [None]:
# dataset_csv_path_file = "../data/processed_png_data/Training_Label_Paths_For_Fold1.csv"
# dataset = ImageRegistrationDataset(dataset_csv_path_file)
# micro_image, micro_mask, hist_image, hist_mask = dataset[3]
# import matplotlib.pyplot as plt
# plt.figure(figsize=(12,5))
# plt.subplot(141);plt.imshow(micro_image.permute(1,2,0));plt.title("Micro Image");plt.axis('off');
# plt.subplot(142);plt.imshow(micro_mask.permute(1,2,0));plt.title("Micro Mask");plt.axis('off');
# plt.subplot(143);plt.imshow(hist_image.permute(1,2,0));plt.title("Hist Image");plt.axis('off');
# plt.subplot(144);plt.imshow(hist_mask.permute(1,2,0));plt.title("Hist Mask");plt.axis('off');
# plt.tight_layout()
# plt.show();
# print(f"Micro Image: Min: {torch.min(micro_image)} | Max: {torch.max(micro_image)} | Shape: {micro_image.shape} | Unique: {len(torch.unique(micro_image))} | DType: {micro_image.dtype}")
# print(f"Micro Mask: Min: {torch.min(micro_mask)} | Max: {torch.max(micro_mask)} | Shape: {micro_mask.shape} | Unique: {len(torch.unique(micro_mask))} | DType: {micro_mask.dtype}")
# print(f"Hist Image: Min: {torch.min(hist_image)} | Max: {torch.max(hist_image)} | Shape: {hist_image.shape} | Unique: {len(torch.unique(hist_image))} | DType: {hist_image.dtype}")
# print(f"Hist Mask: Min: {torch.min(hist_mask)} | Max: {torch.max(hist_mask)} | Shape: {hist_mask.shape} | Unique: {len(torch.unique(hist_mask))} | DType: {hist_mask.dtype}")

---

## 2. Parameters <a id="parameters"><a>

Let's define all the parameters

---

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_of_folds = 6
saved_model_dir = "./saved_models/" #location where the trained models be saved
results_dir = "./results/" # directory where the result be saved
data_dir = "../data/processed_png_data/"
batch_size = 1
os.makedirs(saved_model_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

---

<div align="center">
    
## 3. Training Affine Registration Network <a id="affinenetwork"></a>

</div>

Let's define all the parameters

---
    
### 3.1. Defining the Affine Registration Network
    

In [None]:
# Let's define all the parameters
affine_model = AffineNet().to(device)
trainable_params = sum(p.numel() for p in affine_model.parameters() if p.requires_grad)
print(f"affine model trainable params: {trainable_params}")

---

### 3.2. Training the Affine Registration Network

---

In [None]:
def train_affine(affine_model, train_loader, val_loader, optimizer_affine, criterion_affine, device, path_to_save_the_model, num_epochs=10):
    min_loss = float('inf')  # Initialize minimum loss to infinity
    for epoch in range(num_epochs):
        # Initialize tqdm for the training loop
        train_loader_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", dynamic_ncols=True)
        
        for fixed_image, fixed_mask, moving_image, moving_mask in train_loader_iter:
            ############# AFFINE REGISTRATION ###############################
            optimizer_affine.zero_grad()
            # Pass masks through the network
            fixed_mask = fixed_mask.to(device)
            moving_mask = moving_mask.to(device)
            theta = affine_model((fixed_mask, moving_mask))
            # Apply transformation and compute deformed mask
            deformed_mask = apply_affine_transformation(moving_mask, theta)
            # Calculate SSD loss
            loss_affine = criterion_affine(fixed_mask, deformed_mask)
            # Backpropagate and update weights
            loss_affine.backward()
            optimizer_affine.step()
            
            # Update tqdm description with current loss
            train_loader_iter.set_postfix(loss=loss_affine.item())        
        
        # Save the model if loss decreased
        if loss_affine < min_loss:
            min_loss = loss_affine.item()
            print("Saving model with improved loss:", min_loss)
            torch.save(affine_model.state_dict(), path_to_save_the_model)


---


### 3.3. Initializing Training Process for Each Fold


---

In [None]:
for i in range(num_of_folds):
    print(f"Processing Fold {i+1} ...")
    
    # Dataset Preparation
    train_csv_path_file = os.path.join(data_dir, "Training_Label_Paths_For_Fold"+str(i+1)+".csv")
    test_csv_path_file = os.path.join(data_dir, "Testing_Label_Paths_For_Fold"+str(i+1)+".csv")
    
    train_dataset = ImageRegistrationDataset(train_csv_path_file)
    test_dataset = ImageRegistrationDataset(test_csv_path_file)
    print(f"Train Dataset: {len(train_dataset)} | Test Dataset: {len(test_dataset)}")
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # Defining Model
    affine_model = AffineNet().to(device)
    optimizer_affine = torch.optim.Adam(affine_model.parameters(), lr=0.0001)
    criterion_affine = ssd_loss
    path_to_save_the_model = os.path.join(saved_model_dir, "trained_affine_registration_model_for_Fold"+str(i+1)+".pth")
    train_affine(affine_model, train_dataloader, val_dataloader, optimizer_affine, criterion_affine, device, path_to_save_the_model, num_epochs=30)

---


### 3.4 Evaluation

Let's evaluate the training affine model on a single image see how it works!


---

In [None]:
def inference_single_image(affine_model, fixed_image, fixed_mask, moving_image, moving_mask, device):
    
    affine_model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        fixed_mask = fixed_mask.to(device)
        moving_mask = moving_mask.to(device)
        moving_image = moving_image.to(device)
        theta = affine_model((fixed_mask, moving_mask))
        deformed_image = apply_affine_transformation(moving_image, theta, mode="bilinear")
        
    return deformed_image

---

## 4. Inference for Affine Registration <a id = "inference_affine_registration"></a>

[OPTIONAL] If you like, you can check the affine registration network's performance by uncommenting the following cell:

---

In [None]:
# # Which fold are you interested in?
# fold = 1 # 1, 2, 3, 4, 5 or 6
# # which image index are you interested in to see the results?
# index = 2 # 0, 1, 2, or 3
# path_to_affine_model = os.path.join(saved_model_dir, "trained_affine_registration_model_for_Fold" + str(fold) + ".pth")

# dataset_csv_path_file = os.path.join(data_dir, "Testing_Label_Paths_For_Fold" + str(fold) + ".csv")
# dataset = ImageRegistrationDataset(dataset_csv_path_file)
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
# for batch in dataloader:
#     batch = batch
#     break
# fixed_image, fixed_mask, moving_image, moving_mask = batch
# fixed_image = fixed_image[index].unsqueeze(0).to(device)
# fixed_mask = fixed_mask[index].unsqueeze(0).to(device)
# moving_image = moving_image[index].unsqueeze(0).to(device)
# moving_mask = moving_mask[index].unsqueeze(0).to(device)
# # Loading the best model
# trained_affine_model = AffineNet().to(device)
# trained_affine_model.load_state_dict(torch.load(path_to_affine_model))
# deformed_image = inference_single_image(affine_model,fixed_image, fixed_mask, moving_image, moving_mask, device)

# import matplotlib.pyplot as plt
# plt.figure(figsize=(10,6));
# plt.subplot(131);plt.imshow(fixed_image[0].permute(1,2,0).detach().cpu());plt.axis('off');plt.title("Fixed Image")
# plt.subplot(132);plt.imshow(moving_image[0].permute(1,2,0).detach().cpu());plt.axis('off');plt.title("Moving Image")
# plt.subplot(133);plt.imshow(deformed_image[0].permute(1,2,0).detach().cpu());plt.axis('off');plt.title("Affine Deformed Image")
# plt.tight_layout()
# plt.show()

---
<div align="center">
    
## 5. Training Deformable Registration Network <a id="deformablenetwork"></a>
    
</div>
    
Let's define all the parameters

---
    
### 5.1. Defining the Deformable Registration Network

In [None]:
# Let's define all the parameters
deformable_model = DeformRegNet(in_channels=6, out_channels=2, init_features=4).to(device)
trainable_params = sum(p.numel() for p in deformable_model.parameters() if p.requires_grad)
print(f"Deformable Model Trainable Params: {trainable_params}")

---

### 5.2. Training the Deformable Registration Network

---

In [None]:
def train_DeformRegNet(affine_model, deformable_model, train_loader, val_loader, optimizer, criterion, device, path_to_save_the_model, results_path, stn, num_epochs=10):
    min_loss = float('inf')  # Initialize minimum loss to infinity    
    for epoch in range(num_epochs):
        # Initialize tqdm for the training loop
        train_loader_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", dynamic_ncols=True)
        for fixed_image, fixed_mask, moving_image, moving_mask in train_loader_iter:            
            optimizer.zero_grad()
            fixed_mask = fixed_mask.to(device)
            moving_mask = moving_mask.to(device)
            fixed_image = fixed_image.to(device)
            moving_image = moving_image.to(device)
            affine_theta = affine_model((fixed_mask, moving_mask))
            affine_deformed_image = apply_affine_transformation(moving_image, affine_theta, mode="bilinear")
            affine_deformed_mask = apply_affine_transformation(moving_mask, affine_theta)
            input_tensor = torch.cat([affine_deformed_image, fixed_image], dim=1)
            flow = deformable_model(input_tensor)
            registered_img = stn(affine_deformed_image, flow)
            loss_image = criterion(registered_img, fixed_image, flow)
            # For mask
            registered_mask = stn(affine_deformed_mask, flow)
            loss_label = nn.MSELoss()(registered_mask, fixed_mask)
            loss = loss_image + loss_label
            loss.backward()
            optimizer.step()            
            
            # Update tqdm description with current loss
            train_loader_iter.set_postfix(loss=loss.item()) 
        
        # Save the model if loss decreased
        if loss < min_loss:
            min_loss = loss.item()
            print("Saving model with improved loss:", min_loss)
            torch.save(deformable_model.state_dict(), path_to_save_the_model)
            
            # If you want to save intermediate results for debugging, uncomment the following lines
            ######## Saving Results (if you want to save results) ###########
            # save_image(moving_image, f'{results_path}/Epoch_{epoch}_Original_Moving_Images.png')
            # save_image(moving_mask, f'{results_path}/Epoch_{epoch}_Original_Moving_Masks.png')
            # save_image(fixed_image, f'{results_path}/Epoch_{epoch}_Original_Fixed_Images.png')
            # save_image(fixed_mask, f'{results_path}/Epoch_{epoch}_Original_Fixed_Masks.png')
            # save_image(affine_deformed_image, f'{results_path}/Epoch_{epoch}_Affine_Deformed_Moving_Image.png')
            # save_image(registered_img, f'{results_path}/Epoch_{epoch}_Registered_Images.png')    

            # torch.save(theta_tps, f'theta_tps_epoch_{epoch}.pt')
            # torch.save(affine_theta, f'affine_theta_epoch_{epoch}.pt')

---


### 5.3. Initializing Training Process for Each Fold


---

In [None]:
for i in range(num_of_folds):
    print(f"Processing Fold {i+1} ...")
    
    
    # Loading affine model
    print(f"Let's load Affine Trained Model for Fold {i+1}...")
    path_to_affine_model = os.path.join(saved_model_dir, "trained_affine_registration_model_for_Fold" + str(i+1) + ".pth")
    trained_affine_model = AffineNet().to(device)
    trained_affine_model.load_state_dict(torch.load(path_to_affine_model))
    
    # Results Directory
    results_path = os.path.join(results_dir, "Fold" + str(i))
    os.makedirs(results_path, exist_ok=True)
    
    # Dataset Preparation
    train_csv_path_file = os.path.join(data_dir, "Training_Label_Paths_For_Fold"+str(i+1)+".csv")
    test_csv_path_file = os.path.join(data_dir, "Testing_Label_Paths_For_Fold"+str(i+1)+".csv")
    
    train_dataset = ImageRegistrationDataset(train_csv_path_file)
    test_dataset = ImageRegistrationDataset(test_csv_path_file)
    print(f"Train Dataset: {len(train_dataset)} | Test Dataset: {len(test_dataset)}")
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # Defining Model
    deformable_model = DeformRegNet(in_channels=6, out_channels=2, init_features=4).to(device)
    stn = SpatialTransformer()
    optimizer = torch.optim.Adam(deformable_model.parameters(), lr=0.01)
    criterion = loss_function    
    path_to_save_deformable_model = os.path.join(saved_model_dir, "trained_deformable_registration_model_for_Fold"+str(i+1)+".pth")
    train_DeformRegNet(affine_model=trained_affine_model, 
                       deformable_model=deformable_model, 
                       train_loader=train_dataloader, 
                       val_loader=val_dataloader, 
                       optimizer=optimizer, 
                       criterion=criterion, 
                       device=device, 
                       path_to_save_the_model=path_to_save_deformable_model, 
                       results_path=results_path, 
                       stn=stn,
                       num_epochs=200)
    print('')
    print('--'*90)
    print('')

---


### 5.4 Evaluation

Let's evaluate the training affine model on a single image see how it works!


---

In [None]:
def deformable_inference_single_image(affine_model, deformable_model, fixed_image, fixed_mask, moving_image, moving_mask, stn, device):
    
    affine_model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        fixed_mask = fixed_mask.to(device)
        moving_mask = moving_mask.to(device)
        moving_image = moving_image.to(device)
        theta = affine_model((fixed_mask, moving_mask))
        affine_deformed_image = apply_affine_transformation(moving_image, theta)
        
    deformable_model.eval()
    with torch.no_grad():        
        fixed_image = fixed_image.to(device)
        input_tensor = torch.cat([affine_deformed_image, fixed_image], dim=1)
        flow = deformable_model(input_tensor)
        registered_img = stn(affine_deformed_image, flow)
        #######################################################
    return registered_img

---

## 6. Inference for Deformable Registration <a id="inference_deformable_registration"></a>

Let's define all the parameters

---

In [None]:
# Which fold are you interested in?
fold = 1 # 1, 2, 3, 4, 5 or 6
# which image index are you interested in to see the results?
index = 0 # 0, 1, 2, or 3

# Defining and Loading Trained Affine Model
path_to_affine_model = os.path.join(saved_model_dir, "trained_affine_registration_model_for_Fold" + str(fold) + ".pth")
trained_affine_model = AffineNet().to(device)
trained_affine_model.load_state_dict(torch.load(path_to_affine_model))


# Defining and Loading Trained Deformable Registration Network
path_to_deformable_model = os.path.join(saved_model_dir, "trained_deformable_registration_model_for_Fold"+str(fold)+".pth")
trained_deformable_model = DeformRegNet(in_channels=6, out_channels=2, init_features=4).to(device)
trained_deformable_model.load_state_dict(torch.load(path_to_deformable_model))
stn = SpatialTransformer()

dataset_csv_path_file = os.path.join(data_dir, "Testing_Label_Paths_For_Fold" + str(fold) + ".csv")
dataset = ImageRegistrationDataset(dataset_csv_path_file)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
for batch in dataloader:
    batch = batch
    break
fixed_image, fixed_mask, moving_image, moving_mask = batch
fixed_image = fixed_image[index].unsqueeze(0).to(device)
fixed_mask = fixed_mask[index].unsqueeze(0).to(device)
moving_image = moving_image[index].unsqueeze(0).to(device)
moving_mask = moving_mask[index].unsqueeze(0).to(device)
# Loading the best model
trained_affine_model = AffineNet().to(device)
trained_affine_model.load_state_dict(torch.load(path_to_affine_model))
registered_image = deformable_inference_single_image(affine_model=trained_affine_model, 
                                                   deformable_model=trained_deformable_model,
                                                   fixed_image=fixed_image, 
                                                   fixed_mask=fixed_mask, 
                                                   moving_image=moving_image, 
                                                   moving_mask=moving_mask, 
                                                   stn=stn, 
                                                   device=device)

import matplotlib.pyplot as plt
plt.figure(figsize=(10,6));
plt.subplot(131);plt.imshow(fixed_image[0].permute(1,2,0).detach().cpu());plt.axis('off');plt.title("Fixed Image")
plt.subplot(132);plt.imshow(moving_image[0].permute(1,2,0).detach().cpu());plt.axis('off');plt.title("Moving Image")
plt.subplot(133);plt.imshow(registered_image[0].permute(1,2,0).detach().cpu());plt.axis('off');plt.title("Registered Image")
plt.tight_layout()
plt.show()


def threshold_binary(image, threshold=0.001):
    """
    Threshold the input image to convert it into a binary mask.
    """
    return (image > threshold).float()

def dice_coefficient(image1, image2):
    """
    Compute the Dice coefficient between two binary masks.
    """
    intersection = torch.sum(image1 * image2)
    union = torch.sum(image1) + torch.sum(image2)
    dice = (2. * intersection) / (union + 1e-8)  # Add a small epsilon to avoid division by zero
    return dice

# Assuming registered_image, fixed_image, and moving_image are torch tensors
# and they are already thresholded if necessary
image_binary = threshold_binary(fixed_image)
registered_image_binary = threshold_binary(moving_image)
recovered_image_binary = threshold_binary(registered_image)


dice1 = dice_coefficient(image_binary.detach().cpu(), registered_image_binary.detach().cpu())
dice2 = dice_coefficient(image_binary.detach().cpu(), recovered_image_binary.detach().cpu())

print("Dice coefficient between registered_image and fixed_image:", dice1.item())
print("Dice coefficient between fixed_image and moving_image:", dice2.item())