In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import faiss
import time
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import cv2
from attention_unet import Attention_block, conv_block, up_conv
import matplotlib.pyplot as plt
from dataloaders import Birds_OneCluster
import torchvision.utils

In [2]:
import pickle
with open("/home/e_radionova/DeepCluster/deepcluster/checkpts_Birds_1500epochs_30cls/clusters", "rb") as f:
    clusters = pickle.load(f)

In [3]:
epoch = 1000 #-1
num_cluster = 3 #11
cluster_idxs = clusters[epoch][num_cluster]

In [4]:
def find_names_in_fold(prefix):
    images_names = np.sort(os.listdir(prefix))
    list_names = np.sort(os.listdir(prefix / images_names[0])).tolist()
    for i, x in enumerate(list_names):
        list_names[i] = os.path.join(images_names[0],x)
    for i in images_names[1:]:
        list_names_onefold = np.sort(os.listdir(prefix / i)).tolist()
        for j, x in enumerate(list_names_onefold):
            list_names_onefold[j] = os.path.join(i, x)
        list_names.extend(list_names_onefold)
    return list_names

In [5]:
dataset_path = Path('/home/e_radionova/Datasets/Caltech_birds/CUB_200_2011/dataset')
train_imgs_path = dataset_path /  'train/images/'
train_masks_path = dataset_path /  'train/masks/'

In [6]:
train_masks_names = find_names_in_fold(train_masks_path)
train_imgs_names = find_names_in_fold(train_imgs_path)

In [7]:
lst_cluster_masks = [train_masks_names[i] for i in cluster_idxs]
lst_cluster_imgs = [train_imgs_names[i] for i in cluster_idxs]

In [8]:
from sklearn.model_selection import train_test_split
train_imgs, test_imgs, train_masks, test_masks = train_test_split(lst_cluster_imgs, lst_cluster_masks, 
                                                                  test_size=0.3, random_state=42)

In [9]:
trans = transforms.Compose([
    transforms.Resize((400, 400), interpolation=Image.NEAREST),
#     transforms.CenterCrop((400, 400)),
    transforms.ToTensor(),
])

train_set = Birds_OneCluster(img_names = train_imgs, 
                             mask_names = train_masks, 
                             images_folder = dataset_path / 'train/images', 
                             masks_folder = dataset_path / 'train/masks',
                             clusters=None,
                             cluster_num=11,
                             cluster_epoch=-1, 
                             img_transform = trans, 
                             masks_transform = trans)

val_set = Birds_OneCluster(img_names = test_imgs, 
                           mask_names = test_masks,
                           images_folder = dataset_path / 'train/images',
                           masks_folder = dataset_path / 'train/masks',
                           clusters=None,
                           cluster_num=11,
                           cluster_epoch=-1, 
                           img_transform = trans,
                           masks_transform = trans)

In [10]:
image_datasets = {
    'train': train_set , 'val': val_set
}

batch_size_train = 4
batch_size_val = batch_size_train

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size_train, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size_val, shuffle=True, num_workers=0)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

dataset_sizes

{'train': 219, 'val': 95}

In [11]:
# SEED = 42
# random.seed(SEED)
# torch.manual_seed(SEED)
# torch.cuda.manual_seed_all(SEED)
# np.random.seed(SEED)

# n_pics_to_show = batch_size_train
# inputs, masks = next(iter(dataloaders['val']))
# fig, ax = plt.subplots(n_pics_to_show, 2, figsize=(7, 15))
# for i in range(n_pics_to_show):
#     pic, label = inputs[i], masks[i] 
#     label_np = label.data.numpy().transpose(1, 2, 0) 
#     pic_np = pic.data.numpy().transpose(1, 2, 0) 
#     ax[i,0].imshow(pic_np)
#     ax[i,1].imshow(label_np)

In [12]:
from torchsummary import summary
import torch
import torch.nn as nn
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss, calc_loss, print_metrics
from tqdm.notebook import tqdm

In [13]:
logs_base_dir = Path('./logs')
logs_base_dir.mkdir(exist_ok=True)

In [14]:
activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

def plot_attention_map(attention_map, N=10, n_columns=5):
    normed_att_map = transforms.Normalize(0, 1)(attention_map).detach().cpu().numpy()
    plt.imshow(normed_att_map.transpose(1,2,0))
#     clear_output(wait=True)
    
    n_rows = N // n_columns + int(N // n_columns * n_columns < N)
    fig, axes = plt.subplots(n_rows, n_columns, figsize=(15,6))
    for map_i in range(N):
        if N==1:
            plt.imshow(normed_att_map[map_i])
        else:
            row_index = map_i // n_columns
            column_index = map_i % n_columns
            axes[row_index, column_index].imshow(normed_att_map[map_i])
    plt.show()

In [15]:
from template_matching_funcs import thresholding, get_Fourier_coeffs_and_kernel

