In [None]:
import os
import torch

os.environ['NEURITE_BACKEND'] = 'pytorch'
os.environ['VXM_BACKEND'] = 'pytorch'

# some third party very useful libraries
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

# our libraries
import voxelmorph as vxm
import neurite as ne

In [None]:
files = [f + '/slice_norm.nii.gz' for f in os.listdir('./data') if f.startswith('OASIS_OAS1_')]
vols = [nib.load('./data/'+f).get_fdata() for f in files]
x_vols = np.stack(vols, 0)
vol_shape = x_vols.shape[1:-1]

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

from voxelmorph.voxelmorph import default_unet_features
from voxelmorph.voxelmorph.torch import layers
from voxelmorph.voxelmorph.torch.modelio import LoadableModel, store_config_args


class Unet(nn.Module):
    """
    A unet architecture. Layer features can be specified directly as a list of encoder and decoder
    features or as a single integer along with a number of unet levels. The default network features
    per layer (when no options are specified) are:

        encoder: [16, 32, 32, 32]
        decoder: [32, 32, 32, 32, 32, 16, 16]
    """

    def __init__(self,
                 inshape=None,
                 infeats=None,
                 nb_features=None,
                 nb_levels=None,
                 max_pool=2,
                 feat_mult=1,
                 nb_conv_per_level=1,
                 half_res=False):
        """
        Parameters:
            inshape: Input shape. e.g. (192, 192, 192)
            infeats: Number of input features.
            nb_features: Unet convolutional features. Can be specified via a list of lists with
                the form [[encoder feats], [decoder feats]], or as a single integer. 
                If None (default), the unet features are defined by the default config described in 
                the class documentation.
            nb_levels: Number of levels in unet. Only used when nb_features is an integer. 
                Default is None.
            feat_mult: Per-level feature multiplier. Only used when nb_features is an integer. 
                Default is 1.
            nb_conv_per_level: Number of convolutions per unet level. Default is 1.
            half_res: Skip the last decoder upsampling. Default is False.
        """

        super().__init__()

        # ensure correct dimensionality
        ndims = len(inshape)
        assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims

        # cache some parameters
        self.half_res = half_res

        # default encoder and decoder layer features if nothing provided
        if nb_features is None:
            nb_features = default_unet_features()

        # build feature list automatically
        if isinstance(nb_features, int):
            if nb_levels is None:
                raise ValueError('must provide unet nb_levels if nb_features is an integer')
            feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(int)
            nb_features = [
                np.repeat(feats[:-1], nb_conv_per_level),
                np.repeat(np.flip(feats), nb_conv_per_level)
            ]
        elif nb_levels is not None:
            raise ValueError('cannot use nb_levels if nb_features is not an integer')

        # extract any surplus (full resolution) decoder convolutions
        enc_nf, dec_nf = nb_features
        nb_dec_convs = len(enc_nf)
        final_convs = dec_nf[nb_dec_convs:]
        dec_nf = dec_nf[:nb_dec_convs]
        self.nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1

        if isinstance(max_pool, int):
            max_pool = [max_pool] * self.nb_levels

        # cache downsampling / upsampling operations
        MaxPooling = getattr(nn, 'MaxPool%dd' % ndims)
        self.pooling = [MaxPooling(s) for s in max_pool]
        self.upsampling = [nn.Upsample(scale_factor=s, mode='nearest') for s in max_pool]

        # configure encoder (down-sampling path)
        prev_nf = infeats
        encoder_nfs = [prev_nf]
        self.encoder = nn.ModuleList()
        for level in range(self.nb_levels - 1):
            convs = nn.ModuleList()
            for conv in range(nb_conv_per_level):
                nf = enc_nf[level * nb_conv_per_level + conv]
                convs.append(ConvBlock(ndims, prev_nf, nf))
                prev_nf = nf
            self.encoder.append(convs)
            encoder_nfs.append(prev_nf)

        # configure decoder (up-sampling path)
        encoder_nfs = np.flip(encoder_nfs)
        self.decoder = nn.ModuleList()
        for level in range(self.nb_levels - 1):
            convs = nn.ModuleList()
            for conv in range(nb_conv_per_level):
                nf = dec_nf[level * nb_conv_per_level + conv]
                convs.append(ConvBlock(ndims, prev_nf, nf))
                prev_nf = nf
            self.decoder.append(convs)
            if not half_res or level < (self.nb_levels - 2):
                prev_nf += encoder_nfs[level]

        # now we take care of any remaining convolutions
        self.remaining = nn.ModuleList()
        for num, nf in enumerate(final_convs):
            self.remaining.append(ConvBlock(ndims, prev_nf, nf))
            prev_nf = nf

        # cache final number of features
        self.final_nf = prev_nf

    def forward(self, x):

        # encoder forward pass
        x_history = [x]
        for level, convs in enumerate(self.encoder):
            for conv in convs:
                x = conv(x)
            x_history.append(x)
            x = self.pooling[level](x)

        # decoder forward pass with upsampling and concatenation
        for level, convs in enumerate(self.decoder):
            for conv in convs:
                x = conv(x)
            if not self.half_res or level < (self.nb_levels - 2):
                x = self.upsampling[level](x)
                x = torch.cat([x, x_history.pop()], dim=1)

        # remaining convs at full resolution
        for conv in self.remaining:
            x = conv(x)

        return x


