GradICON deformable registration of FA images. (WIP)

In [1]:
import os
import glob
import random
from collections import namedtuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import monai
import torch
import torch.nn

import footsteps
import pickle

In [2]:
device = torch.device('cuda')
spatial_size = (144,144,144)
num_scales = 4

The input images are known to be $140\times140\times140$, and we will pad them out to $144$ in each dimension.

We assume that each dimension in `spatial_size` is divisible by $2^{\texttt{num}\_\texttt{scales}-1}$, because we will downsample by a factor of $2$ a bunch of times to produce images at different scales.

In [3]:
fa_dir = 'dti_fit_images/fa'
fa_keys = [f'fa{i}' for i in range(num_scales)]
fa_key = fa_keys[0] # a simpler way to refer to the first element of fa_keys
data = [{fa_key:path, "filename":os.path.basename(path)} for path in glob.glob(os.path.join(fa_dir,'*'))]
data_train, data_valid = monai.data.utils.partition_dataset(data, ratios=(8,2))

`fa_keys` is a list mapping index to key for scale at that index: $0$ is the base resolution, $1$ is downscaled by a factor of $2$, $2$ is further downscaled by a factor of $2$, etc.

In [4]:
base_transforms = [
    monai.transforms.LoadImageD(keys=fa_key),
    monai.transforms.AddChannelD(keys=fa_key),
    monai.transforms.SpatialPadD(keys=fa_key, spatial_size=spatial_size, mode="constant"),
    monai.transforms.ToTensorD(keys=fa_key),
    monai.transforms.ToDeviceD(keys=fa_key, device=device),
]

In [5]:
# Control the overall scale of affine transform
a=0.5

S = spatial_size[0]

rand_affine_params = {
    'prob':1.,
    'mode': 'bilinear',
    'padding_mode': 'zeros',
    'spatial_size':spatial_size,
    'cache_grid':True,
    'rotate_range': (a*np.pi/2,)*3,
    'shear_range': (0,)*6, # no shearing
    'translate_range': (a*S/16,)*3,
    'scale_range': (a*0.4,)*3,
}

rand_affine_transform = monai.transforms.RandAffineD(keys=fa_key, **rand_affine_params)

In [6]:
add_scales_transforms = [
    monai.transforms.CopyItemsD(keys=[fa_key], times=(num_scales-1), names=fa_keys[1:])
]
add_scales_transforms += [
    monai.transforms.ResizeD(keys=[fa_keys[i]], spatial_size=[s//2**i for s in spatial_size])
    for i in range(1,num_scales)
]

The `add_scales_transforms` is a chain of transforms that adds downsampled versions of the base images, with keys coming from `fa_keys`.

In [7]:
transform_valid = monai.transforms.Compose(base_transforms + add_scales_transforms)
transform_train = monai.transforms.Compose(base_transforms + [rand_affine_transform] + add_scales_transforms)

In [8]:
ds_train = monai.data.CacheDataset(data_train, transform_train)
ds_valid = monai.data.CacheDataset(data_valid, transform_valid)

Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 68/68 [00:05<00:00, 12.83it/s]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 25.98it/s]


In [9]:
warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="zeros")

def mse_loss(b1, b2):
    """Return image similarity loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D).
    It is scaled up a bit here."""
    return 10000*((b1-b2)**2).mean()

def ncc_loss(b1, b2):
    """Return the negative NCC loss given two batches b1 and b2 of shape (batch_size, channels, H,W,D).
    It is averaged over batches and channels."""
    mu1 = b1.mean(dim=(2,3,4)) # means
    mu2 = b2.mean(dim=(2,3,4))
    alpha1 = (b1**2).mean(dim=(2,3,4)) # second moments
    alpha2 = (b2**2).mean(dim=(2,3,4))
    alpha12 = (b1*b2).mean(dim=(2,3,4)) # cross term
    numerator = alpha12 - mu1*mu2
    denominator = torch.sqrt((alpha1 - mu1**2) * (alpha2-mu2**2))
    ncc = numerator / denominator
    return -ncc.mean() # average over batches and channels

def compose_ddf(u,v):
    """Compose two displacement fields, return the displacement that warps by v followed by u"""
    return u + warp(v,u)

