In [1]:
import os
from pathlib import Path

os.environ['MMWHS_CACHE_PATH'] = str(Path('.', '.cache'))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp

from tqdm import tqdm
import wandb
import nibabel as nib

from slice_inflate.datasets.mmwhs_dataset import MMWHSDataset, load_data, extract_2d_data
from slice_inflate.utils.common_utils import DotDict, get_script_dir
from slice_inflate.utils.torch_utils import reset_determinism, ensure_dense, get_batch_dice_over_all, get_batch_dice_per_class, save_model
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from slice_inflate.datasets.align_mmwhs import cut_slice
from slice_inflate.utils.log_utils import get_global_idx, log_class_dices
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader

from mdl_seg_class.metrics import dice3d
import numpy as np
THIS_SCRIPT_DIR = get_script_dir()

PROJECT_NAME = "slice_inflate"

In [31]:

config_dict = DotDict({
    'num_folds': 5,
    'only_first_fold': True,                # If true do not contiue with training after the first fold
    # 'fold_override': 0,
    # 'checkpoint_epx': 0,

    'use_mind': False,                      # If true use MIND features (https://pubmed.ncbi.nlm.nih.gov/22722056/)
    'epochs': 100,

    'batch_size': 4,
    'val_batch_size': 1,
    'modality': 'mr',
    'use_2d_normal_to': None,               # Can be None or 'D', 'H', 'W'. If not None 2D slices will be selected for training

    'dataset': 'mmwhs',                 # The dataset prepared with our preprocessing scripts
    'data_base_path': str(Path(THIS_SCRIPT_DIR, "data/MMWHS")),
    'reg_state': None, # Registered (noisy) labels used in training. See prepare_data() for valid reg_states
    'train_set_max_len': None,              # Length to cut of dataloader sample count
    'crop_around_3d_label_center': (128,128,128),
    'crop_3d_region': ((0,128), (0,128), (0,128)),        # dimension range in which 3D samples are cropped
    'crop_2d_slices_gt_num_threshold': 0,   # Drop 2D slices if less than threshold pixels are positive

    'lr': 1e-5,
    'use_scheduling': True,

    'save_every': 'best',
    'mdl_save_prefix': 'data/models',

    'debug': False,
    'wandb_mode': 'disabled',                         # e.g. online, disabled. Use weights and biases online logging
    'do_sweep': False,                                # Run multiple trainings with varying config values defined in sweep_config_dict below

    # For a snapshot file: dummy-a2p2z76CxhCtwLJApfe8xD_fold0_epx0
    'checkpoint_name': None,                          # Training snapshot name, e.g. dummy-a2p2z76CxhCtwLJApfe8xD
    'fold_override': None,                            # Training fold, e.g. 0
    'checkpoint_epx': None,                           # Training epx, e.g. 0

    'do_plot': False,                                 # Generate plots (debugging purpose)
    'save_dp_figures': False,                         # Plot data parameter value distribution
    'save_labels': True,                              # Store training labels alongside data parameter values inside the training snapshot

    'device': 'cuda'
})

def prepare_data(config):
    training_dataset = MMWHSDataset(
        config.data_base_path,
        state="training",
        load_func=load_data,
        extract_slice_func=extract_2d_data,
        modality=config.modality,
        do_align_global=True,
        do_resample=False, # Prior to cropping, resample image?
        crop_3d_region=None, # Crop or pad the images to these dimensions
        crop_around_3d_label_center=config.crop_around_3d_label_center,
        pre_interpolation_factor=1., # When getting the data, resize the data by this factor
        ensure_labeled_pairs=True, # Only use fully labelled images (segmentation label available)
        use_2d_normal_to=config.use_2d_normal_to, # Use 2D slices cut normal to D,H,>W< dimensions
        crop_around_2d_label_center=(128,128),

        augment_angle_std=5,

        device=config.device,
        debug=config.debug
    )

    return training_dataset

In [32]:
if False:
    training_dataset = prepare_data(config_dict)
    training_dataset.train(augment=False)
    training_dataset.self_attributes['augment_angle_std'] = 2
    print(training_dataset.do_augment)
    for sample in [training_dataset[idx] for idx in [1]]:
        pass
        fig = plt.figure(figsize=(16., 4.))
        grid = ImageGrid(fig, 111,  # similar to subplot(111)
            nrows_ncols=(1, 6),  # creates 2x2 grid of axes
            axes_pad=0.0,  # pad between axes in inch.
        )

        show_row = [
            cut_slice(sample['image']),
            cut_slice(sample['label']),

            sample['sa_image_slc'],
            sample['sa_label_slc'],

            sample['hla_image_slc'],
            sample['hla_label_slc'],
        ]

        for ax, im in zip(grid, show_row):
            ax.imshow(im, cmap='gray', interpolation='none')

        plt.show()

In [33]:
if False:
    training_dataset = prepare_data(config_dict)
    training_dataset.train()

    training_dataset.self_attributes['augment_angle_std'] = 10
    print(training_dataset.do_augment)
    import torch
    lbl, sa_label, hla_label = torch.zeros(128,128), torch.zeros(128,128), torch.zeros(128,128)
    for idx in range(15):
        sample = training_dataset[1]
        # nib.save(nib.Nifti1Image(sample['label'].cpu().numpy(), affine=torch.eye(4).numpy()), f'out{idx}.nii.gz')
        lbl += cut_slice(sample['label']).cpu()
        sa_label += sample['sa_label_slc'].cpu()
        hla_label += sample['hla_label_slc'].cpu()
    fig = plt.figure(figsize=(16., 4.))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
        nrows_ncols=(1, 3),  # creates 2x2 grid of axes
        axes_pad=0.0,  # pad between axes in inch.
    )

    show_row = [
        lbl, sa_label, hla_label
    ]

    for ax, im in zip(grid, show_row):
        ax.imshow(im, cmap='magma', interpolation='none')

    plt.show()

In [34]:
if False:
    training_dataset = prepare_data(config_dict)
    training_dataset.train(augment=False)
    training_dataset.self_attributes['augment_angle_std'] = 2
    print(training_dataset.do_augment)

    lbl, sa_label, hla_label = torch.zeros(128,128), torch.zeros(128,128), torch.zeros(128,128)
    for tr_idx in range(len(training_dataset)):
        sample = training_dataset[tr_idx]

        lbl += cut_slice(sample['label']).cpu()
        sa_label += sample['sa_label_slc'].cpu()
        hla_label += sample['hla_label_slc'].cpu()

    fig = plt.figure(figsize=(16., 4.))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
        nrows_ncols=(1, 3),  # creates 2x2 grid of axes
        axes_pad=0.0,  # pad between axes in inch.
    )

    show_row = [
        lbl, sa_label, hla_label
    ]

    for ax, im in zip(grid, show_row):
        ax.imshow(im, cmap='magma', interpolation='none')

    plt.show()

In [35]:
import contextlib

def get_named_layers_leaves(module):
    """ Returns all leaf layers of a pytorch module and a keychain as identifier.
        e.g.
        ...
        ('features.0.5', nn.ReLU())
        ...
        ('classifier.0', nn.BatchNorm2D())
        ('classifier.1', nn.Linear())
    """

    return [(keychain, sub_mod) for keychain, sub_mod in list(module.named_modules()) if not next(sub_mod.children(), None)]

@contextlib.contextmanager
def temp_forward_hooks(modules, pre_fwd_hook_fn=None, post_fwd_hook_fn=None):
    handles = []
    if pre_fwd_hook_fn:
        handles.extend([mod.register_forward_pre_hook(pre_fwd_hook_fn) for mod in modules])
    if post_fwd_hook_fn:
        handles.extend([mod.register_forward_hook(post_fwd_hook_fn) for mod in modules])

    yield
    for hand in handles:
        hand.remove()

def debug_forward_pass(module, inpt, STEP_MODE=False):
    named_leaves = get_named_layers_leaves(module)
    leave_mod_dict = {mod:keychain for keychain, mod in named_leaves}

    def get_shape_str(interface_var):
        if isinstance(interface_var, tuple):
            shps = [str(elem.shape) if isinstance(elem, torch.Tensor) else type(elem) for elem in interface_var]
            return ', '.join(shps)
        elif isinstance(interface_var, torch.Tensor):
            return interface_var.shape
        return type(interface_var)

    def print_pre_info(module, inpt):
        inpt_shapes = get_shape_str(inpt)
        print(f"in:  {inpt_shapes}")
        print(f"key: {leave_mod_dict[module]}")
        print(f"mod: {module}")
        if STEP_MODE:
            input("To continue forward pass press [ENTER]")

    def print_post_info(module, inpt, output):
        output_shapes = get_shape_str(output)
        print(f"out: {output_shapes}\n")

    with temp_forward_hooks(leave_mod_dict.keys(), print_pre_info, print_post_info):
        return module(inpt)