class VxmDense(LoadableModel):
    """
    VoxelMorph network for (unsupervised) nonlinear registration between two images.
    """

    @store_config_args
    def __init__(self,
                 inshape,
                 nb_unet_features=None,
                 nb_unet_levels=None,
                 unet_feat_mult=1,
                 nb_unet_conv_per_level=1,
                 int_steps=7,
                 int_downsize=1,
                 bidir=False,
                 use_probs=False,
                 src_feats=1,
                 trg_feats=1,
                 unet_half_res=False):
        """ 
        Parameters:
            inshape: Input shape. e.g. (192, 192, 192)
            nb_unet_features: Unet convolutional features. Can be specified via a list of lists with
                the form [[encoder feats], [decoder feats]], or as a single integer. 
                If None (default), the unet features are defined by the default config described in 
                the unet class documentation.
            nb_unet_levels: Number of levels in unet. Only used when nb_features is an integer. 
                Default is None.
            unet_feat_mult: Per-level feature multiplier. Only used when nb_features is an integer. 
                Default is 1.
            nb_unet_conv_per_level: Number of convolutions per unet level. Default is 1.
            int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this 
                value is 0.
            int_downsize: Integer specifying the flow downsample factor for vector integration. 
                The flow field is not downsampled when this value is 1.
            bidir: Enable bidirectional cost function. Default is False.
            use_probs: Use probabilities in flow field. Default is False.
            src_feats: Number of source image features. Default is 1.
            trg_feats: Number of target image features. Default is 1.
            unet_half_res: Skip the last unet decoder upsampling. Requires that int_downsize=2. 
                Default is False.
        """
        super().__init__()

        # internal flag indicating whether to return flow or integrated warp during inference
        self.training = True

        # ensure correct dimensionality
        ndims = len(inshape)
        assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims

        # configure core unet model
        self.unet_model = Unet(
            inshape,
            infeats=(src_feats + trg_feats),
            nb_features=nb_unet_features,
            nb_levels=nb_unet_levels,
            feat_mult=unet_feat_mult,
            nb_conv_per_level=nb_unet_conv_per_level,
            half_res=unet_half_res,
        )

        # configure unet to flow field layer
        Conv = getattr(nn, 'Conv%dd' % ndims)
        self.flow = Conv(self.unet_model.final_nf, ndims, kernel_size=3, padding=1)

        # init flow layer with small weights and bias
        self.flow.weight = nn.Parameter(Normal(0, 1e-5).sample(self.flow.weight.shape))
        self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape))

        # probabilities are not supported in pytorch
        if use_probs:
            raise NotImplementedError(
                'Flow variance has not been implemented in pytorch - set use_probs to False')

        # configure optional resize layers (downsize)
        if not unet_half_res and int_steps > 0 and int_downsize > 1:
            self.resize = layers.ResizeTransform(int_downsize, ndims)
        else:
            self.resize = None

        # resize to full res
        if int_steps > 0 and int_downsize > 1:
            self.fullsize = layers.ResizeTransform(1 / int_downsize, ndims)
        else:
            self.fullsize = None

        # configure bidirectional training
        self.bidir = bidir

        # configure optional integration layer for diffeomorphic warp
        down_shape = [int(dim / int_downsize) for dim in inshape]
        self.integrate = layers.VecInt(down_shape, int_steps) if int_steps > 0 else None

        # configure transformer
        self.transformer = layers.SpatialTransformer(inshape)

    def forward(self, source, target, registration=False):
        '''
        Parameters:
            source: Source image tensor.
            target: Target image tensor.
            registration: Return transformed image and flow. Default is False.
        '''

        # concatenate inputs and propagate unet
        x = torch.cat([source, target], dim=1)

        x = self.unet_model(x)


        # transform into flow field
        flow_field = self.flow(x)

        # resize flow for integration
        pos_flow = flow_field
        if self.resize:
            pos_flow = self.resize(pos_flow)

        preint_flow = pos_flow

        # negate flow for bidirectional model
        neg_flow = -pos_flow if self.bidir else None

        # integrate to produce diffeomorphic warp
        if self.integrate:
            pos_flow = self.integrate(pos_flow)
            neg_flow = self.integrate(neg_flow) if self.bidir else None

            # resize to final resolution
            if self.fullsize:
                pos_flow = self.fullsize(pos_flow)
                neg_flow = self.fullsize(neg_flow) if self.bidir else None

        # warp image with flow field
        y_source = self.transformer(source, pos_flow)
        y_target = self.transformer(target, neg_flow) if self.bidir else None

        # return non-integrated flow field if training
        if not registration:
            return (y_source, y_target, preint_flow) if self.bidir else (y_source, preint_flow)
        else:
            return y_source, pos_flow


