In [1]:
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader, random_split
import torch
from torch import optim
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import logging
from evaluate import evaluate

import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import OrderedDict

from model.ensemblenet_model import EnsembleNet


from utils.dice_score import dice_loss
from utils.data_load import KittiDataset
from torchsummaryX import summary

In [2]:
Val_Percent = 0.3
Scale_Percent = 1.0
Batch_Size = 8
learning_rate = 0.0001
Pin_Memory = False
epochs = 50


Image_Size = [384, 1216]
Gradient_Clipping = 0.8


Num_Class = 2
Num_Channel = 3
amp = True

Model_Name = 'ensemble_fusion'


Img_Path =  'data/data_road/training/image_2'
Mask_Path =  'data/data_road/training/semantic'

save_checkpoint = True
checkpoint_dir = '../trained' + '_' + Model_Name
batch_size = Batch_Size

In [3]:
dirImg = Path(Img_Path)
dirMask = Path(Mask_Path)
dir_checkpoint = Path(checkpoint_dir)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
datasets =  KittiDataset(dirImg, dirMask, Image_Size, Scale_Percent)
n_val = int(len(datasets) * Val_Percent)
n_train = len(datasets) - n_val
train_set, val_set = random_split(datasets, [n_train, n_val], generator=torch.Generator().manual_seed(0))

loader_args = dict(batch_size=Batch_Size, num_workers= os.cpu_count(), pin_memory=Pin_Memory)
train_loader = DataLoader(train_set, shuffle=True, drop_last = True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

100%|██████████| 289/289 [00:00<00:00, 803.40it/s]


In [6]:
model = EnsembleNet(Model_Name, Num_Channel, Num_Class)
model = model.to(memory_format=torch.channels_last, device = device)

In [7]:
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
#optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
#optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)

if 'ensemble_voting' in Model_Name:
    unet_optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
    segnet_optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
    enet_optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
    optims = [unet_optimizer, segnet_optimizer, enet_optimizer]
    
    unet_scheduler = optim.lr_scheduler.ReduceLROnPlateau(unet_optimizer, 'max', patience=2)  # goal: maximize Dice score
    segnet_scheduler = optim.lr_scheduler.ReduceLROnPlateau(segnet_optimizer, 'max', patience=2)  # goal: maximize Dice score
    enet_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enet_optimizer, 'max', patience=2)  # goal: maximize Dice score
       
else:
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-10)
    #optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-10)
    optims = [optimizer]
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score

grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0

In [8]:
def calculate_loss(pred, true_masks, nclass, multiclass):
    loss = criterion(pred, true_masks)
    loss += dice_loss(
        F.softmax(pred, dim=1).float(),
        F.one_hot(true_masks, nclass).permute(0, 3, 1, 2).float(),
        multiclass=multiclass
    )
    return loss

def grad_forback(models, losses, optim):
    optim.zero_grad(set_to_none=True)
    grad_scaler.scale(losses).backward()
    torch.nn.utils.clip_grad_norm_(models.parameters(), Gradient_Clipping)
    grad_scaler.step(optim)
    grad_scaler.update()    

def forward_and_backward(model, images, true_masks, amp, optimizers, grad_scaler, model_name):
    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        
        try:
            mn_cls = model.n_classes
        except:
            mn_cls = model.classifier[-1].out_channels

        if model_name == 'ensemble_voting':
            unet_pred, segnet_pred, enet_pred = model(images)
            #deeplab_pred = deeplab_pred['out']
            
            unet_loss = calculate_loss(unet_pred, true_masks, mn_cls, multiclass=True)
            segnet_loss = calculate_loss(segnet_pred, true_masks, mn_cls, multiclass=True)
            enet_loss = calculate_loss(enet_pred, true_masks, mn_cls, multiclass=True)
            
            
        else:
            masks_pred = model(images)
            if isinstance(masks_pred, OrderedDict):
                masks_pred = masks_pred['out']
            loss = calculate_loss(masks_pred, true_masks, mn_cls, multiclass=True)
    
    
    if model_name == 'ensemble_voting':
        for _loss, _optiz in zip([unet_loss, segnet_loss, enet_loss], optimizers):
            grad_forback(model, _loss, _optiz)

        return model, unet_loss, segnet_loss, enet_loss
    else:
        for _loss, _optiz in zip([loss], optimizers):
            grad_forback(model, _loss, _optiz)
            
        return model, loss