In [36]:
class BlendowskiAE(torch.nn.Module):

    class ConvBlock(torch.nn.Module):
        def __init__(self, in_channels: int, out_channels_list: list, strides_list: list, kernels_list:list=None, paddings_list:list=None):
            super().__init__()

            ops = []
            in_channels = [in_channels] + out_channels_list[:-1]
            if kernels_list is None:
                kernels_list = [3] * len(out_channels_list)
            if paddings_list is None:
                paddings_list = [1] * len(out_channels_list)

            for op_idx in range(len(out_channels_list)):
                ops.append(torch.nn.Conv3d(
                    in_channels[op_idx],
                    out_channels_list[op_idx],
                    kernel_size=kernels_list[op_idx],
                    stride=strides_list[op_idx],
                    padding=paddings_list[op_idx]
                ))

            self.block = torch.nn.Sequential(*ops)

        def forward(self, x):
            return self.block(x)



    def __init__(self, in_channels, out_channels, decoder_in_channels=2, debug_mode=False):
        super().__init__()

        self.debug_mode = debug_mode

        self.first_layer_encoder = self.ConvBlock(in_channels, out_channels_list=[8], strides_list=[1])
        self.first_layer_decoder = self.ConvBlock(8, out_channels_list=[8,out_channels], strides_list=[1,1])

        self.second_layer_encoder = self.ConvBlock(8, out_channels_list=[20,20,20], strides_list=[2,1,1])
        self.second_layer_decoder = self.ConvBlock(20, out_channels_list=[8], strides_list=[1])

        self.third_layer_encoder = self.ConvBlock(20, out_channels_list=[40,40,40], strides_list=[2,1,1])
        self.third_layer_decoder = self.ConvBlock(40, out_channels_list=[20], strides_list=[1])

        self.fourth_layer_encoder = self.ConvBlock(40, out_channels_list=[60,60,60], strides_list=[2,1,1])
        self.fourth_layer_decoder = self.ConvBlock(decoder_in_channels, out_channels_list=[40], strides_list=[1])

        self.deepest_layer = self.ConvBlock(60, out_channels_list=[60,20,2], strides_list=[2,1,1])

        self.encoder = torch.nn.Sequential(
            self.first_layer_encoder,
            self.second_layer_encoder,
            self.third_layer_encoder,
            self.fourth_layer_encoder,
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2),
            self.fourth_layer_decoder,
            torch.nn.Upsample(scale_factor=2),
            self.third_layer_decoder,
            torch.nn.Upsample(scale_factor=2),
            self.second_layer_decoder,
            torch.nn.Upsample(scale_factor=2),
            self.first_layer_decoder,
        )

    def encode(self, x):
        h = self.encoder(x)
        h = self.deepest_layer(h)
        return h
        # if self.debug_mode:
        #     return debug_forward_pass(self.encoder, x, STEP_MODE=False)
        # else:
        #     return self.encoder(x)

    def decode(self, z):
        if self.debug_mode:
            return debug_forward_pass(self.decoder, z, STEP_MODE=False)
        else:
            return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z



class BlendowskiVAE(BlendowskiAE):
    def __init__(self, *args, **kwargs):
        kwargs['decoder_in_channels'] = 1
        super().__init__(*args, **kwargs)

        self.deepest_layer = nn.ModuleList([
            self.ConvBlock(60, out_channels_list=[60,20,20,1], strides_list=[2,1,1,1], kernels_list=[3,3,3,1], paddings_list=[1,1,1,0]),
            self.ConvBlock(60, out_channels_list=[60,20,20,1], strides_list=[2,1,1,1], kernels_list=[3,3,3,1], paddings_list=[1,1,1,0]),
        ])

        self.log_var_scale = nn.Parameter(torch.Tensor([0.0]))

    def sample_z(self, mean, std):
        return torch.normal(mean=mean, std=std)

    def encode(self, x):
        h = self.encoder(x)
        mean = self.deepest_layer[0](h)
        log_var = self.deepest_layer[1](h)
        return mean, log_var

    def forward(self, x):
        mean, log_var = self.encode(x)
        std = torch.exp(log_var/2) + 1e-6
        z = self.sample_z(mean=mean, std=std)
        return self.decode(z), (z, mean, std)



In [37]:
# x = torch.zeros(1,8,128,128,128)
# bae = BlendowskiAE(in_channels=8, out_channels=8)

# y, z = bae(x)

# print("BAE")
# print("x", x.shape)
# print("z", z.shape)
# print("y", y.shape)
# print()

# bvae = BlendowskiVAE(in_channels=8, out_channels=8)

# y, z = bvae(x)

# print("BVAE")
# print("x", x.shape)
# print("z", z.shape)
# print("y", y.shape)

In [18]:
# model = BlendowskiVAE(in_channels=6, out_channels=6)
# model.cuda()
# with torch.no_grad():
#     smp = torch.nn.functional.one_hot(training_dataset[1]['label'], 6).unsqueeze(0).permute([0,4,1,2,3]).float().cuda()
# y, _ = model(smp)

In [19]:
training_dataset = prepare_data(config_dict)

Loading MMWHS training images and labels... (['mr'])


20 images, 20 labels: 100%|██████████| 40/40 [00:31<00:00,  1.26it/s]


Postprocessing 3D volumes
Removed 0 3D images in postprocessing
Equal image and label numbers: True (20)
Data import finished.
Dataloader will yield 3D samples


In [20]:
# def nan_hook(self, inp, output):
#     if not isinstance(output, tuple):
#         outputs = [output]
#     else:
#         outputs = output

#     for i, out in enumerate(outputs):
#         nan_mask = torch.isnan(out)
#         if nan_mask.any():
#             print("In", self.__class__.__name__)
#             raise RuntimeError(f"Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])

def get_model(config, dataset_len, num_classes, THIS_SCRIPT_DIR, _path=None, device='cpu'):
    _path = Path(THIS_SCRIPT_DIR).joinpath(_path).resolve()

    model = BlendowskiVAE(in_channels=num_classes, out_channels=num_classes)

    model.to(device)
    print(f"Param count model: {sum(p.numel() for p in model.parameters())}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    scaler = amp.GradScaler()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, verbose=True)

    if _path and _path.is_dir():
        print(f"Loading model, optimizers and grad scalers from {_path}")
        model.load_state_dict(torch.load(_path.joinpath('model.pth'), map_location=device))
        optimizer.load_state_dict(torch.load(_path.joinpath('optimizer.pth'), map_location=device))
        scheduler.load_state_dict(torch.load(_path.joinpath('scheduler.pth'), map_location=device))
        scaler.load_state_dict(torch.load(_path.joinpath('scaler.pth'), map_location=device))
    else:
        print(f"Generating fresh '{type(model).__name__}' model, optimizer and grad scaler.")

    # for submodule in model.modules():
    #     submodule.register_forward_hook(nan_hook)
    
    return (model, optimizer, scheduler, scaler)

In [38]:
def get_model_input(batch, config, num_classes):
    b_hla_slc_seg = batch['hla_label_slc']
    b_sa_slc_seg = batch['sa_label_slc']
    b_input = torch.cat(
        [b_sa_slc_seg.unsqueeze(1).repeat(1,64,1,1),
            b_hla_slc_seg.unsqueeze(1).repeat(1,64,1,1)],
            dim=1
    )
    b_seg = batch['label']

    b_input = b_input.to(device=config.device)
    b_seg = b_seg.to(device=config.device)

    b_input = F.one_hot(b_input, num_classes).permute(0,4,1,2,3)
    b_input = b_input.float()

    return b_input, b_seg

def inference_wrap(model, seg):
    with torch.inference_mode():
        b_seg = seg.unsqueeze(0).unsqueeze(0).float()
        b_out = model(b_seg)[0]
        b_out = b_out.argmax(1)
        return b_out



def gaussian_likelihood(y_hat, log_var_scale, y_target):
    B, *_ = y_hat.shape
    mean = y_hat
    scale = torch.exp(log_var_scale/2)
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    log_pxz = dist.log_prob(y_hat)

    # GLH, mean instead of sum..
    return log_pxz.view(B, -1).mean(-1)



def kl_divergence(z, mean, std):
    # See https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed
    B, *_ = z.shape
    p = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std))
    q = torch.distributions.Normal(mean, std)

    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

    # KL divergence
    kl = (log_qzx - log_pz)

    # Reduce spatial dimensions, mean instead of sum
    kl = kl.view(B, -1).mean(-1)
    return kl