class ConvBlock(nn.Module):
    """
    Specific convolutional block followed by leakyrelu for unet.
    """

    def __init__(self, ndims, in_channels, out_channels, stride=1):
        super().__init__()

        Conv = getattr(nn, 'Conv%dd' % ndims)
        self.main = Conv(in_channels, out_channels, 3, stride, 1)
        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x):
        out = self.main(x)
        out = self.activation(out)
        return out

In [None]:
class MeanStream(nn.Module):
    """
    Maintain stream of data mean.

    cap refers to maintaining an approximation of up to that number of subjects -- that is,
    any incoming data point will have at least 1/cap weight.

    If you find this class useful, please cite the original paper this was written for:
        A.V. Dalca, M. Rakic, J. Guttag, M.R. Sabuncu.
        Learning Conditional Deformable Templates with Convolutional Networks
        NeurIPS: Advances in Neural Information Processing Systems. pp 804-816, 2019.
    """

    def __init__(self, cap=100, **kwargs):
        super(MeanStream, self).__init__(**kwargs)
        self.cap = float(cap)
        self.mean = nn.Parameter(torch.zeros(1, 2, 160, 192), requires_grad=False)
        self.count = nn.Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, x):
        # Get batch size
        this_bs_int = x.size(0)

        # If calling in inference mode, use moving stats
        if not self.training:
            return torch.min(torch.tensor(1.), self.count / self.cap) * (torch.ones((this_bs_int, 1, 160, 192)) * self.mean)

        # Get new mean and count
        new_mean, new_count = self._mean_update(x)

        # Update mean and count
        self.count.data = new_count
        self.mean.data = new_mean

        # The first few 1000 should not matter that much towards this cost
        return torch.min(torch.tensor(1.), new_count / self.cap) * (torch.ones((this_bs_int, 1, 160, 192)).cuda() * new_mean)

    def _mean_update(self, x):
        # Convert to float for calculations
        x_float = x.float()

        # Compute sum and count
        sum_x = torch.sum(x_float, dim=0, keepdim=True)
        count_x = torch.tensor(x.size(0), dtype=torch.float)

        # Compute new mean and count
        new_mean = sum_x / count_x
        new_count = self.count.data + count_x

        return new_mean, new_count

In [None]:
class ReferenceContainer:
    def __init__(self):
        self.atlas_layer = None
        self.vxm_model = None
        self.pos_flow = None
        self.neg_flow = None