In [9]:
valScore_list1 = []
TrainLoss_list1 = []

valScore_list2 = []
TrainLoss_list2 = []

valScore_list3 = []
TrainLoss_list3 = []

valScore_list4 = []
TrainLoss_list4 = []

val_losses = []
val_accs = []
val_mious = []

# 5. Begin training
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    epoch_unet_loss = 0
    epoch_segnet_loss = 0
    epoch_enet_loss = 0
    epoch_voting_loss = 0
    
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            images, true_masks = batch['image'], batch['mask']

            images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            true_masks = true_masks.to(device=device, dtype=torch.long)

                
            result = forward_and_backward(model, images, true_masks, amp, optims, grad_scaler, Model_Name)
            
            if len(result) == 4:
                model, unet_loss, segnet_loss, enet_loss = result
                
                pbar.update(images.shape[0])
                global_step += 1
                epoch_unet_loss += unet_loss.item()
                epoch_segnet_loss += segnet_loss.item()
                epoch_enet_loss += enet_loss.item()
                vot_loss = ((unet_loss.item() + segnet_loss.item() + enet_loss.item()) /3)
                epoch_voting_loss += vot_loss
                
                
            elif len(result) == 2:
                model, loss = result
                
                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()


        print('***')
        if len(result) == 4:
            print('Unet Loss: {}     Segnet Loss: {}     Enet Loss: {}'.format(unet_loss, segnet_loss, enet_loss))
            print('Voting Loss: {}'.format(vot_loss))
            
            
        elif len(result) == 2:
            print('{} Loss: {}'.format(Model_Name, loss))

        # Evaluation round
        division_step = (n_train // (5 * batch_size))
        if division_step > 0:
            #if global_step % division_step == 0:
            if len(result) == 4:
                unet_val_score, segnet_val_score, enet_val_score, voting_val_score, val_loss, val_acc, val_miou = evaluate(model, val_loader, criterion, device, Model_Name, amp)
                
                unet_scheduler.step(unet_val_score)
                segnet_scheduler.step(segnet_val_score)
                enet_scheduler.step(enet_val_score)
                #voting_scheduler.step(voting_val_score)
                
                valScore_list1.append(unet_val_score.cpu().detach().numpy())
                TrainLoss_list1.append(unet_loss.cpu().detach().numpy())
                valScore_list2.append(segnet_val_score.cpu().detach().numpy())
                TrainLoss_list2.append(segnet_loss.cpu().detach().numpy())                
                valScore_list3.append(enet_val_score.cpu().detach().numpy())
                TrainLoss_list3.append(enet_loss.cpu().detach().numpy())
                valScore_list4.append(voting_val_score.cpu().detach().numpy())
                TrainLoss_list4.append(vot_loss)
                
                val_losses.append(val_loss)
                val_accs.append(val_acc)
                val_mious.append(val_miou)
                
                print('---')
                print('Unet Validation Dice Score: {}     Segnet Validation Dice Score: {}     Enet Validation Dice Score: {}'.format(unet_val_score, segnet_val_score, enet_val_score))
                print('---')
                print('Ensemble Voting Validation Dice Loss: {}'.format(val_loss))
                print('Ensemble Voting Validation Pixel Accuracy: {} '.format(val_acc))
                print('Ensemble Voting Validation MIoU: {}'.format(val_miou))                
                print('Ensemble Voting Validation Dice Score: {} '.format(voting_val_score))
                
            else:
                val_score, val_loss, val_acc, val_miou = evaluate(model, val_loader, criterion, device, Model_Name, amp)
                
                                
                scheduler.step(val_score)
                
                print('---')
                print('{} Validation Dice Loss: {}'.format(Model_Name, val_loss))   
                print('{} Validation Pixel Accuracy: {}'.format(Model_Name, val_acc))
                print('{} Validation MIoU: {}'.format(Model_Name, val_miou))
                print('{} Validation Dice Score: {}'.format(Model_Name, val_score))
                
            
                valScore_list1.append(val_score.cpu().detach().numpy())
                TrainLoss_list1.append(loss.cpu().detach().numpy())
                val_losses.append(val_loss)
                val_accs.append(val_acc)
                val_mious.append(val_miou)

        
    if save_checkpoint:
        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))

Epoch 1/50:  99%|█████████▊| 200/203 [00:18<00:00, 15.08img/s]