def train_DL(run_name, config, training_dataset):
    reset_determinism()

    # Configure folds
    kf = KFold(n_splits=config.num_folds)
    # kf.get_n_splits(training_dataset.__len__(use_2d_override=False))
    fold_iter = enumerate(kf.split(range(training_dataset.__len__(use_2d_override=False))))

    if config.get('fold_override', None):
        selected_fold = config.get('fold_override', 0)
        fold_iter = list(fold_iter)[selected_fold:selected_fold+1]
    elif config.only_first_fold:
        fold_iter = list(fold_iter)[0:1]

    if config.use_2d_normal_to is not None:
        n_dims = (-2,-1)
    else:
        n_dims = (-3,-2,-1)

    fold_means_no_bg = []

    best_val_score = 0

    for fold_idx, (train_idxs, val_idxs) in fold_iter:
        train_idxs = torch.tensor(train_idxs)
        val_idxs = torch.tensor(val_idxs)
        val_ids = training_dataset.switch_3d_identifiers(val_idxs)

        print(f"Will run validation with these 3D samples (#{len(val_ids)}):", sorted(val_ids))

        ### Add train sampler and dataloaders ##
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_idxs)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_idxs)

        train_dataloader = DataLoader(training_dataset, batch_size=config.batch_size,
            sampler=train_subsampler, pin_memory=False, drop_last=False,
            # collate_fn=training_dataset.get_efficient_augmentation_collate_fn()
        )
        val_dataloader = DataLoader(training_dataset, batch_size=config.val_batch_size,
            sampler=val_subsampler, pin_memory=False, drop_last=False,
        )

        ### Get model, data parameters, optimizers for model and data parameters, as well as grad scaler ###
        if 'checkpoint_epx' in config and config['checkpoint_epx'] is not None:
            epx_start = config['checkpoint_epx']
        else:
            epx_start = 0

        if config.checkpoint_name:
            # Load from checkpoint
            _path = f"{config.mdl_save_prefix}/{config.checkpoint_name}_fold{fold_idx}_epx{epx_start}"
        else:
            _path = f"{config.mdl_save_prefix}/{wandb.run.name}_fold{fold_idx}_epx{epx_start}"

        (model, optimizer, scheduler, scaler) = get_model(config, len(training_dataset), len(training_dataset.label_tags),
            THIS_SCRIPT_DIR=THIS_SCRIPT_DIR, _path=_path, device=config.device)

        all_bn_counts = torch.zeros([len(training_dataset.label_tags)], device='cpu')

        for bn_counts in training_dataset.bincounts_3d.values():
            all_bn_counts += bn_counts

        class_weights = 1 / (all_bn_counts).float().pow(.35)
        class_weights /= class_weights.mean()

        class_weights = class_weights.to(device=config.device)

        autocast_enabled = 'cuda' in config.device

        for epx in range(epx_start, config.epochs):
            global_idx = get_global_idx(fold_idx, epx, config.epochs)

            model.train()

            ### Disturb samples ###
            training_dataset.train(use_modified=False)

            epx_losses = []
            dices = []
            class_dices = []

            # Load data
            for batch_idx, batch in tqdm(enumerate(train_dataloader), desc="batch:", total=len(train_dataloader)):

                optimizer.zero_grad()

                b_input, b_seg = get_model_input(batch, config, len(training_dataset.label_tags))

                ### Forward pass ###
                with amp.autocast(enabled=autocast_enabled):
                    assert b_input.dim() == len(n_dims)+2, \
                        f"Input image for model must be {len(n_dims)+2}D: BxCxSPATIAL but is {b_input.shape}"
                    for param in model.parameters():
                        param.requires_grad = True

                    model.use_checkpointing = True
                    y_hat, (z, mean, std) = model(b_input)
                    logits = y_hat

                    ### Calculate loss ###
                    assert logits.dim() == len(n_dims)+2, \
                        f"Input shape for loss must be BxNUM_CLASSESxSPATIAL but is {logits.shape}"
                    assert b_seg.dim() == len(n_dims)+1, \
                        f"Target shape for loss must be BxSPATIAL but is {b_seg.shape}"

                    # ce_loss = nn.CrossEntropyLoss(class_weights)(logits, b_seg)

                    # Reconstruction loss
                    recon_loss = gaussian_likelihood(y_hat, model.log_var_scale, b_seg)

                    # kl
                    kl = kl_divergence(z, mean, std)

                    # elbo
                    elbo = (kl - recon_loss)
                    # print("ls", elbo, kl, recon_loss)
                    elbo = elbo.mean()
                    vae_loss = elbo

                    scaler.scale(vae_loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    epx_losses.append(vae_loss.item())

                logits_for_score = logits.argmax(1)

                # Calculate dice score
                b_dice = dice3d(
                    torch.nn.functional.one_hot(logits_for_score, len(training_dataset.label_tags)),
                    torch.nn.functional.one_hot(b_seg, len(training_dataset.label_tags)), # Calculate dice score with original segmentation (no disturbance)
                    one_hot_torch_style=True
                )

                dices.append(get_batch_dice_over_all(
                    b_dice, exclude_bg=True))
                class_dices.append(get_batch_dice_per_class(
                    b_dice, training_dataset.label_tags, exclude_bg=True))


                if config.debug:
                    break

            ###  Scheduler management ###
            if config.use_scheduling:
                scheduler.step(vae_loss)

            ### Logging ###
            print(f"### Log epoch {epx}")
            print("### Training")

            ### Log wandb data ###
            # Log the epoch idx per fold - so we can recover the diagram by setting
            # ref_epoch_idx as x-axis in wandb interface
            wandb.log({"ref_epoch_idx": epx}, step=global_idx)

            mean_loss = torch.tensor(epx_losses).mean()
            wandb.log({f'losses/loss_fold{fold_idx}': mean_loss}, step=global_idx)
            print(f'losses/loss_fold{fold_idx}', f"{mean_loss}")

            mean_dice = np.nanmean(dices)
            print(f'dice_mean_wo_bg_fold{fold_idx}', f"{mean_dice*100:.2f}%")
            wandb.log({f'scores/dice_mean_wo_bg_fold{fold_idx}': mean_dice}, step=global_idx)

            log_class_dices("scores/dice_mean_", f"_fold{fold_idx}", class_dices, global_idx)

            print()
            print("### Validation")
            model.eval()
            training_dataset.eval()

            val_dices = []
            val_class_dices = []

            with amp.autocast(enabled=autocast_enabled):
                with torch.no_grad():
                    for val_batch_idx, val_batch in tqdm(enumerate(val_dataloader), desc="batch:", total=len(val_dataloader)):

                        b_val_input, b_val_seg = get_model_input(val_batch, config, len(training_dataset.label_tags))

                        output_val = model(b_val_input)[0]
                        val_logits_for_score = output_val.argmax(1)

                        b_val_dice = dice3d(
                            torch.nn.functional.one_hot(val_logits_for_score, len(training_dataset.label_tags)),
                            torch.nn.functional.one_hot(b_val_seg, len(training_dataset.label_tags)),
                            one_hot_torch_style=True
                        )

                        # Get mean score over batch
                        val_dices.append(get_batch_dice_over_all(
                            b_val_dice, exclude_bg=True))

                        val_class_dices.append(get_batch_dice_per_class(
                            b_val_dice, training_dataset.label_tags, exclude_bg=True))

                    mean_val_dice = np.nanmean(val_dices)

                    print(f'val_dice_mean_wo_bg_fold{fold_idx}', f"{mean_val_dice*100:.2f}%")
                    wandb.log({f'scores/val_dice_mean_wo_bg_fold{fold_idx}': mean_val_dice}, step=global_idx)
                    log_class_dices("scores/val_dice_mean_", f"_fold{fold_idx}", val_class_dices, global_idx)

            print()

            # Save model
            if config.save_every is None:
                pass

            elif config.save_every == 'best':
                if mean_val_dice > best_val_score:
                    best_val_score = mean_val_dice
                    save_path = f"{config.mdl_save_prefix}/{wandb.run.name}_fold{fold_idx}_best"
                    save_model(
                        Path(THIS_SCRIPT_DIR, save_path),
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        scaler=scaler)

            elif (epx % config.save_every == 0) or (epx+1 == config.epochs):
                save_path = f"{config.mdl_save_prefix}/{wandb.run.name}_fold{fold_idx}_epx{epx}"
                save_model(
                    Path(THIS_SCRIPT_DIR, save_path),
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    scaler=scaler)

                # (model, optimizer, scheduler, scaler) = \
                #     get_model(
                #         config, len(training_dataset),
                #         len(training_dataset.label_tags),
                #         THIS_SCRIPT_DIR=THIS_SCRIPT_DIR,
                #         _path=_path, device=config.device)

            # End of training loop

            if config.debug:
                break

        # End of fold loop

In [39]:
# Config overrides
# config_dict['wandb_mode'] = 'disabled'
# config_dict['debug'] = True
# Model loading
# config_dict['checkpoint_name'] = 'ethereal-serenity-1138'
# config_dict['fold_override'] = 0
# config_dict['checkpoint_epx'] = 39

# Define sweep override dict
sweep_config_dict = dict(
    method='grid',
    metric=dict(goal='maximize', name='scores/val_dice_mean_left_atrium_fold0'),
    parameters=dict(
        # disturbance_mode=dict(
        #     values=[
        #        'LabelDisturbanceMode.AFFINE',
        #     ]
        # ),
        # disturbance_strength=dict(
        #     values=[0.1, 0.2, 0.5, 1.0, 2.0, 5.0]
        # ),
        # disturbed_percentage=dict(
        #     values=[0.3, 0.6]
        # ),
        # data_param_mode=dict(
        #     values=[
        #         DataParamMode.INSTANCE_PARAMS,
        #         DataParamMode.DISABLED,
        #     ]
        # ),
        use_risk_regularization=dict(
            values=[False, True]
        ),
        use_fixed_weighting=dict(
            values=[False, True]
        ),
        # fixed_weight_min_quantile=dict(
        #     values=[0.9, 0.8, 0.6, 0.4, 0.2, 0.0]
        # ),
    )
)

In [40]:

def normal_run():
    with wandb.init(project=PROJECT_NAME, group="training", job_type="train",
            config=config_dict, settings=wandb.Settings(start_method="thread"),
            mode=config_dict['wandb_mode']
        ) as run:

        run_name = run.name
        print("Running", run_name)
        # training_dataset = prepare_data(config_dict)
        config = wandb.config

        train_DL(run_name, config, training_dataset)

def sweep_run():
    with wandb.init() as run:
        run = wandb.init(
            settings=wandb.Settings(start_method="thread"),
            mode=config_dict['wandb_mode']
        )

        run_name = run.name
        print("Running", run_name)
        training_dataset = prepare_data(config)
        config = wandb.config

        train_DL(run_name, config, training_dataset)

if config_dict['do_sweep']:
    # Integrate all config_dict entries into sweep_dict.parameters -> sweep overrides config_dict
    cp_config_dict = copy.deepcopy(config_dict)
    # cp_config_dict.update(copy.deepcopy(sweep_config_dict['parameters']))
    for del_key in sweep_config_dict['parameters'].keys():
        if del_key in cp_config_dict:
            del cp_config_dict[del_key]
    merged_sweep_config_dict = copy.deepcopy(sweep_config_dict)
    # merged_sweep_config_dict.update(cp_config_dict)
    for key, value in cp_config_dict.items():
        merged_sweep_config_dict['parameters'][key] = dict(value=value)
    # Convert enum values in parameters to string. They will be identified by their numerical index otherwise
    for key, param_dict in merged_sweep_config_dict['parameters'].items():
        if 'value' in param_dict and isinstance(param_dict['value'], Enum):
            param_dict['value'] = str(param_dict['value'])
        if 'values' in param_dict:
            param_dict['values'] = [str(elem) if isinstance(elem, Enum) else elem for elem in param_dict['values']]

        merged_sweep_config_dict['parameters'][key] = param_dict

    sweep_id = wandb.sweep(merged_sweep_config_dict, project=PROJECT_NAME)
    wandb.agent(sweep_id, function=sweep_run)

else:
    normal_run()

Running dummy-EeuBbksLHDWAQXSzMw7Png
Will run validation with these 3D samples (#4): ['1001-mr', '1002-mr', '1003-mr', '1004-mr']
Param count model: 705933
Generating fresh 'BlendowskiVAE' model, optimizer and grad scaler.


batch:: 100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


### Log epoch 0
### Training
losses/loss_fold0 0.9365672469139099
dice_mean_wo_bg_fold0 3.12%
scores/dice_mean_left_myocardium_fold0 3.78%
scores/dice_mean_left_atrium_fold0 1.80%
scores/dice_mean_left_ventricle_fold0 2.27%
scores/dice_mean_right_atrium_fold0 4.07%
scores/dice_mean_right_ventricle_fold0 3.66%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


val_dice_mean_wo_bg_fold0 2.90%
scores/val_dice_mean_left_myocardium_fold0 3.72%
scores/val_dice_mean_left_atrium_fold0 1.87%
scores/val_dice_mean_left_ventricle_fold0 1.41%
scores/val_dice_mean_right_atrium_fold0 3.33%
scores/val_dice_mean_right_ventricle_fold0 4.16%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.99s/it]


### Log epoch 1
### Training
losses/loss_fold0 0.9364902973175049
dice_mean_wo_bg_fold0 3.22%
scores/dice_mean_left_myocardium_fold0 3.39%
scores/dice_mean_left_atrium_fold0 1.97%
scores/dice_mean_left_ventricle_fold0 3.06%
scores/dice_mean_right_atrium_fold0 4.05%
scores/dice_mean_right_ventricle_fold0 3.65%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


val_dice_mean_wo_bg_fold0 2.58%
scores/val_dice_mean_left_myocardium_fold0 3.88%
scores/val_dice_mean_left_atrium_fold0 1.24%
scores/val_dice_mean_left_ventricle_fold0 2.21%
scores/val_dice_mean_right_atrium_fold0 3.09%
scores/val_dice_mean_right_ventricle_fold0 2.49%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


### Log epoch 2
### Training
losses/loss_fold0 0.9391302466392517
dice_mean_wo_bg_fold0 3.13%
scores/dice_mean_left_myocardium_fold0 3.60%
scores/dice_mean_left_atrium_fold0 1.95%
scores/dice_mean_left_ventricle_fold0 2.09%
scores/dice_mean_right_atrium_fold0 4.08%
scores/dice_mean_right_ventricle_fold0 3.92%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.52it/s]