In [None]:
class TemplateCreation(nn.Module):
    """
    VoxelMorph network to generate an unconditional template image.
    """

    def __init__(self, inshape, nb_unet_features=None, mean_cap=100, atlas_feats=1, src_feats=1, mean_temp = None, load_weights = False, **kwargs):
        """
        Parameters:
            inshape: Input shape. e.g. (192, 192, 192)
            nb_unet_features: Unet convolutional features.
                See VxmDense documentation for more information.
            mean_cap: Cap for mean stream. Default is 100.
            atlas_feats: Number of atlas/template features. Default is 1.
            src_feats: Number of source image features. Default is 1.
            kwargs: Forwarded to the internal VxmDense model.
        """
        super(TemplateCreation, self).__init__()

        # configure inputs
        self.src_feats = src_feats

        # pre-warp (atlas) model
        if mean_temp.any() != None:
            mean = torch.tensor(mean_temp,dtype = torch.float32)#.permute(1,2,0).unsqueeze(0)
            self.atlas_layer = nn.Parameter(mean, requires_grad=True)
        else:
            self.atlas_layer = nn.Parameter(torch.randn(*inshape, atlas_feats) * 1e-7, requires_grad=True)
        # warp model
        self.vxm_model = VxmDense(inshape, nb_unet_features=nb_unet_features, bidir=True, **kwargs)
        if load_weights == True:
            checkpoint = torch.load('./models/1000.pt')
            self.vxm_model.load_state_dict(checkpoint['model_state'], strict = False)
            for name, param in self.vxm_model.named_parameters():
                if 'flow' not in name:
                    param.requires_grad = False
        

        # mean stream
        self.mean_stream = MeanStream(mean_cap)
        self.mean_stream.cuda()

        # cache references
        self.references = ReferenceContainer()
        self.references.atlas_layer = self.atlas_layer
        self.references.vxm_model = self.vxm_model

    def forward(self, source_input):
        # atlas tensor
        atlas_tensor = self.atlas_layer.expand(source_input.shape[0], *self.atlas_layer.shape).permute(0,3,1,2)
        
        

        # warp model
        y_source, y_target, preint_flow = self.vxm_model(atlas_tensor, source_input)
        #print(len(y_pred))
        

        # get mean stream of negative flow
        mean_stream = self.mean_stream(-1*preint_flow)
        

        return y_source, y_target, mean_stream, preint_flow

    def set_atlas(self, atlas):
        """
        Sets the atlas weights.
        """
        if len(atlas.shape) > len(self.atlas_layer.shape):
            atlas = np.reshape(atlas, atlas.shape[1:])
        self.atlas_layer.data = torch.from_numpy(atlas)

    def get_atlas(self):
        """
        Gets the atlas weights.
        """
        return self.atlas_layer.data.squeeze().cpu().numpy()

    def get_registration_model(self):
        """
        Returns a reconfigured model to predict only the final transform.
        """
        return lambda src, trg: self.vxm_model(src, trg, registration = True)#[-1]

    def register(self, src, trg):
        """
        Predicts the transform from src to trg tensors.
        """
        return self.get_registration_model()(src, trg)

    def apply_transform(self, src, trg, img, interp_method='linear', fill_value=None):
        """
        Predicts the transform from src to trg and applies it to the img tensor.
        """
        warp_model = self.get_registration_model()
        return spatial_transformer(img, warp_model(src, trg), interp_method, fill_value)

In [None]:
def get_mean_temp(invol):
    mean = np.mean(invol, axis = 0)
    stdev = np.std(invol,axis= (0,1,2))
    return mean

In [None]:
mean_temp = get_mean_temp(x_vols)

In [None]:
enc_nf = [16, 32, 32, 32]
dec_nf = [32, 32, 32, 32, 32, 16, 16]
model = TemplateCreation(vol_shape, nb_unet_features=[enc_nf, dec_nf], mean_temp = mean_temp)