***
ensemble_fusion Loss: 1.242800235748291


Epoch 1/50:  99%|█████████▊| 200/203 [00:24<00:00,  8.13img/s]

---
ensemble_fusion Validation Dice Loss: 1.2662577629089355
ensemble_fusion Validation Pixel Accuracy: 0.14941379480194628
ensemble_fusion Validation MIoU: 0.07470689740097317
ensemble_fusion Validation Dice Score: 0.28675708174705505



Epoch 2/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.08img/s]

***
ensemble_fusion Loss: 1.1503007411956787


Epoch 2/50:  99%|█████████▊| 200/203 [00:19<00:00, 10.07img/s]

---
ensemble_fusion Validation Dice Loss: 1.2080085277557373
ensemble_fusion Validation Pixel Accuracy: 0.17911034299616227
ensemble_fusion Validation MIoU: 0.09445139948828285
ensemble_fusion Validation Dice Score: 0.29405632615089417



Epoch 3/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.21img/s]

***
ensemble_fusion Loss: 1.1108816862106323


Epoch 3/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.88img/s]

---
ensemble_fusion Validation Dice Loss: 1.1318280696868896
ensemble_fusion Validation Pixel Accuracy: 0.6377970377604166
ensemble_fusion Validation MIoU: 0.4318198014988708
ensemble_fusion Validation Dice Score: 0.4702892303466797



Epoch 4/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.04img/s]

***
ensemble_fusion Loss: 1.098218321800232


Epoch 4/50:  99%|█████████▊| 200/203 [00:19<00:00, 10.04img/s]

---
ensemble_fusion Validation Dice Loss: 1.1302821636199951
ensemble_fusion Validation Pixel Accuracy: 0.7855406644051535
ensemble_fusion Validation MIoU: 0.5718720425792584
ensemble_fusion Validation Dice Score: 0.6074281334877014



Epoch 5/50:  99%|█████████▊| 200/203 [00:16<00:00, 14.98img/s]

***
ensemble_fusion Loss: 1.0703864097595215


Epoch 5/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.66img/s]

---
ensemble_fusion Validation Dice Loss: 1.1078870296478271
ensemble_fusion Validation Pixel Accuracy: 0.8428588331791392
ensemble_fusion Validation MIoU: 0.6373437836623443
ensemble_fusion Validation Dice Score: 0.6758133769035339



Epoch 6/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.00img/s]

***
ensemble_fusion Loss: 1.0235888957977295


Epoch 6/50:  99%|█████████▊| 200/203 [00:19<00:00, 10.03img/s]

---
ensemble_fusion Validation Dice Loss: 1.105132818222046
ensemble_fusion Validation Pixel Accuracy: 0.8817706191748903
ensemble_fusion Validation MIoU: 0.7061471826424262
ensemble_fusion Validation Dice Score: 0.7442137002944946



Epoch 7/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.88img/s]

***
ensemble_fusion Loss: 1.0405449867248535


Epoch 7/50:  99%|█████████▊| 200/203 [00:19<00:00, 10.14img/s]

---
ensemble_fusion Validation Dice Loss: 1.0900901556015015
ensemble_fusion Validation Pixel Accuracy: 0.914169846919545
ensemble_fusion Validation MIoU: 0.7601020101837672
ensemble_fusion Validation Dice Score: 0.7870831489562988



Epoch 8/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.96img/s]

***
ensemble_fusion Loss: 1.0629549026489258


Epoch 8/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.95img/s]

---
ensemble_fusion Validation Dice Loss: 1.0922679901123047
ensemble_fusion Validation Pixel Accuracy: 0.9037090100740132
ensemble_fusion Validation MIoU: 0.7418245636538219
ensemble_fusion Validation Dice Score: 0.7903817296028137



Epoch 9/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.86img/s]

***
ensemble_fusion Loss: 0.9818758964538574


Epoch 9/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.89img/s]

---
ensemble_fusion Validation Dice Loss: 1.1009740829467773
ensemble_fusion Validation Pixel Accuracy: 0.9039986593681469
ensemble_fusion Validation MIoU: 0.7428902168923223
ensemble_fusion Validation Dice Score: 0.7814001441001892



Epoch 10/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.94img/s]

***
ensemble_fusion Loss: 0.9982020258903503


Epoch 10/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.85img/s]