def get_feature_map(attention_map, index=9, threshold=True):
    normed_att_map = attention_map.detach().cpu().numpy()
    one_layer = normed_att_map.transpose(1,2,0) #[index]
    plt.imshow(one_layer)
    plt.title('Choosen map')
    plt.show()
    print('Unique els: ', np.unique(one_layer))
    if threshold:
        fm_thresh = thresholding(one_layer, 1.0)
        kern = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
        fm_thresh = cv2.morphologyEx(fm_thresh, cv2.MORPH_CLOSE, kern)
        fm_thresh = cv2.dilate(fm_thresh,kern,iterations = 1)
        fm_thresh = 255 * abs(fm_thresh / fm_thresh.max())
        plt.imshow(fm_thresh)
        plt.show()

In [16]:
def get_normalize_image(image, limit=1.):
    image = image-np.min(image)
    image = image / image.max()
    image = limit * image
    return image

In [17]:
def get_kernel_baseline(image, template, order, kernel_size,
                       morph_open, morph_close, dilate):
#     img_match = template_matching(image=image, template=template, method=cv2.TM_CCORR)
    img_thresh = thresholding(template, 1.0)
    
    kern = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))  # np.ones((5,5),np.uint8)
    if morph_open:
        img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_OPEN, kern)
    if morph_close:
        img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_CLOSE, kern)
    if dilate:
        img_thresh = cv2.dilate(img_thresh,kern,iterations = 1)
        
    _, kernel = get_Fourier_coeffs_and_kernel(img_thresh, order=order, kernel_size=kernel_size)

    return kernel, img_thresh

In [18]:
def plot_attention_map(attention_map, N=30, n_columns=5):
    normed_att_map = transforms.Normalize(0, 1)(attention_map)
    resize_map = normed_att_map.detach().cpu().numpy()
    
#     clear_output(wait=True)
    n_rows = N // n_columns + int(N // n_columns * n_columns < N)
    fig, axes = plt.subplots(n_rows, n_columns, figsize=(40,20))
    for map_i in range(N):
        if N==1:
            plt.imshow(resize_map[map_i])
        else:
            row_index = map_i // n_columns
            column_index = map_i % n_columns
            axes[row_index, column_index].imshow(resize_map[map_i])
    plt.show()

In [19]:
import torchvision.transforms.functional as F
# from torchvision.transforms import InterpolationMode

In [20]:
def kernel_torch(att_map, kernel_size):
    uniq_els = torch.unique(att_map)
    mean = torch.mean(uniq_els)
    m = torch.nn.Threshold(mean, 0., inplace=False)
    thresh = m(att_map)
    height = att_map.shape[-2]
    width = att_map.shape[-1]
    kernel_resize = F.resize(thresh, kernel_size)
    return kernel_resize

In [21]:

def train_model(model, optimizer, scheduler, experiment_name, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    best_dice = 0
    
    writer = SummaryWriter(logs_base_dir / experiment_name)

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()
        i=0
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                
                inputs = inputs.to(device)
                labels = labels.to(device)    

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
#                     outputs_att = activation['Att4']
#                     if i ==10:
#                         plot_attention_map(activation['Att4'][0]) 
#                         print(activation['Att4'].shape)
#                     i+=1
                    #####----------------------
                    
#                     thresh_detach = thresh.permute(1,2,0).detach().cpu().numpy()
#                     outputs_detach = outputs_att.permute(1,2,0).detach().cpu().numpy()
#                     a = outputs_detach.shape[0]
#                     b = outputs_detach.shape[1]
#                     kernel_size = 20
#                     kernel_per_img, img_thresh = get_kernel_baseline(image=outputs_detach, 
#                                                                      template=thresh_detach, 
#                                                                      order=10, 
#                                                                      kernel_size=kernel_size,
#                                                                      morph_open=False, 
#                                                                      morph_close=False, 
#                                                                      dilate=True)

                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples   
            writer.add_scalar(f'Loss_{phase}', epoch_loss.item(), global_step=epoch)
            
            dice_epoch = metrics['dice'] / epoch_samples
            writer.add_scalar(f'DICE_{phase}', dice_epoch.item(), global_step=epoch)
            
#             imgs_to_tb = torchvision.utils.make_grid(inputs)
#             writer.add_image('images', imgs_to_tb, global_step=epoch)

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best loss")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
            
            if phase == 'val' and dice_epoch > best_dice:
                print("saving best DICE")
                best_dice = dice_epoch

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    print('Best val DICE: {:4f}'.format(best_dice))
    
#     file_name = "Best_DICE_" + Model_name + '_' + model_type 
#     with open(f"{files_dir}/{file_name}.txt","a") as the_file:
#             the_file.write('Fold №{}, best DICE: {}\n'.format(fold, best_dice))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [22]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [23]:
# import attention_unet
# model = attention_unet.AttU_Net(3, 3).to(device)
import attention_kernel2conv1
model = attention_kernel2conv1.AttU_Net(3, 3).to(device)

In [24]:
# model.Att4.register_forward_hook(get_activation('Att4'))

In [25]:
optimizer_ft = optim.Adam(model.parameters(), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=15, gamma=0.1)

exp_name = f'ModifiedModel_cluster3from1000_30epochs_b{batch_size_train}_kernel_7_' + datetime.now().isoformat(timespec='minutes') 

In [26]:
model = train_model(model, optimizer_ft, exp_lr_scheduler, experiment_name=exp_name, num_epochs=30) 

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch 0/29
----------
LR 0.001




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))




