In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import faiss
import time
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import cv2
from attention_unet import Attention_block, conv_block, up_conv
import matplotlib.pyplot as plt
from dataloaders import Birds_OneCluster
import torchvision.utils

In [2]:
import pickle
with open("/home/e_radionova/DeepCluster/deepcluster/checkpts_Birds_1500epochs_30cls/clusters", "rb") as f:
    clusters = pickle.load(f)

In [3]:
epoch = 1000 #-1
num_cluster = 3 #11
cluster_idxs = clusters[epoch][num_cluster]

In [4]:
def find_names_in_fold(prefix):
    images_names = np.sort(os.listdir(prefix))
    list_names = np.sort(os.listdir(prefix / images_names[0])).tolist()
    for i, x in enumerate(list_names):
        list_names[i] = os.path.join(images_names[0],x)
    for i in images_names[1:]:
        list_names_onefold = np.sort(os.listdir(prefix / i)).tolist()
        for j, x in enumerate(list_names_onefold):
            list_names_onefold[j] = os.path.join(i, x)
        list_names.extend(list_names_onefold)
    return list_names

In [5]:
dataset_path = Path('/home/e_radionova/Datasets/Caltech_birds/CUB_200_2011/dataset')
train_imgs_path = dataset_path /  'train/images/'
train_masks_path = dataset_path /  'train/masks/'

In [6]:
train_masks_names = find_names_in_fold(train_masks_path)
train_imgs_names = find_names_in_fold(train_imgs_path)

In [7]:
lst_cluster_masks = [train_masks_names[i] for i in cluster_idxs]
lst_cluster_imgs = [train_imgs_names[i] for i in cluster_idxs]

In [8]:
from sklearn.model_selection import train_test_split
train_imgs, test_imgs, train_masks, test_masks = train_test_split(lst_cluster_imgs, lst_cluster_masks, 
                                                                  test_size=0.3, random_state=42)

In [9]:
trans = transforms.Compose([
    transforms.Resize((390, 390), interpolation=Image.NEAREST),
#     transforms.CenterCrop((400, 400)),
    transforms.ToTensor(),
])

train_set = Birds_OneCluster(img_names = train_imgs, 
                             mask_names = train_masks, 
                             images_folder = dataset_path / 'train/images', 
                             masks_folder = dataset_path / 'train/masks',
                             clusters=None,
                             cluster_num=11,
                             cluster_epoch=-1, 
                             img_transform = trans, 
                             masks_transform = trans)

val_set = Birds_OneCluster(img_names = test_imgs, 
                           mask_names = test_masks,
                           images_folder = dataset_path / 'train/images',
                           masks_folder = dataset_path / 'train/masks',
                           clusters=None,
                           cluster_num=11,
                           cluster_epoch=-1, 
                           img_transform = trans,
                           masks_transform = trans)

In [10]:
image_datasets = {
    'train': train_set , 'val': val_set
}

batch_size_train = 4
batch_size_val = batch_size_train

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size_train, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size_val, shuffle=True, num_workers=0)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

dataset_sizes

{'train': 219, 'val': 95}

In [11]:
# SEED = 42
# random.seed(SEED)
# torch.manual_seed(SEED)
# torch.cuda.manual_seed_all(SEED)
# np.random.seed(SEED)

# n_pics_to_show = batch_size_train
# inputs, masks = next(iter(dataloaders['val']))
# fig, ax = plt.subplots(n_pics_to_show, 2, figsize=(7, 15))
# for i in range(n_pics_to_show):
#     pic, label = inputs[i], masks[i] 
#     label_np = label.data.numpy().transpose(1, 2, 0) 
#     pic_np = pic.data.numpy().transpose(1, 2, 0) 
#     ax[i,0].imshow(pic_np)
#     ax[i,1].imshow(label_np)

In [12]:
from torchsummary import summary
import torch
import torch.nn as nn
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss, calc_loss, print_metrics
from tqdm.notebook import tqdm

In [13]:
logs_base_dir = Path('./logs')
logs_base_dir.mkdir(exist_ok=True)

In [14]:
activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

def plot_attention_map(attention_map, N=10, n_columns=5):
    normed_att_map = transforms.Normalize(0, 1)(attention_map).detach().cpu().numpy()
    plt.imshow(normed_att_map.transpose(1,2,0))
#     clear_output(wait=True)
    
    n_rows = N // n_columns + int(N // n_columns * n_columns < N)
    fig, axes = plt.subplots(n_rows, n_columns, figsize=(15,6))
    for map_i in range(N):
        if N==1:
            plt.imshow(normed_att_map[map_i])
        else:
            row_index = map_i // n_columns
            column_index = map_i % n_columns
            axes[row_index, column_index].imshow(normed_att_map[map_i])
    plt.show()