---
ensemble_fusion Validation Dice Loss: 1.0657620429992676
ensemble_fusion Validation Pixel Accuracy: 0.9226173267029879
ensemble_fusion Validation MIoU: 0.7767294497540567
ensemble_fusion Validation Dice Score: 0.8417136073112488



Epoch 11/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.73img/s]

***
ensemble_fusion Loss: 0.9862491488456726


Epoch 11/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.83img/s]

---
ensemble_fusion Validation Dice Loss: 1.048461675643921
ensemble_fusion Validation Pixel Accuracy: 0.9400658858449835
ensemble_fusion Validation MIoU: 0.8075801691550514
ensemble_fusion Validation Dice Score: 0.8600160479545593



Epoch 12/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.96img/s]

***
ensemble_fusion Loss: 0.9910356402397156


Epoch 12/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.96img/s]

---
ensemble_fusion Validation Dice Loss: 1.0472723245620728
ensemble_fusion Validation Pixel Accuracy: 0.9339966355708608
ensemble_fusion Validation MIoU: 0.800845286754029
ensemble_fusion Validation Dice Score: 0.8554482460021973



Epoch 13/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.97img/s]

***
ensemble_fusion Loss: 0.9846231341362


Epoch 13/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.91img/s]

---
ensemble_fusion Validation Dice Loss: 1.0191850662231445
ensemble_fusion Validation Pixel Accuracy: 0.9497206838507402
ensemble_fusion Validation MIoU: 0.8359929568382598
ensemble_fusion Validation Dice Score: 0.8685641288757324



Epoch 14/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.95img/s]

***
ensemble_fusion Loss: 0.9916022419929504


Epoch 14/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.80img/s]

---
ensemble_fusion Validation Dice Loss: 1.036548376083374
ensemble_fusion Validation Pixel Accuracy: 0.9338601095634594
ensemble_fusion Validation MIoU: 0.801796139541394
ensemble_fusion Validation Dice Score: 0.845578670501709



Epoch 15/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.77img/s]

***
ensemble_fusion Loss: 0.9613921642303467


Epoch 15/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.68img/s]

---
ensemble_fusion Validation Dice Loss: 1.029972791671753
ensemble_fusion Validation Pixel Accuracy: 0.9290254827131305
ensemble_fusion Validation MIoU: 0.7925269692677145
ensemble_fusion Validation Dice Score: 0.847538948059082



Epoch 16/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.04img/s]

***
ensemble_fusion Loss: 0.9258512854576111


Epoch 16/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.89img/s]

---
ensemble_fusion Validation Dice Loss: 0.9937180280685425
ensemble_fusion Validation Pixel Accuracy: 0.953762924462034
ensemble_fusion Validation MIoU: 0.8489497678719743
ensemble_fusion Validation Dice Score: 0.8699674606323242



Epoch 17/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.76img/s]

***
ensemble_fusion Loss: 0.9297443628311157


Epoch 17/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.92img/s]

---
ensemble_fusion Validation Dice Loss: 0.9872747659683228
ensemble_fusion Validation Pixel Accuracy: 0.9523735715631854
ensemble_fusion Validation MIoU: 0.841588829624526
ensemble_fusion Validation Dice Score: 0.8861233592033386



Epoch 18/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.16img/s]

***
ensemble_fusion Loss: 0.9391117095947266


Epoch 18/50:  99%|█████████▊| 200/203 [00:19<00:00, 10.06img/s]

---
ensemble_fusion Validation Dice Loss: 0.9854241013526917
ensemble_fusion Validation Pixel Accuracy: 0.940704345703125
ensemble_fusion Validation MIoU: 0.8153433687217855
ensemble_fusion Validation Dice Score: 0.8701228499412537



Epoch 19/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.82img/s]

***
ensemble_fusion Loss: 0.9403440952301025


Epoch 19/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.75img/s]

---
ensemble_fusion Validation Dice Loss: 0.9888077974319458
ensemble_fusion Validation Pixel Accuracy: 0.9375642475328947
ensemble_fusion Validation MIoU: 0.8099403647347093
ensemble_fusion Validation Dice Score: 0.8609020113945007



Epoch 20/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.82img/s]

***
ensemble_fusion Loss: 0.9083320498466492


Epoch 20/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.82img/s]

