In [None]:
import argparse
import logging
import sys

import numpy as np
import torch
from monai.transforms import AsDiscrete, MaskIntensity, RandAffine, Affine
from monai.utils import set_determinism
from monai.losses import LocalNormalizedCrossCorrelationLoss
from torch.utils.tensorboard import SummaryWriter
import random

import utils_parser
from reg_data import getRegistrationDataset
from reg_model import getRegistrationModel
from utils import compute_mean_dice, getAdamOptimizer, getReducePlateauScheduler, loadExistingModel, getDevice
from utils import print_model_output, print_weights, add_weights_to_name, compute_landmarks_distance_local
from loss import get_deformable_registration_loss_from_weights, get_affine_registration_loss_from_weights, jacobian_loss, get_jacobian, antifolding_loss, JacobianDet
from models import TrilinearLocalNet
from torchinfo import summary
from miseval import evaluate

In [None]:
parser = argparse.ArgumentParser(description="Train 3D mouse brain registration model.")
parser.add_argument("-lr", "--learningrate", type=float, default=0.001, help="Specify learning rate.")
parser.add_argument("-p", "--patience", type=int, default=20, help="Specify patience for LR Plateau.")
parser.add_argument("-b", "--batchsize", type=int, default=1, help="Batch size for training")
parser.add_argument("-e", "--epochs", type=int, default=500, help="Max epochs for training")
parser.add_argument("-o", "--output", help="Model name for save")
parser.add_argument("-d", "--dataset", default='IRIS', help="Dataset name for training.")
parser.add_argument("-ft", "--finetuning", help="Load existing model for finetuning.")
parser.add_argument("-ct", "--continuetraining", help="Load existing model to continue training.")
parser.add_argument("-pt", "--pretraining", help="Load existing model for affine registration.")
parser.add_argument("-t", "--type", type=str, help="Specify affine/deformable/local registration")
parser.add_argument("-a", "--atlas", action='store_true', help="Perform to-atlas registration instead of paired registration")
parser.add_argument("-m", "--mask", action='store_true', help="Skullstrip dataset if available")
parser.add_argument("-w", "--weights", nargs='+', type=float, default=[1.0, 0, 2.0], help="Loss weights for 1) ImageLoss 2) LabelLoss 3) DDF. Default : [1,1,1]")
parser.add_argument("-newmodel", "--newmodel", action='store_true', help="True: Depth 5 Channels 32; False: Depth 4 Channels 16")
parser.add_argument("-validfeminad", "--validfeminad", action='store_true', help="True: Validate on Feminad with Landmarks")
parser.add_argument("-freeze", "--freeze", type=int, default=0, help="Freeze Xth layer")
parser.add_argument("-cycleconsistenttraining", "--cycleconsistenttraining", action='store_true', help="UseCycleConsistentTraining")
parser.add_argument("-affineconsistenttraining", "--affineconsistenttraining", action='store_true', help="UseAffineConsistentTraining")
parser.add_argument("-jacobianloss", "--jacobianloss", action='store_true', help="UseJacobianDetLoss")
parser.add_argument("-antifoldingloss", "--antifoldingloss", action='store_true', help="UseAntiFoldingLoss")
parser.add_argument("-ddf", "--ddf", action='store_true', help="Use DDF instead of DVF2DDF")

In [None]:
modelname = "fakedata"
dataset = "fakedata"
ft = None
ct = None
batchsize = 1
max_epochs = 500
lr = 0.001
patience
weights
registration_type
atlas
mask
pt
newmodel
validfeminad
freeze
cycle_consistent_training
use_jacobian_loss
affine_consistent_training
use_antifolding_loss
use_ddf