val_dice_mean_wo_bg_fold0 2.55%
scores/val_dice_mean_left_myocardium_fold0 3.72%
scores/val_dice_mean_left_atrium_fold0 1.35%
scores/val_dice_mean_left_ventricle_fold0 2.56%
scores/val_dice_mean_right_atrium_fold0 2.79%
scores/val_dice_mean_right_ventricle_fold0 2.34%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


### Log epoch 3
### Training
losses/loss_fold0 0.9360880851745605
dice_mean_wo_bg_fold0 3.15%
scores/dice_mean_left_myocardium_fold0 3.72%
scores/dice_mean_left_atrium_fold0 2.02%
scores/dice_mean_left_ventricle_fold0 2.27%
scores/dice_mean_right_atrium_fold0 3.79%
scores/dice_mean_right_ventricle_fold0 3.96%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


val_dice_mean_wo_bg_fold0 2.91%
scores/val_dice_mean_left_myocardium_fold0 3.64%
scores/val_dice_mean_left_atrium_fold0 1.41%
scores/val_dice_mean_left_ventricle_fold0 2.83%
scores/val_dice_mean_right_atrium_fold0 3.52%
scores/val_dice_mean_right_ventricle_fold0 3.17%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.95s/it]


### Log epoch 4
### Training
losses/loss_fold0 0.93594890832901
dice_mean_wo_bg_fold0 3.25%
scores/dice_mean_left_myocardium_fold0 3.39%
scores/dice_mean_left_atrium_fold0 2.22%
scores/dice_mean_left_ventricle_fold0 2.76%
scores/dice_mean_right_atrium_fold0 3.89%
scores/dice_mean_right_ventricle_fold0 3.98%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