In [None]:
def volgen(
    vol_names,
    batch_size=1,
    segs=None,
    np_var='vol',
    pad_shape=None,
    resize_factor=1,
    add_feat_axis=True
):
    """
    Base generator for random volume loading. Volumes can be passed as a path to
    the parent directory, a glob pattern, a list of file paths, or a list of
    preloaded volumes. Corresponding segmentations are additionally loaded if
    `segs` is provided as a list (of file paths or preloaded segmentations) or set
    to True. If `segs` is True, npz files with variable names 'vol' and 'seg' are
    expected. Passing in preloaded volumes (with optional preloaded segmentations)
    allows volumes preloaded in memory to be passed to a generator.

    Parameters:
        vol_names: Path, glob pattern, list of volume files to load, or list of
            preloaded volumes.
        batch_size: Batch size. Default is 1.
        segs: Loads corresponding segmentations. Default is None.
        np_var: Name of the volume variable if loading npz files. Default is 'vol'.
        pad_shape: Zero-pads loaded volumes to a given shape. Default is None.
        resize_factor: Volume resize factor. Default is 1.
        add_feat_axis: Load volume arrays with added feature axis. Default is True.
    """

    # convert glob path to filenames
    if isinstance(vol_names, str):
        if os.path.isdir(vol_names):
            vol_names = os.path.join(vol_names, '*')
        vol_names = glob.glob(vol_names)

    if isinstance(segs, list) and len(segs) != len(vol_names):
        raise ValueError('Number of image files must match number of seg files.')

    while True:
        # generate [batchsize] random image indices
        indices = np.random.randint(len(vol_names), size=batch_size)

        # load volumes and concatenate
        load_params = dict(np_var=np_var, add_batch_axis=True, add_feat_axis=add_feat_axis,
                           pad_shape=pad_shape, resize_factor=resize_factor)
        imgs = [py.utils.load_volfile(vol_names[i], **load_params) for i in indices]
        vols = [np.concatenate(imgs, axis=0)]

        # optionally load segmentations and concatenate
        if segs is True:
            # assume inputs are npz files with 'seg' key
            load_params['np_var'] = 'seg'  # be sure to load seg
            s = [py.utils.load_volfile(vol_names[i], **load_params) for i in indices]
            vols.append(np.concatenate(s, axis=0))
        elif isinstance(segs, list):
            # assume segs is a corresponding list of files or preloaded volumes
            s = [py.utils.load_volfile(segs[i], **load_params) for i in indices]
            vols.append(np.concatenate(s, axis=0))

        yield tuple(vols)

def template_creation(vol_names, bidir=False, batch_size=1, **kwargs):
    """
    Generator for unconditional template creation.

    Parameters:
        vol_names: List of volume files to load, or list of preloaded volumes.
        bidir: Yield input image as output for bidirectional models. Default is False.
        batch_size: Batch size. Default is 1.
        kwargs: Forwarded to the internal volgen generator.
    """
    zeros = None
    gen = volgen(vol_names, batch_size=batch_size, **kwargs)
    while True:
        scan = next(gen)[0]

        # cache zeros
        if zeros is None:
            shape = scan.shape[1:-1]
            zeros = np.zeros((1, *shape, len(shape)))

        invols = [scan]
        outvols = [scan, zeros, zeros, zeros] if bidir else [scan, zeros, zeros]
        yield (invols, outvols)


In [None]:
gen = template_creation(["./data/"+ i for i in files], bidir=True, batch_size=2)

In [None]:
from voxelmorph.voxelmorph import losses as lss

import torch.optim as optim



# Define optimizer
optimizer = optim.Adam(model.parameters())

# Function to compute total loss
def compute_total_loss_(outputs, targets):

    ncc_l1 = lss.MSE().loss(outputs[0], targets[0])
    #print(ncc_l1)
    ncc_l2 = lss.MSE().loss(model.references.atlas_layer.permute(2,0,1).unsqueeze(0), outputs[1][0].unsqueeze(0))
    #print(ncc_l2)
    mse_l = lss.MSE().loss(outputs[2], targets[2])
    #print(mse_l)
    lgrad = lss.Grad('l2', loss_mult=2).loss(outputs[3], targets[3])
    #print(lgrad)
    total_loss = ncc_l1+ ncc_l2+0.01*mse_l+lgrad
    return total_loss

    

In [None]:
import random