torch.multiprocessing.set_sharing_strategy('file_system')
torch.backends.cudnn.benchmark = True
modelname = add_weights_to_name(modelname, weights)
print_model_output(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()

if newmodel:
    channels = 32
    extract = [0, 1, 2, 3, 4]
else:
    channels = 16
    extract = [0, 1, 2, 3]
model = getRegistrationModel(registration_type, img_size=128, pretrain_model=pt,
                                 channels=channels, extract=extract, use_ddf=use_ddf)

optimizer = getAdamOptimizer(model, lr)
scheduler = getReducePlateauScheduler(optimizer, factor=0.5, patience=20)
weights = loadExistingModel(model, optimizer, ft, ct, weights=weights, registration=True)
print_weights(weights)

dataloaders, size = getRegistrationDataset(dataset=dataset,
                                           batch=batchsize,
                                           training=True,
                                           augment=True,
                                           eval_augment=False,
                                           atlas=atlas,
                                           mask=mask,
                                           validfeminad=validfeminad,
                                           )

best_loss = np.inf
best_epoch = -1

writer = SummaryWriter(comment='_'+modelname)
sizelol = len(dataloaders["train"])
print(sizelol)

reduceLReveryepochs = False

if freeze != 0:
    print("=> Freezing model")
    for name, p in model.named_parameters():
        p.requires_grad = False
for epoch in range(-1, max_epochs):
#for epoch in range(max_epochs):
    if freeze == 1:
        for name, p in model.named_parameters():
            if "encode_convs.2." in name:
                p.requires_grad = True
    if freeze == 2:
        for name, p in model.named_parameters():
            if "encode_convs.1." in name:
                p.requires_grad = True
            if "encode_convs.2." in name:
                p.requires_grad = True
    if freeze == 3:
        for name, p in model.named_parameters():
            if "encode_convs.1." in name:
                p.requires_grad = True
    if freeze == 4:
        for name, p in model.named_parameters():
            if "encode_convs.0." in name:
                p.requires_grad = True
            if "encode_convs.1." in name:
                p.requires_grad = True
            if "encode_convs.2." in name:
                p.requires_grad = True
    if freeze == 5:
        for name, p in model.named_parameters():
            if "encode_convs.1." in name:
                p.requires_grad = True
            if "encode_convs.2." in name:
                p.requires_grad = True
            if "encode_convs.3." in name:
                p.requires_grad = True
    if freeze == 6:
        for name, p in model.named_parameters():
            if "encode_convs.0." in name:
                p.requires_grad = True
            if "encode_convs.1." in name:
                p.requires_grad = True
            if "encode_convs.2." in name:
                p.requires_grad = True
            if "encode_convs.3." in name:
                p.requires_grad = True
    if freeze == 7:
        for name, p in model.named_parameters():
            if "encode_convs.1." in name:
                p.requires_grad = True
            if "encode_convs.2." in name:
                p.requires_grad = True
            if "encode_convs.3." in name:
                p.requires_grad = True
            if "bottom_block." in name:
                p.requires_grad = True
    if reduceLReveryepochs:
        if epoch != 0 and epoch % 400 == 0:
            weights = [weights[0], 0, weights[2]/2]
            print("New weights: " + str(weights))
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")

    train_loss, train_metric, train_dice, train_lbl_loss, train_img_loss, train_ddf_loss = 0, 0, 0, 0, 0, 0
    valid_loss, valid_metric, valid_dice, valid_lbl_loss, valid_img_loss, valid_ddf_loss = 0, 0, 0, 0, 0, 0
    if cycle_consistent_training:
        train_cycle_loss, valid_cycle_loss = 0, 0
    if use_jacobian_loss:
        train_jcb_loss, valid_jcb_loss = 0, 0
    if affine_consistent_training:
        train_aff_loss, valid_aff_loss = 0, 0
    if use_antifolding_loss:
        train_fold_loss, valid_fold_loss = 0, 0

    for phase in ['train', 'valid']:
        if epoch == -1 and phase == 'train':
            continue
        if pt is None:
            if phase == 'train':
                model.train()
            elif phase == 'valid':
                model.eval()
        else:
            for param in model.globalnet.parameters():
                param.requires_grad = False
            if phase == 'train':
                model.globalnet.eval()
                model.localnet.train()
            elif phase == 'valid':
                model.globalnet.eval()
                model.localnet.eval()

        running_loss = 0.0
        if cycle_consistent_training:
            running_cycle_loss = 0.0
        if use_jacobian_loss:
            running_jcb_loss = 0.0
        if affine_consistent_training:
            running_aff_loss = 0.0
        if use_antifolding_loss:
            running_fold_loss = 0.0
        running_metric = 0.0
        running_dice = 0.0
        running_img_loss = 0.0
        running_lbl_loss = 0.0
        running_ddf_loss = 0.0

        for i, data in enumerate(dataloaders[phase]):
            if i >= sizelol:
                break

            print(i, end='\r')
            optimizer.zero_grad(set_to_none=True)
            with torch.set_grad_enabled(phase == 'train'):
                if registration_type.lower() == 'affine' or registration_type.lower() == 'local':
                    ddf, pred_image, pred_label, dvf = model(data)
                elif registration_type.lower() == 'deformable':
                    affine_ddf, ddf, pred_image, pred_label, affine_image, affine_label = model(data)
                    #affine_image = affine_image.to(device, non_blocking=True)
                    #affine_label = affine_label.to(device, non_blocking=True)
                    #affine_mask = AsDiscrete(threshold=0.5)(affine_label)
                    #affine_image_masked = MaskIntensity(mask_data=affine_mask)(affine_image)

                pred_image = pred_image.to(device, non_blocking=True)
                pred_label = pred_label.to(device, non_blocking=True)
                pred_mask = AsDiscrete(threshold=0.5)(pred_label)
                pred_image_masked = MaskIntensity(mask_data=pred_mask)(pred_image)

                fixed_image = data['fixed_image'].to(device, non_blocking=True)
                fixed_label = data['fixed_label'].to(device, non_blocking=True)
                fixed_mask = AsDiscrete(threshold=0.5)(fixed_label)
                fixed_image_masked = MaskIntensity(mask_data=fixed_mask)(fixed_image)

                if "neatin" in dataset:
                    fixed_regions = data['fixed_regions'].to(device, non_blocking=True)
                    fixed_regions_np = fixed_regions.cpu().detach().numpy().squeeze()
                    moving_regions = data['moving_regions'].to(device, dtype=torch.float, non_blocking=True)
                    pred_regions = model.warp_nearest(moving_regions, ddf)
                    pred_regions_np = pred_regions.cpu().detach().numpy().squeeze()

                if registration_type.lower() == 'affine':
                    img_loss, lbl_loss, ddf_loss = get_affine_registration_loss_from_weights(pred_image_masked,
                                                                                             pred_mask,
                                                                                             fixed_image_masked,
                                                                                             fixed_mask,
                                                                                             weights)
                    loss = img_loss + lbl_loss
                elif registration_type.lower() == 'deformable' or registration_type.lower() == 'local':
                    if use_ddf:
                        img_loss, lbl_loss, ddf_loss = get_deformable_registration_loss_from_weights(pred_image_masked,
                                                                                                     pred_mask,
                                                                                                     fixed_image_masked,
                                                                                                     fixed_mask,
                                                                                                     ddf,
                                                                                                     weights)
                    else:
                        img_loss, lbl_loss, ddf_loss = get_deformable_registration_loss_from_weights(pred_image_masked,
                                                                                                     pred_mask,
                                                                                                     fixed_image_masked,
                                                                                                     fixed_mask,
                                                                                                     dvf,
                                                                                                     weights)

                    loss = img_loss + lbl_loss + ddf_loss
                    if use_jacobian_loss:
                        jcb_loss = jacobian_loss(ddf) / (128*128*128)
                        loss = loss + jcb_loss
                    if use_antifolding_loss:
                        fold_loss = antifolding_loss(ddf)
                        loss = loss + fold_loss

                if cycle_consistent_training and not affine_consistent_training:
                    cycle_data = {
                        "fixed_image": data["moving_image"],
                        "fixed_label": data["moving_label"],
                        "moving_image": pred_image,
                        "moving_label": pred_label,
                    }
                    if registration_type.lower() == 'local':
                        cycle_ddf, cycle_pred_image, cycle_pred_label, cycle_dvf = model(cycle_data)
                        cycle_pred_image = cycle_pred_image.to(device, non_blocking=True)
                        cycle_pred_label = cycle_pred_label.to(device, non_blocking=True)
                        cycle_pred_mask = AsDiscrete(threshold=0.5)(cycle_pred_label)

                        cycle_fixed_image = data["moving_image"].to(device, non_blocking=True)
                        cycle_fixed_label = data["moving_label"].to(device, non_blocking=True)
                        cycle_fixed_mask = AsDiscrete(threshold=0.5)(cycle_fixed_label)

                        if use_ddf:
                            cycle_img_loss, cycle_lbl_loss, cycle_ddf_loss = get_deformable_registration_loss_from_weights(cycle_pred_image,
                                                                                                     cycle_pred_mask,
                                                                                                     cycle_fixed_image,
                                                                                                     cycle_fixed_mask,
                                                                                                     cycle_ddf,
                                                                                                     weights)
                        else:
                            cycle_img_loss, cycle_lbl_loss, cycle_ddf_loss = get_deformable_registration_loss_from_weights(cycle_pred_image,
                                                                                                     cycle_pred_mask,
                                                                                                     cycle_fixed_image,
                                                                                                     cycle_fixed_mask,
                                                                                                     cycle_dvf,
                                                                                                     weights)
                        cycle_loss = cycle_img_loss + cycle_lbl_loss + cycle_ddf_loss
                        if use_jacobian_loss:
                            cycle_jcb_loss = jacobian_loss(cycle_ddf) / (128*128*128)
                            cycle_loss = cycle_loss + cycle_jcb_loss
                        if use_antifolding_loss:
                            cycle_fold_loss = antifolding_loss(cycle_ddf)
                            cycle_loss = cycle_loss + cycle_fold_loss
                        loss = loss + cycle_loss
                if not cycle_consistent_training and affine_consistent_training:
                    randaffine_transform = RandAffine(
                        mode='bilinear',
                        prob=1.0,
                        rotate_range=(np.pi/16, np.pi/16, np.pi/16),
                        scale_range=(0.05, 0.05, 0.05),
                        translate_range=(5, 5, 5),
                    )
                    randaffine_moving_image = randaffine_transform(data["moving_image"][0, :, :, :, :]).unsqueeze(0)
                    randaffine_matrix = randaffine_transform.rand_affine_grid.get_transformation_matrix()
                    randaffine_transform_nearest = Affine(
                        mode='nearest',
                        affine=randaffine_matrix
                    )
                    randaffine_moving_label, _ = randaffine_transform_nearest(data["moving_label"][0, :, :, :, :])
                    randaffine_moving_label = randaffine_moving_label.unsqueeze(0)
                    randaffine_data = {
                        "fixed_image": fixed_image,
                        "fixed_label": fixed_label,
                        "moving_image": randaffine_moving_image,
                        "moving_label": randaffine_moving_label,
                    }

                    if registration_type.lower() == 'local':
                        randaffine_ddf, randaffine_pred_image, randaffine_pred_label, _ = model(randaffine_data)

                        # u1 - A o u2 = 0
                        # => pred1 - pred2 = 0
                        randaffine_pred_image = randaffine_pred_image.to(device, non_blocking=True)
                        randaffine_pred_label = randaffine_pred_label.to(device, non_blocking=True)

                        aff_loss = LocalNormalizedCrossCorrelationLoss()(pred_image, randaffine_pred_image)
                        loss = loss + aff_loss


                dice_metric = compute_mean_dice(pred_mask, fixed_mask)
                if registration_type.lower() == 'local' and phase == 'valid' and (validfeminad or "feminad" in dataset):
                    #metric = np.mean(compute_landmarks_distance_local(ddf, data))
                    metric = np.mean(compute_landmarks_distance_local(ddf, data)[1:])
                elif registration_type.lower() == 'local' and "neatin" in dataset:
                    metric = evaluate(fixed_regions_np, pred_regions_np, metric="DSC", multi_class=True, n_classes=41)
                    metric = np.mean(metric)
                else:
                    metric = torch.zeros(1)

                if phase == 'train':
                    # scaler.scale(loss).backward()
                    # scaler.step(optimizer)
                    # scaler.update()
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * fixed_image.size(0)
            if cycle_consistent_training:
                running_cycle_loss += cycle_loss.item() * fixed_image.size(0)
            if use_jacobian_loss:
                running_jcb_loss += jcb_loss.item() * fixed_image.size(0)
            if affine_consistent_training:
                running_aff_loss += aff_loss.item() * fixed_image.size(0)
            if use_antifolding_loss:
                running_fold_loss += fold_loss.item() * fixed_image.size(0)
            running_metric += metric.item() * fixed_image.size(0)
            running_dice += dice_metric.item() * fixed_image.size(0)
            running_img_loss += img_loss.item() * fixed_image.size(0)
            running_lbl_loss += lbl_loss.item() * fixed_image.size(0)
            running_ddf_loss += ddf_loss.item() * fixed_image.size(0)


        running_loss /= sizelol
        if cycle_consistent_training:
            running_cycle_loss /= sizelol
        if use_jacobian_loss:
            running_jcb_loss /= sizelol
        if affine_consistent_training:
            running_aff_loss /= sizelol
        if use_antifolding_loss:
            running_fold_loss /= sizelol
        running_metric /= sizelol
        running_dice /= sizelol
        running_img_loss /= sizelol
        running_lbl_loss /= sizelol
        running_ddf_loss /= sizelol

        if phase == 'train':
            train_loss, train_metric, train_img_loss, train_lbl_loss, train_ddf_loss = (
                running_loss, running_metric, running_img_loss, running_lbl_loss, running_ddf_loss)
            train_dice = running_dice
            if cycle_consistent_training:
                train_cycle_loss = running_cycle_loss
            if use_jacobian_loss:
                train_jcb_loss = running_jcb_loss
            if affine_consistent_training:
                train_aff_loss = running_aff_loss
            if use_antifolding_loss:
                train_fold_loss = running_fold_loss
        elif phase == 'valid':
            valid_loss, valid_metric, valid_img_loss, valid_lbl_loss, valid_ddf_loss = (
                running_loss, running_metric, running_img_loss, running_lbl_loss, running_ddf_loss)
            valid_dice = running_dice
            if cycle_consistent_training:
                valid_cycle_loss = running_cycle_loss
            if use_jacobian_loss:
                valid_jcb_loss = running_jcb_loss
            if affine_consistent_training:
                valid_aff_loss = running_aff_loss
            if use_antifolding_loss:
                valid_fold_loss = running_fold_loss

        outmessage = "{}: loss: {:.4f} - metric: {:.4f} -- img: {:.4f}, lbl: {:.4f}, ddf: {:.4f}".format(
                phase, running_loss, running_metric, running_img_loss, running_lbl_loss, running_ddf_loss)
        if cycle_consistent_training:
            outmessage += " -- cycle: {:.4f}".format(running_cycle_loss)
        if use_jacobian_loss:
            outmessage += " -- jac: {:.4f}".format(running_jcb_loss)
        if affine_consistent_training:
            outmessage += " -- aff: {:.4f}".format(running_aff_loss)
        if use_antifolding_loss:
            outmessage += " -- fold: {:.4f}".format(running_fold_loss)
        outmessage += " -- dice: {:.4f}".format(running_dice)
        print(outmessage)

        if (phase == 'valid' and not validfeminad) or (phase == 'train' and validfeminad):
            #scheduler.step(running_loss)
            if running_loss < best_loss:
                best_loss = running_loss
                best_epoch = epoch + 1
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'weights': weights,
                    'epoch': epoch,
                    'lr': lr,
                    },
                    './models/' + modelname
                )
                print(
                    "best loss {:.4f} at epoch {}".format(
                        best_loss, best_epoch
                    )
                )
        #if epoch % 25 == 0 and epoch > 0:
        #    torch.save({
        #        'model_state_dict': model.state_dict(),
        #        'optimizer_state_dict': optimizer.state_dict(),
        #        'weights': weights,
        #        'epoch': epoch,
        #        'lr': lr,
        #        },
        #        './models/' + modelname.split('.pth')[0] + '_epoch' + str(epoch) + '.pth'
        #    )
    writer.add_scalars('epoch_loss', {
        'train': train_loss,
        'valid': valid_loss,
    }, epoch + 1)
    writer.add_scalars('epoch_metric', {
        'train': train_metric,
        'valid': valid_metric,
    }, epoch + 1)
    writer.add_scalars('epoch_dice', {
        'train': train_dice,
        'valid': valid_dice,
    }, epoch + 1)
    writer.add_scalars('epoch_lbl_loss', {
        'train': train_lbl_loss,
        'valid': valid_lbl_loss,
    }, epoch + 1)
    writer.add_scalars('epoch_img_loss', {
        'train': train_img_loss,
        'valid': valid_img_loss,
    }, epoch + 1)
    writer.add_scalars('epoch_ddf_loss', {
        'train': train_ddf_loss,
        'valid': valid_ddf_loss,
    }, epoch + 1)
    if cycle_consistent_training:
        writer.add_scalars('epoch_cycle_loss', {
            'train': train_cycle_loss,
            'valid': valid_cycle_loss,
        }, epoch + 1)
    if use_jacobian_loss:
        writer.add_scalars('epoch_jcb_loss', {
            'train': train_jcb_loss,
            'valid': valid_jcb_loss,
        }, epoch + 1)
    if affine_consistent_training:
        writer.add_scalars('epoch_aff_loss', {
            'train': train_aff_loss,
            'valid': valid_aff_loss,
        }, epoch + 1)
    if use_antifolding_loss:
        writer.add_scalars('epoch_fold_loss', {
            'train': train_fold_loss,
            'valid': valid_fold_loss,
        }, epoch + 1)

print(f"train completed, "
      f"best_loss: {best_loss:.4f}  "
      f"at epoch: {best_epoch}")
writer.close()