# Paired Lung CT 3d registration with MONAI

This tutorial shows how to use MONAI to register CT images acquired at different time points for a single patient. The images being registered are taken at inspiration and expiration for each subject. This is an intra subject registration. This type of intra subject registration is useful when there is a need to track certain features on a medical image such as tumor location when conducting invasive procedures.

The usage of the following features are illustrated in this tutorial:
1. Load Nifti image with metadata
1. Transforms for dictionary format data
1. Build LocalNet
1. Warp an image with given dense displacement field (DDF) with Warp block
1. Compute DiceLoss and BendingEnergyLoss
1. Compute MeanDice metric

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_registration/paired_lung_ct.ipynb)

## Setup environment

In [None]:
%matplotlib inline

## Setup imports

In [None]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import matplotlib.pyplot as plt
import numpy as np
import shutil
import tempfile
import torch
from torch.nn import MSELoss
from monai.apps import download_url, download_and_extract
from monai.config import print_config
from monai.data import DataLoader, Dataset, CacheDataset
from monai.losses import BendingEnergyLoss, MultiScaleLoss, DiceLoss, LocalNormalizedCrossCorrelationLoss
from monai.metrics import compute_meandice, DiceMetric
from monai.networks.blocks import Warp
from monai.networks.nets import GlobalNet
from monai.transforms import (
    AddChanneld,
    Compose,
    LoadImaged,
    RandAffined,
    Resized,
    ScaleIntensityd,
    ToTensord,
    SqueezeDimd,
    Spacingd,
    EnsureTyped,
)
from monai.utils import set_determinism, first
from glob import glob
import itertools

    
print_config()

## Set dataset path

In [None]:
data_dir = os.path.join("datasets2", "IRIS")
mris = sorted(glob(os.path.join(data_dir, 'MRI_N4_Resample_Norm', "*.nii.gz")))
labels = sorted(glob(os.path.join(data_dir, 'Mask_Resample', "*.nii.gz")))
mri_number = 4
train_mris = mris[:mri_number]
val_mris = mris[mri_number:mri_number+2]
train_labels = labels[:mri_number]
val_labels = labels[mri_number:mri_number+2]

train_files = [
    {
        "fixed_image": train_mris[i],  
        "fixed_label": train_labels[i],
        
        "moving_image": train_mris[j],
        "moving_label": train_labels[j],
    }    
    for i, j in itertools.product(range(len(train_mris)), range(len(train_mris)))
]

val_files = [
    {
        "fixed_image": val_mris[i],                                    
        "moving_image": val_mris[j],
        "fixed_label": val_labels[i],
        "moving_label": val_labels[j],
    }    
    for i, j in itertools.product(range(len(val_mris)), range(len(val_mris)))
]

## Set deterministic training for reproducibility

In [None]:
#set_determinism(seed=0)

In [None]:
train_transforms = Compose(
    [
        LoadImaged(
            keys=["moving_image", "moving_label", "fixed_image", "fixed_label"]
        ),            
        AddChanneld(
            keys=["moving_image", "moving_label", "fixed_image", "fixed_label"]
        ),   
        ScaleIntensityd(
            keys=["moving_image", "fixed_image"],
            minv=0.0, maxv=1.0,
        ),
        RandAffined(           
            keys=["moving_image", "moving_label", "fixed_image", "fixed_label"],
            mode=('bilinear', 'nearest', 'bilinear', 'nearest'),
            prob=1.0,
            rotate_range=(0, 0, np.pi / 15),             
            scale_range=(0.1, 0.1, 0.1),
        ),
        EnsureTyped(
            keys=["moving_image", "moving_label", "fixed_image", "fixed_label"]
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(
            keys=["moving_image", "moving_label", "fixed_image", "fixed_label"]
        ),          
        AddChanneld(
            keys=["moving_image", "moving_label", "fixed_image", "fixed_label"]
        ),
        ScaleIntensityd(
            keys=["moving_image", "fixed_image"],
            minv=0.0, maxv=1.0,
        ), 
        EnsureTyped(
            keys=["moving_image", "moving_label", "fixed_image", "fixed_label"]
        ),
    ]
)

## Check transforms in DataLoader
Visualize a single batch to check the transforms.

In [None]:
check_ds = Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
fixed_image = check_data["fixed_image"][0][0]
fixed_label = check_data["fixed_label"][0][0]
moving_image = check_data["moving_image"][0][0]
moving_label = check_data["moving_label"][0][0]


print(f"moving_image shape: {moving_image.shape}, "
      f"moving_label shape: {moving_label.shape}")
print(f"fixed_image shape: {fixed_image.shape}, "
      f"fixed_label shape: {fixed_label.shape}")


slice = 40

plt.figure("check", (12, 6))
plt.subplot(1, 4, 1)
plt.title("moving_image")
plt.imshow(moving_image[:, :, slice], cmap="gray")
plt.subplot(1, 4, 2)
plt.title("moving_label")
plt.imshow(moving_label[:, :, slice])
plt.subplot(1, 4, 3)
plt.title("fixed_image")
plt.imshow(fixed_image[:, :, slice], cmap="gray")
plt.subplot(1, 4, 4)
plt.title("fixed_label")
plt.imshow(fixed_label[:, :, slice])

plt.show()
plt.show()

print(np.min(moving_image.numpy()))
print(np.max(moving_image.numpy()))
print(np.min(moving_label.numpy()))
print(np.max(moving_label.numpy()))
print(np.min(fixed_image.numpy()))
print(np.max(fixed_image.numpy()))
print(np.min(fixed_label.numpy()))
print(np.max(fixed_label.numpy()))

In [None]:
batch_size = 1
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4)