In [15]:
from template_matching_funcs import thresholding, get_Fourier_coeffs_and_kernel

def get_feature_map(attention_map, index=9, threshold=True):
    normed_att_map = attention_map.detach().cpu().numpy()
    one_layer = normed_att_map.transpose(1,2,0) #[index]
    plt.imshow(one_layer)
    plt.title('Choosen map')
    plt.show()
    print('Unique els: ', np.unique(one_layer))
    if threshold:
        fm_thresh = thresholding(one_layer, 1.0)
        kern = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
        fm_thresh = cv2.morphologyEx(fm_thresh, cv2.MORPH_CLOSE, kern)
        fm_thresh = cv2.dilate(fm_thresh,kern,iterations = 1)
        fm_thresh = 255 * abs(fm_thresh / fm_thresh.max())
        plt.imshow(fm_thresh)
        plt.show()

In [16]:
def get_normalize_image(image, limit=1.):
    image = image-np.min(image)
    image = image / image.max()
    image = limit * image
    return image

In [17]:
def get_kernel_baseline(image, template, order, kernel_size,
                       morph_open, morph_close, dilate):
#     img_match = template_matching(image=image, template=template, method=cv2.TM_CCORR)
    img_thresh = thresholding(template, 1.0)
    
    kern = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))  # np.ones((5,5),np.uint8)
    if morph_open:
        img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_OPEN, kern)
    if morph_close:
        img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_CLOSE, kern)
    if dilate:
        img_thresh = cv2.dilate(img_thresh,kern,iterations = 1)
        
    _, kernel = get_Fourier_coeffs_and_kernel(img_thresh, order=order, kernel_size=kernel_size)

    return kernel, img_thresh

In [18]:
def plot_attention_map(attention_map, N=30, n_columns=5):
    normed_att_map = transforms.Normalize(0, 1)(attention_map)
    resize_map = normed_att_map.detach().cpu().numpy()
    
#     clear_output(wait=True)
    n_rows = N // n_columns + int(N // n_columns * n_columns < N)
    fig, axes = plt.subplots(n_rows, n_columns, figsize=(40,20))
    for map_i in range(N):
        if N==1:
            plt.imshow(resize_map[map_i])
        else:
            row_index = map_i // n_columns
            column_index = map_i % n_columns
            axes[row_index, column_index].imshow(resize_map[map_i])
    plt.show()

In [19]:
import torchvision.transforms.functional as F
# from torchvision.transforms import InterpolationMode

In [20]:
def kernel_torch(att_map, kernel_size):
    uniq_els = torch.unique(att_map)
    mean = torch.mean(uniq_els)
    m = torch.nn.Threshold(mean, 0., inplace=False)
    thresh = m(att_map)
    height = att_map.shape[-2]
    width = att_map.shape[-1]
    kernel_resize = F.resize(thresh, kernel_size)
    return kernel_resize

In [21]:

def train_model(model, optimizer, scheduler, experiment_name, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    best_dice = 0
    
    writer = SummaryWriter(logs_base_dir / experiment_name)

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()
        i=0
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                
                inputs = inputs.to(device)
                labels = labels.to(device)    

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
#                     outputs_att = activation['Att4']
#                     if i ==10:
#                         plot_attention_map(activation['Att4'][0]) 
#                         print(activation['Att4'].shape)
#                     i+=1
                    #####----------------------
                    
#                     thresh_detach = thresh.permute(1,2,0).detach().cpu().numpy()
#                     outputs_detach = outputs_att.permute(1,2,0).detach().cpu().numpy()
#                     a = outputs_detach.shape[0]
#                     b = outputs_detach.shape[1]
#                     kernel_size = 20
#                     kernel_per_img, img_thresh = get_kernel_baseline(image=outputs_detach, 
#                                                                      template=thresh_detach, 
#                                                                      order=10, 
#                                                                      kernel_size=kernel_size,
#                                                                      morph_open=False, 
#                                                                      morph_close=False, 
#                                                                      dilate=True)

                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples   
            writer.add_scalar(f'Loss_{phase}', epoch_loss.item(), global_step=epoch)
            
            dice_epoch = metrics['dice'] / epoch_samples
            writer.add_scalar(f'DICE_{phase}', dice_epoch.item(), global_step=epoch)
            
            imgs_to_tb = torchvision.utils.make_grid(inputs)
            writer.add_image('images', imgs_to_tb, global_step=epoch)
            
            outs_to_tb = torchvision.utils.make_grid(outputs)
            writer.add_image('images', outs_to_tb, global_step=epoch)
            
#             kernels = model.Conv_upd_x1.conv[0].weight.data.detach().cpu()
#             kernels = kernels[:10]
#             kernels_to_tb = torchvision.utils.make_grid(kernels, nrow=5, pad_value=255)
#             writer.add_image('kernels', kernels_to_tb, global_step=epoch)

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best loss")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
            
            if phase == 'val' and dice_epoch > best_dice:
                print("saving best DICE")
                best_dice = dice_epoch

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    print('Best val DICE: {:4f}'.format(best_dice))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [22]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [23]:
# import attention_unet
# model = attention_unet.AttU_Net(3, 3).to(device)
# import attention_kernel2conv1
# model = attention_kernel2conv1.AttU_Net(3, 3).to(device)
import NewNet
model = NewNet.UNet(3).to(device)

In [24]:
# model.Att4.register_forward_hook(get_activation('Att4'))

In [25]:
optimizer_ft = optim.Adam(model.parameters(), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=15, gamma=0.1)

exp_name = f'NewNet_cluster3from1000_50epochs_b{batch_size_train}_' + datetime.now().isoformat(timespec='minutes') 

In [26]:
model = train_model(model, optimizer_ft, exp_lr_scheduler, experiment_name=exp_name, num_epochs=50) 

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))