val_dice_mean_wo_bg_fold0 2.69%
scores/val_dice_mean_left_myocardium_fold0 3.91%
scores/val_dice_mean_left_atrium_fold0 1.90%
scores/val_dice_mean_left_ventricle_fold0 2.01%
scores/val_dice_mean_right_atrium_fold0 3.01%
scores/val_dice_mean_right_ventricle_fold0 2.64%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.97s/it]


### Log epoch 5
### Training
losses/loss_fold0 0.9378087520599365
dice_mean_wo_bg_fold0 3.11%
scores/dice_mean_left_myocardium_fold0 3.31%
scores/dice_mean_left_atrium_fold0 1.94%
scores/dice_mean_left_ventricle_fold0 2.54%
scores/dice_mean_right_atrium_fold0 3.87%
scores/dice_mean_right_ventricle_fold0 3.90%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.50it/s]


val_dice_mean_wo_bg_fold0 2.61%
scores/val_dice_mean_left_myocardium_fold0 4.00%
scores/val_dice_mean_left_atrium_fold0 1.50%
scores/val_dice_mean_left_ventricle_fold0 1.32%
scores/val_dice_mean_right_atrium_fold0 2.90%
scores/val_dice_mean_right_ventricle_fold0 3.32%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


### Log epoch 6
### Training
losses/loss_fold0 0.9361395835876465
dice_mean_wo_bg_fold0 3.31%
scores/dice_mean_left_myocardium_fold0 4.02%
scores/dice_mean_left_atrium_fold0 2.45%
scores/dice_mean_left_ventricle_fold0 2.31%
scores/dice_mean_right_atrium_fold0 4.05%
scores/dice_mean_right_ventricle_fold0 3.73%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


val_dice_mean_wo_bg_fold0 2.60%
scores/val_dice_mean_left_myocardium_fold0 3.66%
scores/val_dice_mean_left_atrium_fold0 0.95%
scores/val_dice_mean_left_ventricle_fold0 2.38%
scores/val_dice_mean_right_atrium_fold0 2.93%
scores/val_dice_mean_right_ventricle_fold0 3.09%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.03s/it]


### Log epoch 7
### Training
losses/loss_fold0 0.9398614168167114
dice_mean_wo_bg_fold0 3.08%
scores/dice_mean_left_myocardium_fold0 3.51%
scores/dice_mean_left_atrium_fold0 2.11%
scores/dice_mean_left_ventricle_fold0 2.11%
scores/dice_mean_right_atrium_fold0 3.91%
scores/dice_mean_right_ventricle_fold0 3.78%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s]


val_dice_mean_wo_bg_fold0 2.37%
scores/val_dice_mean_left_myocardium_fold0 3.61%
scores/val_dice_mean_left_atrium_fold0 1.43%
scores/val_dice_mean_left_ventricle_fold0 0.82%
scores/val_dice_mean_right_atrium_fold0 3.10%
scores/val_dice_mean_right_ventricle_fold0 2.90%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.97s/it]


### Log epoch 8
### Training
losses/loss_fold0 0.9344966411590576
dice_mean_wo_bg_fold0 3.01%
scores/dice_mean_left_myocardium_fold0 3.73%
scores/dice_mean_left_atrium_fold0 1.58%
scores/dice_mean_left_ventricle_fold0 2.28%
scores/dice_mean_right_atrium_fold0 3.89%
scores/dice_mean_right_ventricle_fold0 3.59%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.54it/s]


val_dice_mean_wo_bg_fold0 2.51%
scores/val_dice_mean_left_myocardium_fold0 3.59%
scores/val_dice_mean_left_atrium_fold0 1.89%
scores/val_dice_mean_left_ventricle_fold0 1.19%
scores/val_dice_mean_right_atrium_fold0 2.41%
scores/val_dice_mean_right_ventricle_fold0 3.47%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.94s/it]


### Log epoch 9
### Training
losses/loss_fold0 0.9402871131896973
dice_mean_wo_bg_fold0 3.30%
scores/dice_mean_left_myocardium_fold0 3.47%
scores/dice_mean_left_atrium_fold0 2.10%
scores/dice_mean_left_ventricle_fold0 3.48%
scores/dice_mean_right_atrium_fold0 3.75%
scores/dice_mean_right_ventricle_fold0 3.68%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


val_dice_mean_wo_bg_fold0 2.88%
scores/val_dice_mean_left_myocardium_fold0 3.15%
scores/val_dice_mean_left_atrium_fold0 2.43%
scores/val_dice_mean_left_ventricle_fold0 2.64%
scores/val_dice_mean_right_atrium_fold0 3.37%
scores/val_dice_mean_right_ventricle_fold0 2.81%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


### Log epoch 10
### Training
losses/loss_fold0 0.9408731460571289
dice_mean_wo_bg_fold0 3.24%
scores/dice_mean_left_myocardium_fold0 3.75%
scores/dice_mean_left_atrium_fold0 1.92%
scores/dice_mean_left_ventricle_fold0 2.45%
scores/dice_mean_right_atrium_fold0 4.12%
scores/dice_mean_right_ventricle_fold0 3.94%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.34%
scores/val_dice_mean_left_myocardium_fold0 3.53%
scores/val_dice_mean_left_atrium_fold0 1.03%
scores/val_dice_mean_left_ventricle_fold0 1.44%
scores/val_dice_mean_right_atrium_fold0 2.48%
scores/val_dice_mean_right_ventricle_fold0 3.24%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


### Log epoch 11
### Training
losses/loss_fold0 0.9402219653129578
dice_mean_wo_bg_fold0 3.17%
scores/dice_mean_left_myocardium_fold0 3.33%
scores/dice_mean_left_atrium_fold0 2.14%
scores/dice_mean_left_ventricle_fold0 3.07%
scores/dice_mean_right_atrium_fold0 3.65%
scores/dice_mean_right_ventricle_fold0 3.65%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


val_dice_mean_wo_bg_fold0 3.24%
scores/val_dice_mean_left_myocardium_fold0 3.59%
scores/val_dice_mean_left_atrium_fold0 2.20%
scores/val_dice_mean_left_ventricle_fold0 3.63%
scores/val_dice_mean_right_atrium_fold0 3.63%
scores/val_dice_mean_right_ventricle_fold0 3.17%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.00s/it]


### Log epoch 12
### Training
losses/loss_fold0 0.9357752203941345
dice_mean_wo_bg_fold0 3.22%
scores/dice_mean_left_myocardium_fold0 3.64%
scores/dice_mean_left_atrium_fold0 2.11%
scores/dice_mean_left_ventricle_fold0 2.68%
scores/dice_mean_right_atrium_fold0 3.87%
scores/dice_mean_right_ventricle_fold0 3.81%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


val_dice_mean_wo_bg_fold0 2.38%
scores/val_dice_mean_left_myocardium_fold0 3.67%
scores/val_dice_mean_left_atrium_fold0 1.30%
scores/val_dice_mean_left_ventricle_fold0 1.44%
scores/val_dice_mean_right_atrium_fold0 2.88%
scores/val_dice_mean_right_ventricle_fold0 2.62%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.95s/it]


### Log epoch 13
### Training
losses/loss_fold0 0.9405025243759155
dice_mean_wo_bg_fold0 2.92%
scores/dice_mean_left_myocardium_fold0 3.50%
scores/dice_mean_left_atrium_fold0 1.96%
scores/dice_mean_left_ventricle_fold0 1.73%
scores/dice_mean_right_atrium_fold0 3.65%
scores/dice_mean_right_ventricle_fold0 3.76%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


val_dice_mean_wo_bg_fold0 2.48%
scores/val_dice_mean_left_myocardium_fold0 4.01%
scores/val_dice_mean_left_atrium_fold0 1.80%
scores/val_dice_mean_left_ventricle_fold0 1.16%
scores/val_dice_mean_right_atrium_fold0 2.88%
scores/val_dice_mean_right_ventricle_fold0 2.57%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


### Log epoch 14
### Training
losses/loss_fold0 0.9377096891403198
dice_mean_wo_bg_fold0 3.04%
scores/dice_mean_left_myocardium_fold0 3.82%
scores/dice_mean_left_atrium_fold0 1.75%
scores/dice_mean_left_ventricle_fold0 1.95%
scores/dice_mean_right_atrium_fold0 3.92%
scores/dice_mean_right_ventricle_fold0 3.78%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.57%
scores/val_dice_mean_left_myocardium_fold0 4.02%
scores/val_dice_mean_left_atrium_fold0 1.59%
scores/val_dice_mean_left_ventricle_fold0 1.15%
scores/val_dice_mean_right_atrium_fold0 3.48%
scores/val_dice_mean_right_ventricle_fold0 2.60%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