_, H, W, D = ds_train[0][fa_key].shape

# Compute discrete spatial derivatives
def diff_and_trim(array, axis):
    """Take the discrete difference along a spatial axis, which should be 2,3, or 4.
    Return a difference tensor with all spatial axes trimmed by 1."""
    return torch.diff(array, axis=axis)[:, :, :(H-1), :(W-1), :(D-1)]

def size_of_spatial_derivative(u):
    """Return the squared Frobenius norm of the spatial derivative of the given displacement field.
    To clarify, this is about the derivative of the actual displacement field map, not the deformation
    that the displacement field map defines. The expected input shape is (batch,3,H,W,D).
    Output shape is (batch)."""
    dx = diff_and_trim(u, 2)
    dy = diff_and_trim(u, 3)
    dz = diff_and_trim(u, 4)
    return(dx**2 + dy**2 + dz**2).sum(axis=1).mean(axis=[1,2,3])

monai.networks.blocks.Warp: Using PyTorch native grid_sample.


In [10]:
ModelOutput = namedtuple("ModelOutput", "all_loss, sim_loss, gradicon_loss, deformation_AB")

class Model(torch.nn.Module):
    def __init__(self, lambda_reg, compute_sim_loss):
        super().__init__()
        self.reg_net = monai.networks.nets.UNet(
            3,  # spatial dims
            2,  # input channels (one for fixed image and one for moving image)
            3,  # output channels (to represent 3D displacement vector field)
            (32, 32, 32, 32, 64),  # channel sequence
            (2, 2, 2, 2),  # convolutional strides
            dropout=0.2,
            norm="batch"
        )
        self.lambda_reg = lambda_reg
        self.compute_sim_loss = compute_sim_loss
    
    def update_lambda_reg(self, new_lambda_reg):
        self.lambda_reg = new_lambda_reg

    def forward(self, img_A, img_B) -> ModelOutput:
        img_pair_AB = torch.cat((img_A, img_B), dim=1)
        img_pair_BA = img_pair_AB[:,[1,0]]

        deformation_AB = self.reg_net(img_pair_AB) # deforms img_B to the space of img_A
        deformation_BA = self.reg_net(img_pair_BA) # deforms img_A to the space of img_B

        img_B_warped = warp(img_B, deformation_AB)
        img_A_warped = warp(img_A, deformation_BA)
        sim_loss_A = self.compute_sim_loss(img_A, img_B_warped)
        sim_loss_B = self.compute_sim_loss(img_B, img_A_warped)
        composite_deformation_A = compose_ddf(deformation_AB, deformation_BA)
        composite_deformation_B = compose_ddf(deformation_BA, deformation_AB)
        gradicon_loss_A = size_of_spatial_derivative(composite_deformation_A).mean()
        gradicon_loss_B = size_of_spatial_derivative(composite_deformation_B).mean()
        
        sim_loss = sim_loss_A + sim_loss_B
        gradicon_loss = gradicon_loss_A + gradicon_loss_B
        
        return ModelOutput(
            all_loss = sim_loss + self.lambda_reg * gradicon_loss,
            sim_loss = sim_loss,
            gradicon_loss = gradicon_loss,
            deformation_AB = deformation_AB
        )

In [11]:
# A resize transform that operates on batches of images
class BatchResizer:
    def __init__(self, spatial_dims):
        self.resize = monai.transforms.Resize(spatial_dims)
    def __call__(self, batch):
        return torch.stack([self.resize(x) for x in monai.transforms.Decollated()(batch)])
    