#### Create Model, Loss and Optimizer

In [None]:
#def forward(batch_data, model):
#    fixed_image = batch_data["fixed_image"].to(device)
#    moving_image = batch_data["moving_image"].to(device)
#    
#    # predict DDF through GlobalNet
#    ddf = model(torch.cat((moving_image, fixed_image), dim=1))
#
#    # warp moving image and label with the predicted ddf
#    pred_image = warp_layer(moving_image, ddf)
#
#    return ddf, pred_image

In [None]:
#max_epochs = 3
#val_interval = 1
#epoch_loss_values = []
#
#for epoch in range(max_epochs):
#    if (epoch + 1) % val_interval == 0 or epoch == 0:
#        model.eval()
#        with torch.no_grad():
#            for val_data in val_loader:
#
#                val_ddf, val_pred_image = forward(val_data, model)
#                val_fixed_image = val_data["fixed_image"].to(device)
#
#    print("-" * 10)
#    print(f"epoch {epoch + 1}/{max_epochs}")
#    model.train()
#    epoch_loss = 0
#    step = 0
#    for batch_data in train_loader:
#        step += 1
#        optimizer.zero_grad()
#
#        ddf, pred_image = forward(batch_data, model)
#        fixed_image = batch_data["fixed_image"].to(device)
#        
#        loss = image_loss(pred_image, fixed_image)
#        loss.backward()
#        optimizer.step()
#        epoch_loss += loss.item()
#
#    epoch_loss /= step
#    epoch_loss_values.append(epoch_loss)
#    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

In [None]:
device = torch.device("cuda:0")
model = GlobalNet(
    image_size=(128, 128, 128),
    spatial_dims=3,
    in_channels=2,
    num_channel_initial=16,
    depth=5).to(device)
warp_layer = Warp("bilinear", "border").to(device)
image_loss = LocalNormalizedCrossCorrelationLoss()
label_loss = DiceLoss()
label_loss = MultiScaleLoss(label_loss, scales=[0, 1, 2, 4, 8, 16])
regularization = BendingEnergyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
dice_metric = DiceMetric(include_background=True, reduction="mean")

In [None]:
max_epochs = 3
val_interval = 1
best_metric = -1
best_loss = 100000
epoch_loss_values = []
metric_values = []


for epoch in range(max_epochs):
    
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    
    train_loss = 0
    valid_loss = 0
    train_metric = 0
    valid_metric = 0
    train_lbl_loss = 0
    valid_lbl_loss = 0
    train_ddf_loss = 0
    valid_ddf_loss = 0
    train_img_loss = 0
    valid_img_loss = 0
        
    for phase in ['train', 'valid']:
        if phase == 'train':
            model.train()       
            loader = train_loader
        if phase == 'valid':
            model.eval()         
            loader = val_loader
            
        running_loss = 0.0
        running_metric = 0.0
        running_img_loss = 0.0
        running_lbl_loss = 0.0
        running_ddf_loss = 0.0
        
    
        for i, data in enumerate(loader):
            print(i, end='\r')
                
            optimizer.zero_grad()      
            
            with torch.set_grad_enabled(phase == 'train'):                
                
                fixed_image = data["fixed_image"].to(device) 
                #fixed_label = data["fixed_label"].to(device) 
                moving_image = data["moving_image"].to(device)
                #moving_label = data["moving_label"].to(device)

                ddf = model(torch.cat((moving_image, fixed_image), dim=1))
                
                pred_image = warp_layer(moving_image, ddf)     
                #pred_label = warp_layer(moving_label, ddf)
                               
                img_loss = 1 * image_loss(pred_image, fixed_image)
                #lbl_loss = 20 * label_loss(pred_label, fixed_label)
                ddf_loss = 10 * regularization(ddf)
                loss = img_loss + ddf_loss
                
                if phase == 'train':                    
                    loss.backward()
                    optimizer.step()
                    
            running_loss += loss.item()
            running_img_loss += img_loss.item()
            running_ddf_loss += ddf_loss.item()
            
        
        epoch_loss = running_loss / len(loader)
        epoch_img_loss = running_img_loss / len(loader)  
        epoch_ddf_loss = running_ddf_loss / len(loader)     
        
        if phase == 'train':
            train_loss = epoch_loss
            train_img_loss = epoch_img_loss
            train_ddf_loss = epoch_ddf_loss
        elif phase == 'valid':
            valid_loss = epoch_loss
            valid_img_loss = epoch_img_loss
            valid_ddf_loss = epoch_ddf_loss

        print(
            "{}: loss: {:.4f} -- img: {:.4f}, ddf: {:.4f}".format(
                phase, epoch_loss, epoch_img_loss, epoch_ddf_loss,
            )
        )

        if phase == 'valid':
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                best_epoch = epoch + 1
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    },
                    './models/registration/' + str('test_registration_resample.pth')
                )
                print(
                    "best loss {:.4f} at epoch {}".format(
                        best_loss, best_epoch
                    )
                )


In [None]:
depth = 5 * 10
plt.figure("check", (18, 6))        
plt.subplot(1, 3, 1)
plt.title(f"moving_image")
plt.imshow(moving_image.cpu().numpy()[0, 0, :, :, depth])

plt.subplot(1, 3, 2)
plt.title(f"fixed_image")
plt.imshow(fixed_image.cpu().numpy()[0, 0, :, :, depth])

plt.subplot(1, 3, 3)
plt.title(f"pred_image")
plt.imshow(pred_image.cpu().numpy()[0, 0, :, :, depth])
plt.show()