### Log epoch 15
### Training
losses/loss_fold0 0.9377087354660034
dice_mean_wo_bg_fold0 3.30%
scores/dice_mean_left_myocardium_fold0 3.73%
scores/dice_mean_left_atrium_fold0 2.03%
scores/dice_mean_left_ventricle_fold0 2.88%
scores/dice_mean_right_atrium_fold0 3.81%
scores/dice_mean_right_ventricle_fold0 4.02%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


val_dice_mean_wo_bg_fold0 2.90%
scores/val_dice_mean_left_myocardium_fold0 3.61%
scores/val_dice_mean_left_atrium_fold0 2.02%
scores/val_dice_mean_left_ventricle_fold0 2.44%
scores/val_dice_mean_right_atrium_fold0 3.42%
scores/val_dice_mean_right_ventricle_fold0 3.02%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.97s/it]


### Log epoch 16
### Training
losses/loss_fold0 0.940077006816864
dice_mean_wo_bg_fold0 3.12%
scores/dice_mean_left_myocardium_fold0 3.73%
scores/dice_mean_left_atrium_fold0 1.83%
scores/dice_mean_left_ventricle_fold0 2.36%
scores/dice_mean_right_atrium_fold0 4.33%
scores/dice_mean_right_ventricle_fold0 3.36%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


val_dice_mean_wo_bg_fold0 2.62%
scores/val_dice_mean_left_myocardium_fold0 3.68%
scores/val_dice_mean_left_atrium_fold0 1.21%
scores/val_dice_mean_left_ventricle_fold0 2.99%
scores/val_dice_mean_right_atrium_fold0 2.49%
scores/val_dice_mean_right_ventricle_fold0 2.76%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.97s/it]


### Log epoch 17
### Training
losses/loss_fold0 0.93935227394104
dice_mean_wo_bg_fold0 3.20%
scores/dice_mean_left_myocardium_fold0 3.62%
scores/dice_mean_left_atrium_fold0 1.97%
scores/dice_mean_left_ventricle_fold0 2.42%
scores/dice_mean_right_atrium_fold0 4.09%
scores/dice_mean_right_ventricle_fold0 3.91%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


val_dice_mean_wo_bg_fold0 2.65%
scores/val_dice_mean_left_myocardium_fold0 3.44%
scores/val_dice_mean_left_atrium_fold0 1.66%
scores/val_dice_mean_left_ventricle_fold0 2.55%
scores/val_dice_mean_right_atrium_fold0 2.33%
scores/val_dice_mean_right_ventricle_fold0 3.26%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.01s/it]


### Log epoch 18
### Training
losses/loss_fold0 0.9408173561096191
dice_mean_wo_bg_fold0 3.07%
scores/dice_mean_left_myocardium_fold0 3.59%
scores/dice_mean_left_atrium_fold0 1.86%
scores/dice_mean_left_ventricle_fold0 2.18%
scores/dice_mean_right_atrium_fold0 3.72%
scores/dice_mean_right_ventricle_fold0 4.02%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.72%
scores/val_dice_mean_left_myocardium_fold0 3.64%
scores/val_dice_mean_left_atrium_fold0 1.29%
scores/val_dice_mean_left_ventricle_fold0 2.10%
scores/val_dice_mean_right_atrium_fold0 3.23%
scores/val_dice_mean_right_ventricle_fold0 3.32%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


### Log epoch 19
### Training
losses/loss_fold0 0.9377627372741699
dice_mean_wo_bg_fold0 3.34%
scores/dice_mean_left_myocardium_fold0 3.71%
scores/dice_mean_left_atrium_fold0 2.14%
scores/dice_mean_left_ventricle_fold0 2.86%
scores/dice_mean_right_atrium_fold0 4.18%
scores/dice_mean_right_ventricle_fold0 3.78%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


val_dice_mean_wo_bg_fold0 2.55%
scores/val_dice_mean_left_myocardium_fold0 3.68%
scores/val_dice_mean_left_atrium_fold0 1.57%
scores/val_dice_mean_left_ventricle_fold0 1.06%
scores/val_dice_mean_right_atrium_fold0 3.17%
scores/val_dice_mean_right_ventricle_fold0 3.26%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.98s/it]


### Log epoch 20
### Training
losses/loss_fold0 0.9356688261032104
dice_mean_wo_bg_fold0 3.13%
scores/dice_mean_left_myocardium_fold0 3.71%
scores/dice_mean_left_atrium_fold0 2.20%
scores/dice_mean_left_ventricle_fold0 2.20%
scores/dice_mean_right_atrium_fold0 3.79%
scores/dice_mean_right_ventricle_fold0 3.78%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


val_dice_mean_wo_bg_fold0 2.59%
scores/val_dice_mean_left_myocardium_fold0 3.72%
scores/val_dice_mean_left_atrium_fold0 1.72%
scores/val_dice_mean_left_ventricle_fold0 1.98%
scores/val_dice_mean_right_atrium_fold0 2.58%
scores/val_dice_mean_right_ventricle_fold0 2.95%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.93s/it]


### Log epoch 21
### Training
losses/loss_fold0 0.9373108148574829
dice_mean_wo_bg_fold0 3.01%
scores/dice_mean_left_myocardium_fold0 3.82%
scores/dice_mean_left_atrium_fold0 1.72%
scores/dice_mean_left_ventricle_fold0 2.03%
scores/dice_mean_right_atrium_fold0 3.72%
scores/dice_mean_right_ventricle_fold0 3.76%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


val_dice_mean_wo_bg_fold0 2.62%
scores/val_dice_mean_left_myocardium_fold0 3.24%
scores/val_dice_mean_left_atrium_fold0 1.21%
scores/val_dice_mean_left_ventricle_fold0 3.03%
scores/val_dice_mean_right_atrium_fold0 2.65%
scores/val_dice_mean_right_ventricle_fold0 2.98%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


### Log epoch 22
### Training
losses/loss_fold0 0.9388963580131531
dice_mean_wo_bg_fold0 3.24%
scores/dice_mean_left_myocardium_fold0 3.73%
scores/dice_mean_left_atrium_fold0 2.24%
scores/dice_mean_left_ventricle_fold0 2.50%
scores/dice_mean_right_atrium_fold0 3.89%
scores/dice_mean_right_ventricle_fold0 3.83%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.53%
scores/val_dice_mean_left_myocardium_fold0 3.33%
scores/val_dice_mean_left_atrium_fold0 1.85%
scores/val_dice_mean_left_ventricle_fold0 2.16%
scores/val_dice_mean_right_atrium_fold0 2.67%
scores/val_dice_mean_right_ventricle_fold0 2.67%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


### Log epoch 23
### Training
losses/loss_fold0 0.9394581913948059
dice_mean_wo_bg_fold0 3.03%
scores/dice_mean_left_myocardium_fold0 3.59%
scores/dice_mean_left_atrium_fold0 1.91%
scores/dice_mean_left_ventricle_fold0 1.82%
scores/dice_mean_right_atrium_fold0 4.06%
scores/dice_mean_right_ventricle_fold0 3.75%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


val_dice_mean_wo_bg_fold0 2.96%
scores/val_dice_mean_left_myocardium_fold0 3.66%
scores/val_dice_mean_left_atrium_fold0 2.04%
scores/val_dice_mean_left_ventricle_fold0 3.15%
scores/val_dice_mean_right_atrium_fold0 2.83%
scores/val_dice_mean_right_ventricle_fold0 3.10%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.00s/it]


### Log epoch 24
### Training
losses/loss_fold0 0.9394792318344116
dice_mean_wo_bg_fold0 3.17%
scores/dice_mean_left_myocardium_fold0 3.55%
scores/dice_mean_left_atrium_fold0 2.20%
scores/dice_mean_left_ventricle_fold0 2.24%
scores/dice_mean_right_atrium_fold0 4.11%
scores/dice_mean_right_ventricle_fold0 3.75%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


val_dice_mean_wo_bg_fold0 2.69%
scores/val_dice_mean_left_myocardium_fold0 3.53%
scores/val_dice_mean_left_atrium_fold0 1.58%
scores/val_dice_mean_left_ventricle_fold0 3.21%
scores/val_dice_mean_right_atrium_fold0 2.26%
scores/val_dice_mean_right_ventricle_fold0 2.89%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.94s/it]