def dice_loss(pred, target):
    """This definition generalize to real valued pred and target vector.
This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(tflat * iflat)
    B_sum = torch.sum(tflat * tflat)
    
    return ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
    
def test_dice():
    dice = dice_loss#lss.Dice().loss
    selected_samples_src = random.choices(x_vols,k = 100)
    selected_samples_src = torch.tensor(selected_samples_src,dtype = torch.float32).permute(0,3,1,2)
    selected_samples_trg = random.choices(x_vols, k= 100)
    selected_samples_trg = torch.tensor(selected_samples_trg,dtype = torch.float32).permute(0,3,1,2)
    template = torch.tensor(model.get_atlas(),dtype = torch.float32).unsqueeze(0).unsqueeze(0).cuda()
    dice_score = 0
    for i in range(100):
        with torch.no_grad():
            #print(template.shape)
            #print(selected_samples_src[i].unsqueeze(0).shape)
            src_to_temp = model.register(selected_samples_src[i].unsqueeze(0).cuda(), template)
            #print("gets")
            #print(len(src_to_temp))
            temp_to_trg = model.register(src_to_temp[0][0].unsqueeze(0),selected_samples_trg[i].unsqueeze(0).cuda())
            #print("here")
            #print(dice(temp_to_trg[0], selected_samples_trg[i].unsqueeze(0).cuda()))
            dice_score += dice(temp_to_trg[0][0], selected_samples_trg[i].cuda())
            #print(dice(temp_to_trg[0][0], selected_samples_trg[i].cuda()))
            #print(src_to_temp.shape)
            #print(selected_samples_src[i].unsqueeze(0).shape)
            """plt.figure(figsize=(12, 3))
            title = ["Source","Mean Atlas", "Source to Target", "Target"]
            # Iterate through each image
            for i, img in enumerate([np.rot90(selected_samples_src[i][0].cpu().numpy(),-1),
                                     np.rot90(model.get_atlas(),-1),
                                     np.rot90(temp_to_trg[0][0][0].cpu().numpy(), -1),
                                     np.rot90(selected_samples_trg[i][0].cpu().numpy(),-1)]):
            
                # Plot each image
                plt.subplot(1, 4, i + 1)
                plt.imshow(img, cmap="gray")
                plt.title(title[i])
                plt.axis('off')"""
            
            # Show the plot
            #plt.title("Atlas with PreLoaded RegNet Weights")
            #plt.show()
            
            """plt.imshow(selected_samples_src[i][0].cpu().numpy(),cmap="grey")
            plt.title("Source")
            plt.show()
            plt.imshow(model.get_atlas(),cmap = "grey")
            plt.title("Atlas")
            plt.show()
            #plt.imshow(src_to_temp[0][0][0].cpu().numpy(),cmap="grey")
            #plt.show()
            plt.imshow(temp_to_trg[0][0][0].cpu().numpy(),cmap="grey")
            plt.title("Transformed Source")
            plt.show()
            plt.imshow(selected_samples_trg[i][0].cpu().numpy(),cmap="grey")
            plt.title("Target")
            plt.show()"""
            #asd
    return dice_score / 100

In [None]:

epochs = 1000
steps_per_epoch = 25
model.cuda()
total_dice_loss = []
loss_list = []

# Training loop
for epoch in range(epochs):
    total_loss = 0.0
    for i, batch in enumerate(gen):
        # Extract inputs and targets from the batch
        inputs, targets = batch
        #inputs = torch.FloatTensor(inputs)
        #targets = torch.FloatTensor(targets)
        
        # Perform a training step
        optimizer.zero_grad()

        
        inputs = torch.tensor(inputs[0], dtype = torch.float32).permute(0,3,1,2).cuda()

        targets = [torch.from_numpy(d).cuda().float().permute(0, 3, 1, 2) for d in targets]#.permute(0,3,1,2)
        outputs = model(inputs)

        loss = compute_total_loss_(outputs, targets)
        #print(loss)
        loss.backward()
        optimizer.step()
        total_loss += loss
        
        if i >= steps_per_epoch - 1:
            break  
    


    loss_list.append(total_loss)
    if epoch % 50 == 0:
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss / steps_per_epoch}')
    if epoch % 50 == 0:
        dice_score = test_dice()
        total_dice_loss.append(dice_score)
        print(dice_score)