# A multiscale version of the Model idea above
class MultiscaleModel(torch.nn.Module):
    def __init__(self, lambda_reg, compute_sim_loss, num_subnetworks):
        super().__init__()
        self.num_subnetworks = num_subnetworks
        self.reg_nets = torch.nn.ModuleList()
        for i in range(num_subnetworks):
            # i is scale. i=0 is the original input image scale. Scale i is at a downsample factor of 2**i.
            n = 4 # Amount of down-convolution for the original image size.
            # (We will assume that the original image size is divisible by 2**n.)
            num_twos = n-i # The number of 2's we will put in the sequence of convolutional strides.
            num_ones = min(i,n-i) # The number of 1's
            stride_sequence = (1,2)*num_ones + (2,)*(num_twos-num_ones)
            channel_sequence = [min(8*2**c,64) for c in range(num_twos+num_ones+1)]
            self.reg_nets.append(
                monai.networks.nets.UNet(
                    3,  # spatial dims
                    2,  # input channels (one for fixed image and one for moving image)
                    3,  # output channels (to represent 3D displacement vector field)
                    channel_sequence,
                    stride_sequence,
                    dropout=0.2,
                    norm="batch"
                )
            )
        self.lambda_reg = lambda_reg
        self.compute_sim_loss = compute_sim_loss
        
        self.batch_resizers = [BatchResizer([s//2**i for s in spatial_size]) for i in range(num_subnetworks-1)]
    
    def update_lambda_reg(self, new_lambda_reg):
        self.lambda_reg = new_lambda_reg
        
    def multiscale_reg_nets(self, img_A, img_B):
        """
        Here we expect img_A to be a list consisting of batches of target images:
            img_A[0] is a batch of target images at the original resolution,
            img_A[1] is a batch of target images downsampled by a factor of 2 in each dimension,
            img_A[2] is a batch of target images downsampled by a factor of 4 in each dimension,
            etc.
        and similarly img_B is a list consisting of batches of moving images.
        Returns the final displacement field (composed over all scales) for deforming img_B[0] to img_A[0].
        """
        
        i = self.num_subnetworks - 1
        phi = self.reg_nets[i](torch.cat([img_A[i], img_B[i]], dim=1)) # Warp from scale i, operating at scale i
        phi_up = self.batch_resizers[i-1](phi) # Warp from scale i, operating at scale i-1
        
        for i in range(self.num_subnetworks - 1, -1, -1): # Run backwards to 0 from num_subnetworks-1
            
            # phi_up = Composite of warps up to scale i+1, operating at scale i
            if i==self.num_subnetworks - 1:
                pass # Base case: phi_up is the identity map.
                # (We treat this case specially below to avoid complicating the computational graph with
                # useless compositions with identity map)
            else:
                phi_up = self.batch_resizers[i](phi_comp)
            
            # warped_B = img_B at scale i warped by the composite of warps up to scale i+1
            if i==self.num_subnetworks - 1:
                warped_B = img_B[i] # Base case: "the composite of warps up to scale i+1" = the identity map
            else:
                warped_B = warp(img_B[i], phi_up)
            
            # phi = Warp from scale i, operating at scale i
            phi = self.reg_nets[i](torch.cat([img_A[i], warped_B], dim=1))

            # phi_comp = Composite of warps up to scale i, operating at scale i
            if i==self.num_subnetworks - 1:
                phi_comp = phi # Base case: phi_up = the identity map, i.e. chain to compose consists of phi only
            else:
                phi_comp = compose_ddf(phi,phi_up)
        
        return phi_comp
        
    def forward(self, img_A, img_B) -> ModelOutput:
        """
        Here we expect img_A to be a list consisting of batches of target images:
            img_A[0] is a batch of target images at the original resolution,
            img_A[1] is a batch of target images downsampled by a factor of 2 in each dimension,
            img_A[2] is a batch of target images downsampled by a factor of 4 in each dimension,
            etc.
        and similarly img_B is a list consisting of batches of moving images.
        """
        deformation_AB = self.multiscale_reg_nets(img_A, img_B) # deforms img_B to the space of img_A
        deformation_BA = self.multiscale_reg_nets(img_B, img_A) # deforms img_A to the space of img_B

        img_B0_warped = warp(img_B[0], deformation_AB)
        img_A0_warped = warp(img_A[0], deformation_BA)
        sim_loss_A = self.compute_sim_loss(img_A[0], img_B0_warped)
        sim_loss_B = self.compute_sim_loss(img_B[0], img_A0_warped)
        composite_deformation_A = compose_ddf(deformation_AB, deformation_BA)
        composite_deformation_B = compose_ddf(deformation_BA, deformation_AB)
        gradicon_loss_A = size_of_spatial_derivative(composite_deformation_A).mean()
        gradicon_loss_B = size_of_spatial_derivative(composite_deformation_B).mean()
        
        sim_loss = sim_loss_A + sim_loss_B
        gradicon_loss = gradicon_loss_A + gradicon_loss_B
        
        return ModelOutput(
            all_loss = sim_loss + self.lambda_reg * gradicon_loss,
            sim_loss = sim_loss,
            gradicon_loss = gradicon_loss,
            deformation_AB = deformation_AB
        )

In [12]:
class JacobianDeterminant(torch.nn.Module):
    """Given a batch of displacement vector fields vf, compute the jacobian determinant scalar field."""

    def __init__(self, spatial_dims):
        super().__init__()
        self.spatial_dims = spatial_dims
    def diff_and_trim(self, array, axis):
        H,W,D = self.spatial_dims
        return torch.diff(array, axis=axis)[:, :, :(H-1), :(W-1), :(D-1)]

    def forward(self, vf):
        """
        vf is assumed to be a vector field of shape (B,3,H,W,D),
        and it is interpreted as a displacement field.
        So it is defining a batch of discretely sampled maps from a subset of 3-space into 3-space,
        namely (for batch index b) the map that sends point (x,y,z) to the point (x,y,z)+vf[b,:,x,y,z].
        This function computes a jacobian determinant by taking discrete differences in each spatial direction.

        Returns a numpy array of shape (b,H-1,W-1,D-1).
        """
        dx = self.diff_and_trim(vf, 2)
        dy = self.diff_and_trim(vf, 3)
        dz = self.diff_and_trim(vf, 4)
        
        # Add derivative of identity map
        dx[:,0] += 1
        dy[:,1] += 1
        dz[:,2] += 1

        # Compute determinant at each spatial location
        det = dx[:,0]*(dy[:,1]*dz[:,2]-dz[:,1]*dy[:,2]) - dy[:,0]*(dx[:,1]*dz[:,2] -
                                                dz[:,1]*dx[:,2]) + dz[:,0]*(dx[:,1]*dy[:,2]-dy[:,1]*dx[:,2])

        return det

In [13]:
class LossCurves:
    def __init__(self, name : str, include_folds : bool = False, spatial_dims : tuple = None):
        self.name = name
        
        self.epochs =[]
        self.all_losses = []
        self.sim_losses = []
        self.gradicon_losses = []
        
        self.include_folds = include_folds
        if include_folds:
            if spatial_dims is None:
                raise Exception("Need argument spatial_dims to include fold count.")
            self.jacobian_determinant = JacobianDeterminant(spatial_dims)
            self.fold_counts = []
        
        self.clear_buffers()
        
    def clear_buffers(self):
        self.all_losses_buffer = []
        self.sim_losses_buffer = []
        self.gradicon_losses_buffer = []
        
        if self.include_folds:
            self.fold_counts_buffer = []
        
    def add_to_buffer(self, model_output : ModelOutput):
        self.all_losses_buffer.append(model_output.all_loss.item())
        self.sim_losses_buffer.append(model_output.sim_loss.item())
        self.gradicon_losses_buffer.append(model_output.gradicon_loss.item())
        
        if self.include_folds:
            det = self.jacobian_determinant(model_output.deformation_AB)
            num_folds = (det<0).sum(dim=(1,2,3))
            num_folds_mean = num_folds.to(dtype=torch.float).mean().item() # average over batch
            self.fold_counts_buffer.append(num_folds_mean)
        
    def aggregate_buffers_for_epoch(self, epoch : int):
        self.epochs.append(epoch)
        self.all_losses.append(np.mean(self.all_losses_buffer))
        self.sim_losses.append(np.mean(self.sim_losses_buffer))
        self.gradicon_losses.append(np.mean(self.gradicon_losses_buffer))
        if self.include_folds:
            self.fold_counts.append(np.mean(self.fold_counts_buffer))
        self.clear_buffers()
        
    def plot(self, savepath=None):
        fig, axs = plt.subplots(1,3 if not self.include_folds else 4,figsize = (15,5))
        axs[0].plot(self.epochs, self.all_losses)
        axs[0].set_title(f"{self.name}: overall loss")
        axs[1].plot(self.epochs, self.sim_losses)
        axs[1].set_title(f"{self.name}: similarity loss")
        axs[2].plot(self.epochs, self.gradicon_losses, label="gradicon loss")
        axs[2].set_title(f"{self.name}: gradicon loss")
        if self.include_folds:
            axs[3].plot(self.epochs, self.fold_counts, label="average folds")
            axs[3].set_title(f"{self.name}: average folds")
        for ax in axs:
            ax.set_xlabel("epoch")
        if savepath is not None:
            plt.savefig(savepath)
        plt.show()

In [14]:
lncc_loss = monai.losses.LocalNormalizedCrossCorrelationLoss(
    spatial_dims=3,
    kernel_size=5,
    smooth_nr = 0,
    smooth_dr = 1e-4
)

In [17]:
def batch_to_scales_list(b):
    """Given a batch from dl_train or dl_valid, return the list of images at different scales that
    would be suitable as input to a MultiscaleModel."""
    return list(map(lambda k : b[k], fa_keys))

In [15]:
dl_train = monai.data.DataLoader(ds_train, shuffle=True, batch_size=1, drop_last=True)
dl_valid = monai.data.DataLoader(ds_valid, shuffle=True, batch_size=2, drop_last=True)
max_epochs = 300
validate_when = lambda e : ((e%2==0) and (e!=0)) or (e==max_epochs-1)
lambda_reg_step_size = 0.01 # How much to increase lambda_reg each time it advances
cooldown = 5 # How many epochs to allow before checking whether training loss increases and advacing lambda_reg if so
cooldown_counter = cooldown
lambda_reg_goal = 0.2 # Stop training once lambda_reg advances past this
model = MultiscaleModel(
    lambda_reg = 0,
    compute_sim_loss = ncc_loss,
    num_subnetworks=num_scales
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# min_lr=1e-5
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.954992586)
loss_curves_train = LossCurves("training")
loss_curves_valid = LossCurves("validation", include_folds=True, spatial_dims=(144,144,144))

In [None]:
# TESTING; DELETE THIS CELL

it = iter(dl_train)
d1 = next(it)
d2 = next(it)
a = list(map(lambda k : d1[k], fa_keys))
b = list(map(lambda k : d2[k], fa_keys))
model(a,b).all_loss

In [None]:
# TRAINING

for e in range(max_epochs):
#     current_lr = scheduler.get_last_lr()[0]
    current_lr = optimizer.state_dict()['param_groups'][0]['lr']
    print(f'Epoch {e+1}/{max_epochs} (LR = {current_lr:.1e}, lambda_reg = {model.lambda_reg:.1e}):')
    
    # Train
    model.train()
    dl_train_iter = iter(dl_train)
    while True:
        try:
            b1 = next(dl_train_iter)
            b2 = next(dl_train_iter)
        except StopIteration:
            break
        
        optimizer.zero_grad()
        model_output = model(batch_to_scales_list(b1), batch_to_scales_list(b2))
        model_output.all_loss.backward()
        optimizer.step()
        
        loss_curves_train.add_to_buffer(model_output)
        del(model_output)

#     if scheduler.get_last_lr()[0] > min_lr:
#         scheduler.step()

    loss_curves_train.aggregate_buffers_for_epoch(e)
    print(f"\tTraining loss: {loss_curves_train.all_losses[-1]:.4f} (sim={loss_curves_train.sim_losses[-1]:.4f}, ic={loss_curves_train.gradicon_losses[-1]:.4f})")
    
    # Validate
    if validate_when(e):
        model.eval()
        dl_valid_iter = iter(dl_valid)
        while True:
            try:
                b1 = next(dl_valid_iter)
                b2 = next(dl_valid_iter)
            except StopIteration:
                break

            with torch.no_grad():
                model_output = model(batch_to_scales_list(b1), batch_to_scales_list(b2))
                loss_curves_valid.add_to_buffer(model_output)
        loss_curves_valid.aggregate_buffers_for_epoch(e) 
        print("\tValidation loss:", loss_curves_valid.all_losses[-1])
        print("\tAverage folds:", loss_curves_valid.fold_counts[-1])
    
    cooldown_counter -= 1;
    if cooldown_counter<=0 and loss_curves_train.all_losses[-1] > loss_curves_train.all_losses[-2]:
        print(f"Updating lambda_reg.")
        model.update_lambda_reg(model.lambda_reg + lambda_reg_step_size)
        cooldown_counter = cooldown # reset cooldown_counter
    if model.lambda_reg > lambda_reg_goal:
        print("Reached goal lambda_reg.")
        break
        
        
    

Epoch 1/300 (LR = 1.0e-03, lambda_reg = 0.0e+00):
	Training loss: -1.1175 (-1.1175,43.7104)
Epoch 2/300 (LR = 1.0e-03, lambda_reg = 0.0e+00):
	Training loss: -1.2127 (-1.2127,43.2884)
Epoch 3/300 (LR = 1.0e-03, lambda_reg = 0.0e+00):
	Training loss: -1.2808 (-1.2808,42.9596)
	Validation loss: -1.5844969153404236
	Average folds: 152591.75
Epoch 4/300 (LR = 1.0e-03, lambda_reg = 0.0e+00):
	Training loss: -1.2664 (-1.2664,42.3718)
Epoch 5/300 (LR = 1.0e-03, lambda_reg = 0.0e+00):
	Training loss: -1.3154 (-1.3154,42.5186)
	Validation loss: -1.6106345355510712
	Average folds: 145397.25
Epoch 6/300 (LR = 1.0e-03, lambda_reg = 0.0e+00):
	Training loss: -1.3178 (-1.3178,42.3591)
Epoch 7/300 (LR = 1.0e-03, lambda_reg = 0.0e+00):
	Training loss: -1.3027 (-1.3027,42.6714)
	Validation loss: -1.6265175342559814
	Average folds: 147743.625
Updating lambda_reg.
Epoch 8/300 (LR = 1.0e-03, lambda_reg = 1.0e-02):
	Training loss: -0.9415 (-1.2867,34.5146)
Epoch 9/300 (LR = 1.0e-03, lambda_reg = 1.0e-02):


In [None]:
# SAVE

loss_curves_train.plot(savepath = footsteps.output_dir + 'loss_plot_train.png')
loss_curves_valid.plot(savepath = footsteps.output_dir + 'loss_plot_valid.png')

with open(footsteps.output_dir + 'loss_curves.p', 'wb') as f:
    pickle.dump([loss_curves_train, loss_curves_valid],f)

torch.save(model.state_dict(), footsteps.output_dir + 'model_state_dict.pth')

In [None]:
# LOAD

experiment_name_to_load = "increase lambda_reg over schedule"
load_dir = os.path.join('results', experiment_name_to_load)

model.load_state_dict(torch.load(os.path.join(load_dir,'model_state_dict.pth')))

with open(os.path.join(load_dir,'loss_curves.p'), 'rb') as f:
    loss_curves_train, loss_curves_valid = pickle.load(f)

In [None]:
import util

ds = ds_train # Choose whether to view performance on training or on validation data
d1 = random.choice(ds)
d2 = random.choice(ds)

img_A = d1[fa_key].unsqueeze(0)
img_B = d2[fa_key].unsqueeze(0)
model.eval()
with torch.no_grad():
    model_output = model(img_A,img_B)

img_B_warped = warp(img_B, model_output.deformation_AB)

preview_slices = (80,80,80)

print("moving:")
util.preview_image(img_B[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("warped moving:")
util.preview_image(img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("target:")
util.preview_image(img_A[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("checkerboard of warped moving and target:")
util.preview_checkerboard(img_A[0,0].cpu(), img_B_warped[0,0].cpu(), figsize=(18,10), slices=preview_slices)
print("deformation vector field:")
util.preview_3D_vector_field(model_output.deformation_AB[0].cpu(), slices=preview_slices)
print("deformed grid:")
util.preview_3D_deformation(model_output.deformation_AB[0].cpu(),5, slices=preview_slices)
print("jacobian determinant:")
det = util.jacobian_determinant(model_output.deformation_AB[0].cpu())
util.preview_image(det, normalize_by='slice', threshold=0, slices=preview_slices)
print("sim loss:", model_output.sim_loss.item())
print("gradicon loss:", model_output.gradicon_loss.item())
print("overall loss:", model_output.all_loss.item())
num_folds = (det<0).sum()
print("Number of folds:", num_folds, f"(folding rate {100*num_folds/np.prod(det.shape)}%)")