---
ensemble_fusion Validation Dice Loss: 0.9614239931106567
ensemble_fusion Validation Pixel Accuracy: 0.9435727303488213
ensemble_fusion Validation MIoU: 0.8139241914053544
ensemble_fusion Validation Dice Score: 0.8620874285697937



Epoch 21/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.84img/s]

***
ensemble_fusion Loss: 0.8830013275146484


Epoch 21/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.82img/s]

---
ensemble_fusion Validation Dice Loss: 0.9584707021713257
ensemble_fusion Validation Pixel Accuracy: 0.9490974827816612
ensemble_fusion Validation MIoU: 0.8361191579953184
ensemble_fusion Validation Dice Score: 0.8839119076728821



Epoch 22/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.84img/s]

***
ensemble_fusion Loss: 0.8560265302658081


Epoch 22/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.87img/s]

---
ensemble_fusion Validation Dice Loss: 0.9514617323875427
ensemble_fusion Validation Pixel Accuracy: 0.9533142625239858
ensemble_fusion Validation MIoU: 0.8462266130180747
ensemble_fusion Validation Dice Score: 0.8917251825332642



Epoch 23/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.96img/s]

***
ensemble_fusion Loss: 0.883250892162323


Epoch 23/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.83img/s]

---
ensemble_fusion Validation Dice Loss: 0.9458096623420715
ensemble_fusion Validation Pixel Accuracy: 0.9550157513534814
ensemble_fusion Validation MIoU: 0.8506339519833606
ensemble_fusion Validation Dice Score: 0.8934969902038574



Epoch 24/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.92img/s]

***
ensemble_fusion Loss: 0.8901804089546204


Epoch 24/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.94img/s]

---
ensemble_fusion Validation Dice Loss: 0.9497979879379272
ensemble_fusion Validation Pixel Accuracy: 0.9538322582579496
ensemble_fusion Validation MIoU: 0.8487833116033401
ensemble_fusion Validation Dice Score: 0.8924872279167175



Epoch 25/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.95img/s]

***
ensemble_fusion Loss: 0.8808360695838928


Epoch 25/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.77img/s]

---
ensemble_fusion Validation Dice Loss: 0.9448935389518738
ensemble_fusion Validation Pixel Accuracy: 0.9563130161218476
ensemble_fusion Validation MIoU: 0.8544342100482362
ensemble_fusion Validation Dice Score: 0.8969358801841736



Epoch 26/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.90img/s]

***
ensemble_fusion Loss: 0.8577100038528442


Epoch 26/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.84img/s]

---
ensemble_fusion Validation Dice Loss: 0.9425569176673889
ensemble_fusion Validation Pixel Accuracy: 0.9569924337822094
ensemble_fusion Validation MIoU: 0.856691025516563
ensemble_fusion Validation Dice Score: 0.8969367146492004



Epoch 27/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.73img/s]

***
ensemble_fusion Loss: 0.8704963326454163


Epoch 27/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.77img/s]

---
ensemble_fusion Validation Dice Loss: 0.9398773908615112
ensemble_fusion Validation Pixel Accuracy: 0.9574477881716009
ensemble_fusion Validation MIoU: 0.8562843851837075
ensemble_fusion Validation Dice Score: 0.9012504816055298



Epoch 28/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.88img/s]

***
ensemble_fusion Loss: 0.8786061406135559


Epoch 28/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.88img/s]

---
ensemble_fusion Validation Dice Loss: 0.9379317164421082
ensemble_fusion Validation Pixel Accuracy: 0.9583491275185033
ensemble_fusion Validation MIoU: 0.8598538462155996
ensemble_fusion Validation Dice Score: 0.9007384181022644



Epoch 29/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.81img/s]

***
ensemble_fusion Loss: 0.9122899174690247


Epoch 29/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.88img/s]

---
ensemble_fusion Validation Dice Loss: 0.9470440149307251
ensemble_fusion Validation Pixel Accuracy: 0.9534700627912555
ensemble_fusion Validation MIoU: 0.849274426425251
ensemble_fusion Validation Dice Score: 0.8916950225830078



Epoch 30/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.92img/s]

***
ensemble_fusion Loss: 0.8891775608062744


Epoch 30/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.78img/s]

---
ensemble_fusion Validation Dice Loss: 0.9385699033737183
ensemble_fusion Validation Pixel Accuracy: 0.9569019518400493
ensemble_fusion Validation MIoU: 0.8536007123828091
ensemble_fusion Validation Dice Score: 0.9008553624153137