### Log epoch 25
### Training
losses/loss_fold0 0.9413373470306396
dice_mean_wo_bg_fold0 2.92%
scores/dice_mean_left_myocardium_fold0 3.90%
scores/dice_mean_left_atrium_fold0 1.69%
scores/dice_mean_left_ventricle_fold0 1.54%
scores/dice_mean_right_atrium_fold0 3.58%
scores/dice_mean_right_ventricle_fold0 3.88%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


val_dice_mean_wo_bg_fold0 2.63%
scores/val_dice_mean_left_myocardium_fold0 3.35%
scores/val_dice_mean_left_atrium_fold0 1.47%
scores/val_dice_mean_left_ventricle_fold0 1.74%
scores/val_dice_mean_right_atrium_fold0 3.20%
scores/val_dice_mean_right_ventricle_fold0 3.37%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.03s/it]


### Log epoch 26
### Training
losses/loss_fold0 0.9414917826652527
dice_mean_wo_bg_fold0 3.03%
scores/dice_mean_left_myocardium_fold0 3.42%
scores/dice_mean_left_atrium_fold0 2.01%
scores/dice_mean_left_ventricle_fold0 2.06%
scores/dice_mean_right_atrium_fold0 3.75%
scores/dice_mean_right_ventricle_fold0 3.94%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.59%
scores/val_dice_mean_left_myocardium_fold0 3.34%
scores/val_dice_mean_left_atrium_fold0 1.44%
scores/val_dice_mean_left_ventricle_fold0 2.37%
scores/val_dice_mean_right_atrium_fold0 2.44%
scores/val_dice_mean_right_ventricle_fold0 3.36%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


### Log epoch 27
### Training
losses/loss_fold0 0.9379237294197083
dice_mean_wo_bg_fold0 3.11%
scores/dice_mean_left_myocardium_fold0 3.56%
scores/dice_mean_left_atrium_fold0 2.07%
scores/dice_mean_left_ventricle_fold0 2.13%
scores/dice_mean_right_atrium_fold0 3.82%
scores/dice_mean_right_ventricle_fold0 3.97%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.84%
scores/val_dice_mean_left_myocardium_fold0 3.54%
scores/val_dice_mean_left_atrium_fold0 1.98%
scores/val_dice_mean_left_ventricle_fold0 2.39%
scores/val_dice_mean_right_atrium_fold0 2.98%
scores/val_dice_mean_right_ventricle_fold0 3.31%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


### Log epoch 28
### Training
losses/loss_fold0 0.9413904547691345
dice_mean_wo_bg_fold0 3.22%
scores/dice_mean_left_myocardium_fold0 3.71%
scores/dice_mean_left_atrium_fold0 1.88%
scores/dice_mean_left_ventricle_fold0 2.78%
scores/dice_mean_right_atrium_fold0 3.82%
scores/dice_mean_right_ventricle_fold0 3.91%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


val_dice_mean_wo_bg_fold0 2.73%
scores/val_dice_mean_left_myocardium_fold0 3.74%
scores/val_dice_mean_left_atrium_fold0 1.47%
scores/val_dice_mean_left_ventricle_fold0 2.48%
scores/val_dice_mean_right_atrium_fold0 3.15%
scores/val_dice_mean_right_ventricle_fold0 2.82%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.94s/it]


### Log epoch 29
### Training
losses/loss_fold0 0.93989098072052
dice_mean_wo_bg_fold0 3.03%
scores/dice_mean_left_myocardium_fold0 3.71%
scores/dice_mean_left_atrium_fold0 1.68%
scores/dice_mean_left_ventricle_fold0 2.77%
scores/dice_mean_right_atrium_fold0 3.58%
scores/dice_mean_right_ventricle_fold0 3.40%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.54it/s]


val_dice_mean_wo_bg_fold0 2.42%
scores/val_dice_mean_left_myocardium_fold0 3.79%
scores/val_dice_mean_left_atrium_fold0 1.14%
scores/val_dice_mean_left_ventricle_fold0 1.13%
scores/val_dice_mean_right_atrium_fold0 3.32%
scores/val_dice_mean_right_ventricle_fold0 2.71%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


### Log epoch 30
### Training
losses/loss_fold0 0.932558536529541
dice_mean_wo_bg_fold0 3.10%
scores/dice_mean_left_myocardium_fold0 3.62%
scores/dice_mean_left_atrium_fold0 1.93%
scores/dice_mean_left_ventricle_fold0 2.35%
scores/dice_mean_right_atrium_fold0 4.05%
scores/dice_mean_right_ventricle_fold0 3.53%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


val_dice_mean_wo_bg_fold0 2.64%
scores/val_dice_mean_left_myocardium_fold0 3.64%
scores/val_dice_mean_left_atrium_fold0 2.35%
scores/val_dice_mean_left_ventricle_fold0 0.98%
scores/val_dice_mean_right_atrium_fold0 2.85%
scores/val_dice_mean_right_ventricle_fold0 3.41%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


### Log epoch 31
### Training
losses/loss_fold0 0.9417567849159241
dice_mean_wo_bg_fold0 3.32%
scores/dice_mean_left_myocardium_fold0 3.68%
scores/dice_mean_left_atrium_fold0 1.99%
scores/dice_mean_left_ventricle_fold0 2.78%
scores/dice_mean_right_atrium_fold0 4.06%
scores/dice_mean_right_ventricle_fold0 4.07%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


val_dice_mean_wo_bg_fold0 2.59%
scores/val_dice_mean_left_myocardium_fold0 3.80%
scores/val_dice_mean_left_atrium_fold0 1.65%
scores/val_dice_mean_left_ventricle_fold0 1.79%
scores/val_dice_mean_right_atrium_fold0 2.73%
scores/val_dice_mean_right_ventricle_fold0 2.95%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.99s/it]


### Log epoch 32
### Training
losses/loss_fold0 0.9402580261230469
dice_mean_wo_bg_fold0 3.27%
scores/dice_mean_left_myocardium_fold0 3.69%
scores/dice_mean_left_atrium_fold0 1.80%
scores/dice_mean_left_ventricle_fold0 3.19%
scores/dice_mean_right_atrium_fold0 4.02%
scores/dice_mean_right_ventricle_fold0 3.66%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


val_dice_mean_wo_bg_fold0 2.60%
scores/val_dice_mean_left_myocardium_fold0 3.34%
scores/val_dice_mean_left_atrium_fold0 1.66%
scores/val_dice_mean_left_ventricle_fold0 1.90%
scores/val_dice_mean_right_atrium_fold0 2.92%
scores/val_dice_mean_right_ventricle_fold0 3.18%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.96s/it]


Epoch 00034: reducing learning rate of group 0 to 1.0000e-07.
### Log epoch 33
### Training
losses/loss_fold0 0.9379405975341797
dice_mean_wo_bg_fold0 3.27%
scores/dice_mean_left_myocardium_fold0 3.81%
scores/dice_mean_left_atrium_fold0 1.97%
scores/dice_mean_left_ventricle_fold0 2.95%
scores/dice_mean_right_atrium_fold0 3.89%
scores/dice_mean_right_ventricle_fold0 3.74%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.54it/s]


val_dice_mean_wo_bg_fold0 2.66%
scores/val_dice_mean_left_myocardium_fold0 3.88%
scores/val_dice_mean_left_atrium_fold0 2.01%
scores/val_dice_mean_left_ventricle_fold0 2.02%
scores/val_dice_mean_right_atrium_fold0 2.32%
scores/val_dice_mean_right_ventricle_fold0 3.09%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.99s/it]


### Log epoch 34
### Training
losses/loss_fold0 0.943678617477417
dice_mean_wo_bg_fold0 3.01%
scores/dice_mean_left_myocardium_fold0 3.48%
scores/dice_mean_left_atrium_fold0 1.96%
scores/dice_mean_left_ventricle_fold0 2.02%
scores/dice_mean_right_atrium_fold0 3.99%
scores/dice_mean_right_ventricle_fold0 3.58%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.54%
scores/val_dice_mean_left_myocardium_fold0 3.47%
scores/val_dice_mean_left_atrium_fold0 1.81%
scores/val_dice_mean_left_ventricle_fold0 1.06%
scores/val_dice_mean_right_atrium_fold0 3.09%
scores/val_dice_mean_right_ventricle_fold0 3.28%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.08s/it]