Epoch 0/49
----------
LR 0.001




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))




train: bce: 0.566471, dice: 0.188483, loss: 0.688994


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.427761, dice: 0.189910, loss: 0.618925
saving best loss
saving best DICE
0m 20s
Epoch 1/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.432578, dice: 0.198449, loss: 0.617065


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.426587, dice: 0.209481, loss: 0.608553
saving best loss
saving best DICE
0m 20s
Epoch 2/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.393141, dice: 0.250529, loss: 0.571306


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.414381, dice: 0.269523, loss: 0.572429
saving best loss
saving best DICE
0m 20s
Epoch 3/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.391260, dice: 0.283779, loss: 0.553740


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.438511, dice: 0.339160, loss: 0.549676
saving best loss
saving best DICE
0m 20s
Epoch 4/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.372438, dice: 0.329256, loss: 0.521591


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.414429, dice: 0.348727, loss: 0.532851
saving best loss
saving best DICE
0m 20s
Epoch 5/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.349279, dice: 0.405170, loss: 0.472054


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.409089, dice: 0.421128, loss: 0.493980
saving best loss
saving best DICE
0m 20s
Epoch 6/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.328845, dice: 0.442158, loss: 0.443344


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.473575, dice: 0.464512, loss: 0.504532
saving best DICE
0m 20s
Epoch 7/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.323103, dice: 0.460365, loss: 0.431369


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.426444, dice: 0.434644, loss: 0.495900
0m 20s
Epoch 8/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.320272, dice: 0.455185, loss: 0.432543


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.329847, dice: 0.281723, loss: 0.524062
0m 20s
Epoch 9/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.324287, dice: 0.477636, loss: 0.423325


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.330504, dice: 0.436166, loss: 0.447169
saving best loss
0m 20s
Epoch 10/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.328665, dice: 0.468357, loss: 0.430154


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.363886, dice: 0.444266, loss: 0.459810
0m 20s
Epoch 11/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.324497, dice: 0.483996, loss: 0.420250


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.317765, dice: 0.456012, loss: 0.430876
saving best loss
0m 20s
Epoch 12/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.317703, dice: 0.481374, loss: 0.418165


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.944847, dice: 0.478397, loss: 0.733225
saving best DICE
0m 20s
Epoch 13/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.352225, dice: 0.431571, loss: 0.460327


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.405543, dice: 0.460196, loss: 0.472673
0m 20s
Epoch 14/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.297751, dice: 0.500960, loss: 0.398395


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.381161, dice: 0.496249, loss: 0.442456
saving best DICE
0m 20s
Epoch 15/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.285559, dice: 0.529480, loss: 0.378039


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.368985, dice: 0.488253, loss: 0.440366
0m 20s
Epoch 16/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.288044, dice: 0.531126, loss: 0.378459


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.377421, dice: 0.506955, loss: 0.435233
saving best DICE
0m 20s
Epoch 17/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.282471, dice: 0.538657, loss: 0.371907


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.355459, dice: 0.506689, loss: 0.424385
saving best loss
0m 20s
Epoch 18/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.279412, dice: 0.540367, loss: 0.369523


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.378457, dice: 0.521854, loss: 0.428302
saving best DICE
0m 20s
Epoch 19/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.280049, dice: 0.544629, loss: 0.367710


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.370512, dice: 0.520178, loss: 0.425167
0m 20s
Epoch 20/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.272976, dice: 0.549289, loss: 0.361844


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.342993, dice: 0.512874, loss: 0.415060
saving best loss
0m 20s
Epoch 21/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.271622, dice: 0.551494, loss: 0.360064


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.345938, dice: 0.521469, loss: 0.412235
saving best loss
0m 20s
Epoch 22/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.271754, dice: 0.555515, loss: 0.358119


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.354242, dice: 0.526625, loss: 0.413809
saving best DICE
0m 20s
Epoch 23/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.266978, dice: 0.554849, loss: 0.356064


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.373484, dice: 0.529049, loss: 0.422217
saving best DICE
0m 20s
Epoch 24/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.269626, dice: 0.559729, loss: 0.354948


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.404880, dice: 0.543111, loss: 0.430884
saving best DICE
0m 20s
Epoch 25/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.268544, dice: 0.561566, loss: 0.353489


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.324939, dice: 0.514111, loss: 0.405414
saving best loss
0m 20s
Epoch 26/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.270977, dice: 0.555225, loss: 0.357876


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.378726, dice: 0.545018, loss: 0.416854
saving best DICE
0m 20s
Epoch 27/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.267812, dice: 0.562437, loss: 0.352688


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.344711, dice: 0.533897, loss: 0.405407
saving best loss
0m 20s
Epoch 28/49
----------
LR 0.0001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.261282, dice: 0.569920, loss: 0.345681


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.332303, dice: 0.533233, loss: 0.399535
saving best loss
0m 20s
Epoch 29/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.256002, dice: 0.575342, loss: 0.340330


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.335901, dice: 0.537454, loss: 0.399223
saving best loss
0m 20s
Epoch 30/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.251842, dice: 0.572631, loss: 0.339606


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.346419, dice: 0.541817, loss: 0.402301
0m 20s
Epoch 31/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.255234, dice: 0.576969, loss: 0.339132


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.337828, dice: 0.539186, loss: 0.399321
0m 20s
Epoch 32/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.254587, dice: 0.577233, loss: 0.338677


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.342503, dice: 0.541625, loss: 0.400439
0m 20s
Epoch 33/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.254979, dice: 0.579928, loss: 0.337526


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.336565, dice: 0.539277, loss: 0.398644
saving best loss
0m 20s
Epoch 34/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.251593, dice: 0.576874, loss: 0.337360


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.339544, dice: 0.541426, loss: 0.399059
0m 20s
Epoch 35/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.254475, dice: 0.579988, loss: 0.337244


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.339482, dice: 0.541570, loss: 0.398956
0m 20s
Epoch 36/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.251722, dice: 0.578550, loss: 0.336586


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.339379, dice: 0.541862, loss: 0.398758
0m 20s
Epoch 37/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.254680, dice: 0.581708, loss: 0.336486


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.335588, dice: 0.540410, loss: 0.397589
saving best loss
0m 20s
Epoch 38/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.251716, dice: 0.579692, loss: 0.336012


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.339336, dice: 0.542931, loss: 0.398202
0m 20s
Epoch 39/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.252603, dice: 0.581235, loss: 0.335684


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.336486, dice: 0.541539, loss: 0.397474
saving best loss
0m 20s
Epoch 40/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.251819, dice: 0.581401, loss: 0.335209


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.342530, dice: 0.544856, loss: 0.398837
0m 20s
Epoch 41/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.253733, dice: 0.583575, loss: 0.335079


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.337874, dice: 0.543381, loss: 0.397247
saving best loss
0m 20s
Epoch 42/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.251503, dice: 0.581698, loss: 0.334902


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.335053, dice: 0.542361, loss: 0.396346
saving best loss
0m 20s
Epoch 43/49
----------
LR 1e-05


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.251300, dice: 0.581589, loss: 0.334855


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.336511, dice: 0.543034, loss: 0.396738
0m 20s
Epoch 44/49
----------
LR 1.0000000000000002e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.249534, dice: 0.582514, loss: 0.333510


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.337801, dice: 0.543809, loss: 0.396996
0m 20s
Epoch 45/49
----------
LR 1.0000000000000002e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.250005, dice: 0.583338, loss: 0.333333


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.337656, dice: 0.543833, loss: 0.396911
0m 20s
Epoch 46/49
----------
LR 1.0000000000000002e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.250415, dice: 0.583782, loss: 0.333317


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.338366, dice: 0.544213, loss: 0.397077
0m 20s
Epoch 47/49
----------
LR 1.0000000000000002e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.250036, dice: 0.583552, loss: 0.333242


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.337730, dice: 0.543995, loss: 0.396868
0m 20s
Epoch 48/49
----------
LR 1.0000000000000002e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.250407, dice: 0.583933, loss: 0.333237


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.339091, dice: 0.544624, loss: 0.397234
0m 20s
Epoch 49/49
----------
LR 1.0000000000000002e-06


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=55.0), HTML(value='')))


train: bce: 0.250320, dice: 0.583921, loss: 0.333199


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))


val: bce: 0.338340, dice: 0.544326, loss: 0.397007
0m 20s

Best val loss: 0.396346
Best val DICE: 0.545018