Epoch 31/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.81img/s]

***
ensemble_fusion Loss: 0.8890393376350403


Epoch 31/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.82img/s]

---
ensemble_fusion Validation Dice Loss: 0.9416272640228271
ensemble_fusion Validation Pixel Accuracy: 0.9562078107867324
ensemble_fusion Validation MIoU: 0.8536429508125728
ensemble_fusion Validation Dice Score: 0.9001812934875488



Epoch 32/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.78img/s]

***
ensemble_fusion Loss: 0.8562291860580444


Epoch 32/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.78img/s]

---
ensemble_fusion Validation Dice Loss: 0.939180850982666
ensemble_fusion Validation Pixel Accuracy: 0.957035265470806
ensemble_fusion Validation MIoU: 0.8560053846242328
ensemble_fusion Validation Dice Score: 0.9013611078262329



Epoch 33/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.87img/s]

***
ensemble_fusion Loss: 0.8644729852676392


Epoch 33/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.89img/s]

---
ensemble_fusion Validation Dice Loss: 0.9432032704353333
ensemble_fusion Validation Pixel Accuracy: 0.9566685191371984
ensemble_fusion Validation MIoU: 0.8552359053309373
ensemble_fusion Validation Dice Score: 0.9007154703140259



Epoch 34/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.02img/s]

***
ensemble_fusion Loss: 0.897464394569397


Epoch 34/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.83img/s]

---
ensemble_fusion Validation Dice Loss: 0.9376577138900757
ensemble_fusion Validation Pixel Accuracy: 0.9573883592036733
ensemble_fusion Validation MIoU: 0.8568022273238931
ensemble_fusion Validation Dice Score: 0.9023918509483337



Epoch 35/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.91img/s]

***
ensemble_fusion Loss: 0.8684409856796265


Epoch 35/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.79img/s]

---
ensemble_fusion Validation Dice Loss: 0.9400953054428101
ensemble_fusion Validation Pixel Accuracy: 0.9580142372532895
ensemble_fusion Validation MIoU: 0.8587074422098964
ensemble_fusion Validation Dice Score: 0.9026317000389099



Epoch 36/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.84img/s]

***
ensemble_fusion Loss: 0.8737529516220093


Epoch 36/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.81img/s]

---
ensemble_fusion Validation Dice Loss: 0.9401638507843018
ensemble_fusion Validation Pixel Accuracy: 0.9568722373560855
ensemble_fusion Validation MIoU: 0.8556310772817057
ensemble_fusion Validation Dice Score: 0.9019264578819275



Epoch 37/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.01img/s]

***
ensemble_fusion Loss: 0.8532878160476685


Epoch 37/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.85img/s]

---
ensemble_fusion Validation Dice Loss: 0.9396132230758667
ensemble_fusion Validation Pixel Accuracy: 0.9575149803830866
ensemble_fusion Validation MIoU: 0.8573263468633685
ensemble_fusion Validation Dice Score: 0.902149498462677



Epoch 38/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.80img/s]

***
ensemble_fusion Loss: 0.8603276610374451


Epoch 38/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.84img/s]

---
ensemble_fusion Validation Dice Loss: 0.9410092830657959
ensemble_fusion Validation Pixel Accuracy: 0.9580142372532895
ensemble_fusion Validation MIoU: 0.8585970302732769
ensemble_fusion Validation Dice Score: 0.903204083442688



Epoch 39/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.80img/s]

***
ensemble_fusion Loss: 0.8650626540184021


Epoch 39/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.78img/s]

---
ensemble_fusion Validation Dice Loss: 0.939513623714447
ensemble_fusion Validation Pixel Accuracy: 0.9569587038274396
ensemble_fusion Validation MIoU: 0.8560934350362381
ensemble_fusion Validation Dice Score: 0.9016685485839844



Epoch 40/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.81img/s]

***
ensemble_fusion Loss: 0.8501377105712891


Epoch 40/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.82img/s]

---
ensemble_fusion Validation Dice Loss: 0.9385477304458618
ensemble_fusion Validation Pixel Accuracy: 0.9573484721936678
ensemble_fusion Validation MIoU: 0.8563091166545738
ensemble_fusion Validation Dice Score: 0.9030628204345703



Epoch 41/50:  99%|█████████▊| 200/203 [00:16<00:00, 14.98img/s]