### Log epoch 35
### Training
losses/loss_fold0 0.9380826354026794
dice_mean_wo_bg_fold0 3.05%
scores/dice_mean_left_myocardium_fold0 3.63%
scores/dice_mean_left_atrium_fold0 1.92%
scores/dice_mean_left_ventricle_fold0 2.19%
scores/dice_mean_right_atrium_fold0 3.70%
scores/dice_mean_right_ventricle_fold0 3.82%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


val_dice_mean_wo_bg_fold0 2.70%
scores/val_dice_mean_left_myocardium_fold0 3.50%
scores/val_dice_mean_left_atrium_fold0 1.98%
scores/val_dice_mean_left_ventricle_fold0 2.07%
scores/val_dice_mean_right_atrium_fold0 3.63%
scores/val_dice_mean_right_ventricle_fold0 2.33%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.01s/it]


### Log epoch 36
### Training
losses/loss_fold0 0.9391775131225586
dice_mean_wo_bg_fold0 3.25%
scores/dice_mean_left_myocardium_fold0 3.59%
scores/dice_mean_left_atrium_fold0 1.81%
scores/dice_mean_left_ventricle_fold0 2.92%
scores/dice_mean_right_atrium_fold0 3.94%
scores/dice_mean_right_ventricle_fold0 3.99%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.50it/s]


val_dice_mean_wo_bg_fold0 2.89%
scores/val_dice_mean_left_myocardium_fold0 3.60%
scores/val_dice_mean_left_atrium_fold0 2.13%
scores/val_dice_mean_left_ventricle_fold0 2.88%
scores/val_dice_mean_right_atrium_fold0 3.33%
scores/val_dice_mean_right_ventricle_fold0 2.50%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.94s/it]


### Log epoch 37
### Training
losses/loss_fold0 0.9453803896903992
dice_mean_wo_bg_fold0 3.22%
scores/dice_mean_left_myocardium_fold0 3.71%
scores/dice_mean_left_atrium_fold0 2.08%
scores/dice_mean_left_ventricle_fold0 2.86%
scores/dice_mean_right_atrium_fold0 4.07%
scores/dice_mean_right_ventricle_fold0 3.37%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


val_dice_mean_wo_bg_fold0 2.46%
scores/val_dice_mean_left_myocardium_fold0 3.25%
scores/val_dice_mean_left_atrium_fold0 1.80%
scores/val_dice_mean_left_ventricle_fold0 2.24%
scores/val_dice_mean_right_atrium_fold0 2.58%
scores/val_dice_mean_right_ventricle_fold0 2.44%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


### Log epoch 38
### Training
losses/loss_fold0 0.9395834803581238
dice_mean_wo_bg_fold0 3.24%
scores/dice_mean_left_myocardium_fold0 3.71%
scores/dice_mean_left_atrium_fold0 1.87%
scores/dice_mean_left_ventricle_fold0 3.24%
scores/dice_mean_right_atrium_fold0 3.79%
scores/dice_mean_right_ventricle_fold0 3.61%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.33%
scores/val_dice_mean_left_myocardium_fold0 3.74%
scores/val_dice_mean_left_atrium_fold0 1.17%
scores/val_dice_mean_left_ventricle_fold0 1.03%
scores/val_dice_mean_right_atrium_fold0 2.63%
scores/val_dice_mean_right_ventricle_fold0 3.08%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.03s/it]


### Log epoch 39
### Training
losses/loss_fold0 0.9411622285842896
dice_mean_wo_bg_fold0 3.01%
scores/dice_mean_left_myocardium_fold0 3.66%
scores/dice_mean_left_atrium_fold0 1.66%
scores/dice_mean_left_ventricle_fold0 2.38%
scores/dice_mean_right_atrium_fold0 3.62%
scores/dice_mean_right_ventricle_fold0 3.73%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


val_dice_mean_wo_bg_fold0 2.40%
scores/val_dice_mean_left_myocardium_fold0 3.35%
scores/val_dice_mean_left_atrium_fold0 1.46%
scores/val_dice_mean_left_ventricle_fold0 1.21%
scores/val_dice_mean_right_atrium_fold0 2.79%
scores/val_dice_mean_right_ventricle_fold0 3.18%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.01s/it]


### Log epoch 40
### Training
losses/loss_fold0 0.9365200996398926
dice_mean_wo_bg_fold0 2.98%
scores/dice_mean_left_myocardium_fold0 3.67%
scores/dice_mean_left_atrium_fold0 1.76%
scores/dice_mean_left_ventricle_fold0 1.95%
scores/dice_mean_right_atrium_fold0 3.82%
scores/dice_mean_right_ventricle_fold0 3.67%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.49it/s]


val_dice_mean_wo_bg_fold0 2.78%
scores/val_dice_mean_left_myocardium_fold0 3.51%
scores/val_dice_mean_left_atrium_fold0 1.47%
scores/val_dice_mean_left_ventricle_fold0 3.38%
scores/val_dice_mean_right_atrium_fold0 2.82%
scores/val_dice_mean_right_ventricle_fold0 2.72%



batch:: 100%|██████████| 4/4 [00:11<00:00,  2.94s/it]


### Log epoch 41
### Training
losses/loss_fold0 0.9362409114837646
dice_mean_wo_bg_fold0 3.11%
scores/dice_mean_left_myocardium_fold0 3.49%
scores/dice_mean_left_atrium_fold0 1.47%
scores/dice_mean_left_ventricle_fold0 2.77%
scores/dice_mean_right_atrium_fold0 3.88%
scores/dice_mean_right_ventricle_fold0 3.95%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.54it/s]


val_dice_mean_wo_bg_fold0 2.78%
scores/val_dice_mean_left_myocardium_fold0 3.58%
scores/val_dice_mean_left_atrium_fold0 2.07%
scores/val_dice_mean_left_ventricle_fold0 1.30%
scores/val_dice_mean_right_atrium_fold0 3.28%
scores/val_dice_mean_right_ventricle_fold0 3.69%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


### Log epoch 42
### Training
losses/loss_fold0 0.9368400573730469
dice_mean_wo_bg_fold0 3.05%
scores/dice_mean_left_myocardium_fold0 3.65%
scores/dice_mean_left_atrium_fold0 1.95%
scores/dice_mean_left_ventricle_fold0 2.29%
scores/dice_mean_right_atrium_fold0 3.80%
scores/dice_mean_right_ventricle_fold0 3.54%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


val_dice_mean_wo_bg_fold0 2.77%
scores/val_dice_mean_left_myocardium_fold0 3.22%
scores/val_dice_mean_left_atrium_fold0 1.63%
scores/val_dice_mean_left_ventricle_fold0 3.21%
scores/val_dice_mean_right_atrium_fold0 2.90%
scores/val_dice_mean_right_ventricle_fold0 2.90%



batch:: 100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


### Log epoch 43
### Training
losses/loss_fold0 0.934790849685669
dice_mean_wo_bg_fold0 3.08%
scores/dice_mean_left_myocardium_fold0 3.51%
scores/dice_mean_left_atrium_fold0 1.88%
scores/dice_mean_left_ventricle_fold0 2.26%
scores/dice_mean_right_atrium_fold0 3.82%
scores/dice_mean_right_ventricle_fold0 3.95%

### Validation


batch:: 100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


val_dice_mean_wo_bg_fold0 2.60%
scores/val_dice_mean_left_myocardium_fold0 3.19%
scores/val_dice_mean_left_atrium_fold0 1.63%
scores/val_dice_mean_left_ventricle_fold0 2.26%
scores/val_dice_mean_right_atrium_fold0 3.00%
scores/val_dice_mean_right_ventricle_fold0 2.91%



batch::  25%|██▌       | 1/4 [00:05<00:15,  5.16s/it]


KeyboardInterrupt: 

In [None]:
if not in_notebook():
    sys.exit(0)

In [22]:
# Do any postprocessing / visualization in notebook here

Running dummy-EzxVz3eRm8H7haihXtEusU
Will run validation with these 3D samples (#4): ['1001-mr', '1002-mr', '1003-mr', '1004-mr']
Param count model: 556732
Generating fresh 'BlendowskiAE' model, optimizer and grad scaler.


batch:: 100%|██████████| 4/4 [00:11<00:00,  2.87s/it]

### Log epoch 0
### Training
losses/loss_fold0 3.6996378898620605
losses/loss_fold0 3.6996378898620605





AttributeError: 'ReduceLROnPlateau' object has no attribute 'get_lr'