train: bce: 0.434502, dice: 0.323897, loss: 0.555302


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 1.469298, dice: 0.359884, loss: 1.054707
saving best loss
saving best DICE
0m 29s
Epoch 1/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.332099, dice: 0.411643, loss: 0.460228


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 1.049104, dice: 0.414179, loss: 0.817463
saving best loss
saving best DICE
0m 29s
Epoch 2/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.313505, dice: 0.461530, loss: 0.425988


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.712230, dice: 0.471547, loss: 0.620342
saving best loss
saving best DICE
0m 29s
Epoch 3/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.281754, dice: 0.516952, loss: 0.382401


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.829142, dice: 0.495202, loss: 0.666970
saving best DICE
0m 30s
Epoch 4/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.290484, dice: 0.543485, loss: 0.373499


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.409728, dice: 0.314831, loss: 0.547448
saving best loss
0m 30s
Epoch 5/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.264361, dice: 0.562033, loss: 0.351164


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.504797, dice: 0.168204, loss: 0.668296
0m 31s
Epoch 6/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.245014, dice: 0.591971, loss: 0.326522


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.579228, dice: 0.117820, loss: 0.730704
0m 30s
Epoch 7/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.252977, dice: 0.591211, loss: 0.330883


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.370520, dice: 0.561158, loss: 0.404681
saving best loss
saving best DICE
0m 32s
Epoch 8/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.240085, dice: 0.611396, loss: 0.314345


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.363863, dice: 0.553360, loss: 0.405251
0m 31s
Epoch 9/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.228259, dice: 0.627903, loss: 0.300178


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.567682, dice: 0.585718, loss: 0.490982
saving best DICE
0m 32s
Epoch 10/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.232966, dice: 0.627163, loss: 0.302901


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.301727, dice: 0.609400, loss: 0.346163
saving best loss
saving best DICE
0m 32s
Epoch 11/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.212047, dice: 0.642961, loss: 0.284543


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.310539, dice: 0.562102, loss: 0.374219
0m 31s
Epoch 12/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.207422, dice: 0.659521, loss: 0.273951


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.316100, dice: 0.642614, loss: 0.336743
saving best loss
saving best DICE
0m 33s
Epoch 13/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.203080, dice: 0.677778, loss: 0.262651


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.271729, dice: 0.610879, loss: 0.330425
saving best loss
0m 34s
Epoch 14/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.191084, dice: 0.685661, loss: 0.252711


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.231152, dice: 0.591038, loss: 0.320057
saving best loss
0m 32s
Epoch 15/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.173122, dice: 0.701777, loss: 0.235673


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.218086, dice: 0.622931, loss: 0.297578
saving best loss
0m 32s
Epoch 16/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.169254, dice: 0.709425, loss: 0.229914


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.227510, dice: 0.652558, loss: 0.287476
saving best loss
saving best DICE
0m 31s
Epoch 17/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.164552, dice: 0.719361, loss: 0.222595


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.218007, dice: 0.649335, loss: 0.284336
saving best loss
0m 32s
Epoch 18/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.162837, dice: 0.728736, loss: 0.217051


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.226822, dice: 0.631245, loss: 0.297788
0m 33s
Epoch 19/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.159096, dice: 0.739066, loss: 0.210015


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.214018, dice: 0.639278, loss: 0.287370
0m 31s
Epoch 20/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.160220, dice: 0.732792, loss: 0.213714


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.236713, dice: 0.607125, loss: 0.314794
0m 33s
Epoch 21/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.154933, dice: 0.737109, loss: 0.208912


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.220816, dice: 0.662088, loss: 0.279364
saving best loss
saving best DICE
0m 33s
Epoch 22/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.155798, dice: 0.740891, loss: 0.207453


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.242309, dice: 0.683276, loss: 0.279517
saving best DICE
0m 31s
Epoch 23/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.153555, dice: 0.746177, loss: 0.203689


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.226097, dice: 0.664907, loss: 0.280595
0m 31s
Epoch 24/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.158184, dice: 0.743497, loss: 0.207343


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.220623, dice: 0.667232, loss: 0.276695
saving best loss
0m 32s
Epoch 25/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.164389, dice: 0.739419, loss: 0.212485


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.205259, dice: 0.666484, loss: 0.269387
saving best loss
0m 32s
Epoch 26/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.152068, dice: 0.748416, loss: 0.201826


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.219603, dice: 0.675686, loss: 0.271959
0m 32s
Epoch 27/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.158331, dice: 0.744924, loss: 0.206704


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.234490, dice: 0.639358, loss: 0.297566
0m 31s
Epoch 28/29
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.158924, dice: 0.743896, loss: 0.207514


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.222961, dice: 0.681163, loss: 0.270899
0m 33s
Epoch 29/29
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.153002, dice: 0.750174, loss: 0.201414


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.202505, dice: 0.661931, loss: 0.270287
0m 31s

Best val loss: 0.269387
Best val DICE: 0.683276