***
ensemble_fusion Loss: 0.8582183122634888


Epoch 41/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.81img/s]

---
ensemble_fusion Validation Dice Loss: 0.9390400648117065
ensemble_fusion Validation Pixel Accuracy: 0.9576019822505483
ensemble_fusion Validation MIoU: 0.8579672098107038
ensemble_fusion Validation Dice Score: 0.9018060564994812



Epoch 42/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.80img/s]

***
ensemble_fusion Loss: 0.8719002604484558


Epoch 42/50:  99%|█████████▊| 200/203 [00:20<00:00, 10.00img/s]

---
ensemble_fusion Validation Dice Loss: 0.9396172165870667
ensemble_fusion Validation Pixel Accuracy: 0.9573578416255483
ensemble_fusion Validation MIoU: 0.8573070335088766
ensemble_fusion Validation Dice Score: 0.9015771150588989



Epoch 43/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.88img/s]

***
ensemble_fusion Loss: 0.8569780588150024


Epoch 43/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.93img/s]

---
ensemble_fusion Validation Dice Loss: 0.9395996332168579
ensemble_fusion Validation Pixel Accuracy: 0.9583038865474233
ensemble_fusion Validation MIoU: 0.8594388977839178
ensemble_fusion Validation Dice Score: 0.9034644961357117



Epoch 44/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.95img/s]

***
ensemble_fusion Loss: 0.8526350259780884


Epoch 44/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.82img/s]

---
ensemble_fusion Validation Dice Loss: 0.9418407082557678
ensemble_fusion Validation Pixel Accuracy: 0.9572124815823739
ensemble_fusion Validation MIoU: 0.8571335972504259
ensemble_fusion Validation Dice Score: 0.9011543393135071



Epoch 45/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.94img/s]

***
ensemble_fusion Loss: 0.8532992601394653


Epoch 45/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.91img/s]

---
ensemble_fusion Validation Dice Loss: 0.9377349019050598
ensemble_fusion Validation Pixel Accuracy: 0.9580391331722862
ensemble_fusion Validation MIoU: 0.8588908183633689
ensemble_fusion Validation Dice Score: 0.9032840132713318



Epoch 46/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.06img/s]

***
ensemble_fusion Loss: 0.8780179619789124


Epoch 46/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.74img/s]

---
ensemble_fusion Validation Dice Loss: 0.9395672082901001
ensemble_fusion Validation Pixel Accuracy: 0.9577682227419134
ensemble_fusion Validation MIoU: 0.8580798668963154
ensemble_fusion Validation Dice Score: 0.9027735590934753



Epoch 47/50:  99%|█████████▊| 200/203 [00:15<00:00, 15.01img/s]

***
ensemble_fusion Loss: 0.8877195715904236


Epoch 47/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.82img/s]

---
ensemble_fusion Validation Dice Loss: 0.9385584592819214
ensemble_fusion Validation Pixel Accuracy: 0.9569825289542215
ensemble_fusion Validation MIoU: 0.8568480890029293
ensemble_fusion Validation Dice Score: 0.9001303911209106



Epoch 48/50:  99%|█████████▊| 200/203 [00:15<00:00, 14.94img/s]

***
ensemble_fusion Loss: 0.8496623635292053


Epoch 48/50:  99%|█████████▊| 200/203 [00:20<00:00,  9.80img/s]

---
ensemble_fusion Validation Dice Loss: 0.937703549861908
ensemble_fusion Validation Pixel Accuracy: 0.9578134637129935
ensemble_fusion Validation MIoU: 0.8581274413331605
ensemble_fusion Validation Dice Score: 0.9026322364807129



Epoch 49/50:   0%|          | 0/203 [00:01<?, ?img/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



In [None]:
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
df = pd.DataFrame([TrainLoss_list1, val_losses, valScore_list1, val_accs, val_mious]).T
df.columns = ['train_loss', 'val_loss', 'val_score', 'val_acc', 'val_miou']
df.to_csv(checkpoint_dir + '/model_check.csv', encoding = 'UTF-8')

In [None]:
plt.figure(figsize= (10,5))
plt.plot(TrainLoss_list1)
plt.plot(val_losses)

In [None]:
plt.figure(figsize= (10,5))
plt.plot(valScore_list1)
plt.plot(val_accs)
plt.plot(val_mious)