import necessary packages

In [None]:
from torch.utils.tensorboard import SummaryWriter
import os
import sys
import glob

from torch.utils.data import DataLoader
import numpy as np
import torch
from torch import optim
import matplotlib.pyplot as plt
from natsort import natsorted
from torchvision import transforms

from MIR.models import SpatialTransformer, EncoderFeatureExtractor, SITReg, VFA, TransMorphTVF, TransMorph
from MIR.models.SITReg import ReLUFactory, GroupNormalizerFactory
from MIR.models.SITReg.composable_mapping import DataFormat
from MIR.models.SITReg.deformation_inversion_layer.fixed_point_iteration import (
    AndersonSolver,
    AndersonSolverArguments,
    MaxElementWiseAbsStopCriterion,
    RelativeL2ErrorStopCriterion,
)
import MIR.models.configs_TransMorph as configs_TransMorph
import MIR.models.configs_VFA as CONFIGS_VFA

Define image size

In [None]:
H, W, D = 160, 192, 224

Initialize models
We begin with TransMorph

In [None]:
scale_factor = 1
config = configs_TransMorph.get_3DTransMorph3Lvl_config()
config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
config.window_size = (H // 64, W // 64, D // 64)
config.out_chan = 3
print(config)
TM_model = TransMorph(config).cuda('cuda:0')

Then TransMorph-TVF

In [None]:
scale_factor = 2
config = configs_TransMorph.get_3DTransMorph3Lvl_config()
config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
config.window_size = (H // 64, W // 64, D // 64)
config.out_chan = 3
print(config)
TMTVF_model = TransMorphTVF(config, time_steps=7).cuda('cuda:0')

Then VFA

In [None]:
scale_factor = 1
config = CONFIGS_VFA.get_VFA_default_config()
config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
print(config)
model = VFA(config, device='cuda:0')

Then SITReg

In [None]:
INPUT_SHAPE = (H, W, D)
def create_model() -> SITReg:
    """Create SITReg model from config"""
    feature_extractor = EncoderFeatureExtractor(
            n_input_channels=1,
            activation_factory=ReLUFactory(),
            n_features_per_resolution=[12, 16, 32, 64, 128, 128],
            n_convolutions_per_resolution=[2, 2, 2, 2, 2, 2],
            input_shape=INPUT_SHAPE,
            normalizer_factory=GroupNormalizerFactory(2),
        ).cuda()
    AndersonSolver_forward = AndersonSolver(
        MaxElementWiseAbsStopCriterion(min_iterations=2, max_iterations=50, threshold=1e-2),
        AndersonSolverArguments(memory_length=4),
    )
    AndersonSolver_backward = AndersonSolver(
        RelativeL2ErrorStopCriterion(min_iterations=2, max_iterations=50, threshold=1e-2),
        AndersonSolverArguments(memory_length=4),
    )
    network = SITReg(
        feature_extractor=feature_extractor,
        n_transformation_convolutions_per_resolution=[2, 2, 2, 2, 2, 2],
        n_transformation_features_per_resolution=[12, 64, 128, 256, 256, 256],
        max_control_point_multiplier=0.99,
        affine_transformation_type=None,
        input_voxel_size=(1.0, 1.0, 1.0),
        input_shape=INPUT_SHAPE,
        transformation_downsampling_factor=(1.0, 1.0, 1.0),
        forward_fixed_point_solver=AndersonSolver_forward,
        backward_fixed_point_solver=AndersonSolver_backward,
        activation_factory=ReLUFactory(),
        normalizer_factory=GroupNormalizerFactory(4),
            ).cuda()
    return network

SITReg_model = create_model().cuda('cuda:0')