In [1]:
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader, random_split
import torch
from torch import optim
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import logging
from evaluate import evaluate

import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import OrderedDict

from model.ensemblenet_model import EnsembleNet


from utils.dice_score import dice_loss
from utils.data_load import KittiDataset
from torchsummaryX import summary

In [2]:
Val_Percent = 0.3
Scale_Percent = 1.0
Batch_Size = 8
learning_rate = 0.0001
Pin_Memory = False
epochs = 50


Image_Size = [384, 1216]
Gradient_Clipping = 0.8


Num_Class = 2
Num_Channel = 3
amp = True

Model_Name = 'ensemble_fusion'


Img_Path =  'data/data_road/training/image_2'
Mask_Path =  'data/data_road/training/semantic'

save_checkpoint = True
checkpoint_dir = '../trained' + '_' + Model_Name
batch_size = Batch_Size

In [3]:
dirImg = Path(Img_Path)
dirMask = Path(Mask_Path)
dir_checkpoint = Path(checkpoint_dir)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
datasets =  KittiDataset(dirImg, dirMask, Image_Size, Scale_Percent)
n_val = int(len(datasets) * Val_Percent)
n_train = len(datasets) - n_val
train_set, val_set = random_split(datasets, [n_train, n_val], generator=torch.Generator().manual_seed(0))

loader_args = dict(batch_size=Batch_Size, num_workers= os.cpu_count(), pin_memory=Pin_Memory)
train_loader = DataLoader(train_set, shuffle=True, drop_last = True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

100%|██████████| 289/289 [00:00<00:00, 821.83it/s]


In [6]:
model = EnsembleNet(Model_Name, Num_Channel, Num_Class)
model = model.to(memory_format=torch.channels_last, device = device)

In [7]:
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
#optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
#optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)

if 'ensemble_voting' in Model_Name:
    unet_optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
    segnet_optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
    enet_optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
    optims = [unet_optimizer, segnet_optimizer, enet_optimizer]
    
    unet_scheduler = optim.lr_scheduler.ReduceLROnPlateau(unet_optimizer, 'max', patience=2)  # goal: maximize Dice score
    segnet_scheduler = optim.lr_scheduler.ReduceLROnPlateau(segnet_optimizer, 'max', patience=2)  # goal: maximize Dice score
    enet_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enet_optimizer, 'max', patience=2)  # goal: maximize Dice score
       
else:
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-10)
    #optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-10)
    optims = [optimizer]
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score

grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0

In [8]:
def calculate_loss(pred, true_masks, nclass, multiclass):
    loss = criterion(pred, true_masks)
    loss += dice_loss(
        F.softmax(pred, dim=1).float(),
        F.one_hot(true_masks, nclass).permute(0, 3, 1, 2).float(),
        multiclass=multiclass
    )
    return loss

def grad_forback(models, losses, optim):
    optim.zero_grad(set_to_none=True)
    grad_scaler.scale(losses).backward()
    torch.nn.utils.clip_grad_norm_(models.parameters(), Gradient_Clipping)
    grad_scaler.step(optim)
    grad_scaler.update()    

def forward_and_backward(model, images, true_masks, amp, optimizers, grad_scaler, model_name):
    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        
        try:
            mn_cls = model.n_classes
        except:
            mn_cls = model.classifier[-1].out_channels

        if model_name == 'ensemble_voting':
            unet_pred, segnet_pred, enet_pred = model(images)
            #deeplab_pred = deeplab_pred['out']
            
            unet_loss = calculate_loss(unet_pred, true_masks, mn_cls, multiclass=True)
            segnet_loss = calculate_loss(segnet_pred, true_masks, mn_cls, multiclass=True)
            enet_loss = calculate_loss(enet_pred, true_masks, mn_cls, multiclass=True)
        
        else:
            masks_pred = model(images)
            if isinstance(masks_pred, OrderedDict):
                masks_pred = masks_pred['out']
            loss = calculate_loss(masks_pred, true_masks, mn_cls, multiclass=True)
            
    
    if model_name == 'ensemble_voting':
        for _loss, _optiz in zip([unet_loss, segnet_loss, enet_loss], optimizers):
            grad_forback(model, _loss, _optiz)

        return model, unet_loss, segnet_loss, enet_loss
    else:
        for _loss, _optiz in zip([loss], optimizers):
            grad_forback(model, _loss, _optiz)
            
        return model, loss


In [9]:
valScore_list1 = []
TrainLoss_list1 = []

valScore_list2 = []
TrainLoss_list2 = []

valScore_list3 = []
TrainLoss_list3 = []

valScore_list4 = []
TrainLoss_list4 = []

val_losses = []
val_accs = []
val_mious = []

# 5. Begin training
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    epoch_unet_loss = 0
    epoch_segnet_loss = 0
    epoch_enet_loss = 0
    epoch_voting_loss = 0
    
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            images, true_masks = batch['image'], batch['mask']

            images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            true_masks = true_masks.to(device=device, dtype=torch.long)

                
            result = forward_and_backward(model, images, true_masks, amp, optims, grad_scaler, Model_Name)
            
            if len(result) == 4:
                model, unet_loss, segnet_loss, enet_loss = result
                
                pbar.update(images.shape[0])
                global_step += 1
                epoch_unet_loss += unet_loss.item()
                epoch_segnet_loss += segnet_loss.item()
                epoch_enet_loss += enet_loss.item()
                vot_loss = ((unet_loss.item() + segnet_loss.item() + enet_loss.item()) /3)
                epoch_voting_loss += vot_loss
                
                
            elif len(result) == 2:
                model, loss = result
                
                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()


        print('***')
        if len(result) == 4:
            print('Unet Loss: {}     Segnet Loss: {}     Enet Loss: {}'.format(unet_loss, segnet_loss, enet_loss))
            print('Voting Loss: {}'.format(vot_loss))
            
            
        elif len(result) == 2:
            print('{} Loss: {}'.format(Model_Name, loss))

        # Evaluation round
        division_step = (n_train // (5 * batch_size))
        if division_step > 0:
            #if global_step % division_step == 0:
            if len(result) == 4:
                unet_val_score, segnet_val_score, enet_val_score, voting_val_score, val_loss, val_acc, val_miou = evaluate(model, val_loader, criterion, device, Model_Name, amp)
                
                unet_scheduler.step(unet_val_score)
                segnet_scheduler.step(segnet_val_score)
                enet_scheduler.step(enet_val_score)
                #voting_scheduler.step(voting_val_score)
                
                valScore_list1.append(unet_val_score.cpu().detach().numpy())
                TrainLoss_list1.append(unet_loss.cpu().detach().numpy())
                valScore_list2.append(segnet_val_score.cpu().detach().numpy())
                TrainLoss_list2.append(segnet_loss.cpu().detach().numpy())                
                valScore_list3.append(enet_val_score.cpu().detach().numpy())
                TrainLoss_list3.append(enet_loss.cpu().detach().numpy())
                valScore_list4.append(voting_val_score.cpu().detach().numpy())
                TrainLoss_list4.append(vot_loss)
                
                val_losses.append(val_loss)
                val_accs.append(val_acc)
                val_mious.append(val_miou)
                
                print('---')
                print('Unet Validation Dice Score: {}     Segnet Validation Dice Score: {}     Enet Validation Dice Score: {}'.format(unet_val_score, segnet_val_score, enet_val_score))
                print('---')
                print('Ensemble Voting Validation Dice Loss: {}'.format(val_loss))
                print('Ensemble Voting Validation Pixel Accuracy: {} '.format(val_acc))
                print('Ensemble Voting Validation MIoU: {}'.format(val_miou))                
                print('Ensemble Voting Validation Dice Score: {} '.format(voting_val_score))
                
            else:
                val_score, val_loss, val_acc, val_miou = evaluate(model, val_loader, criterion, device, Model_Name, amp)
                
                                
                scheduler.step(val_score)
                
                print('---')
                print('{} Validation Dice Loss: {}'.format(Model_Name, val_loss))   
                print('{} Validation Pixel Accuracy: {}'.format(Model_Name, val_acc))
                print('{} Validation MIoU: {}'.format(Model_Name, val_miou))
                print('{} Validation Dice Score: {}'.format(Model_Name, val_score))
                
            
                valScore_list1.append(val_score.cpu().detach().numpy())
                TrainLoss_list1.append(loss.cpu().detach().numpy())
                val_losses.append(val_loss)
                val_accs.append(val_acc)
                val_mious.append(val_miou)

        
    if save_checkpoint:
        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))

Epoch 1/50:  99%|█████████▊| 200/203 [00:10<00:00, 39.12img/s]

***
ensemble_fusion Loss: 0.7602043151855469


Epoch 1/50:  99%|█████████▊| 200/203 [00:14<00:00, 13.67img/s]

---
ensemble_fusion Validation Dice Loss: 1.5698583126068115
ensemble_fusion Validation Pixel Accuracy: 0.1499060915227522
ensemble_fusion Validation MIoU: 0.07503425123185858
ensemble_fusion Validation Dice Score: 0.28688129782676697



Epoch 2/50:  99%|█████████▊| 200/203 [00:07<00:00, 39.13img/s]

***
ensemble_fusion Loss: 0.5481064915657043


Epoch 2/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.49img/s]

---
ensemble_fusion Validation Dice Loss: 0.9922606348991394
ensemble_fusion Validation Pixel Accuracy: 0.8306459125719572
ensemble_fusion Validation MIoU: 0.6265882538154197
ensemble_fusion Validation Dice Score: 0.6550341248512268



Epoch 3/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.42img/s]

***
ensemble_fusion Loss: 0.6199971437454224


Epoch 3/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.78img/s]

---
ensemble_fusion Validation Dice Loss: 0.7172279357910156
ensemble_fusion Validation Pixel Accuracy: 0.8965652198122259
ensemble_fusion Validation MIoU: 0.7181075056523072
ensemble_fusion Validation Dice Score: 0.7234596610069275



Epoch 4/50:  99%|█████████▊| 200/203 [00:07<00:00, 39.05img/s]

***
ensemble_fusion Loss: 0.42933768033981323


Epoch 4/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.52img/s]

---
ensemble_fusion Validation Dice Loss: 0.5526084899902344
ensemble_fusion Validation Pixel Accuracy: 0.9210095321922972
ensemble_fusion Validation MIoU: 0.7626234752126163
ensemble_fusion Validation Dice Score: 0.7505028247833252



Epoch 5/50:  99%|█████████▊| 200/203 [00:07<00:00, 39.20img/s]

***
ensemble_fusion Loss: 0.47486403584480286


Epoch 5/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.82img/s]

---
ensemble_fusion Validation Dice Loss: 0.5072843432426453
ensemble_fusion Validation Pixel Accuracy: 0.9184714869449013
ensemble_fusion Validation MIoU: 0.7467993835737035
ensemble_fusion Validation Dice Score: 0.7589864730834961



Epoch 6/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.99img/s]

***
ensemble_fusion Loss: 0.4251369535923004


Epoch 6/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.41img/s]

---
ensemble_fusion Validation Dice Loss: 0.48111119866371155
ensemble_fusion Validation Pixel Accuracy: 0.9225078381990132
ensemble_fusion Validation MIoU: 0.7585328033305672
ensemble_fusion Validation Dice Score: 0.7968313097953796



Epoch 7/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.88img/s]

***
ensemble_fusion Loss: 0.38579732179641724


Epoch 7/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.71img/s]

---
ensemble_fusion Validation Dice Loss: 0.5662614107131958
ensemble_fusion Validation Pixel Accuracy: 0.9151648805852521
ensemble_fusion Validation MIoU: 0.7615362100179089
ensemble_fusion Validation Dice Score: 0.7626034617424011



Epoch 8/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.65img/s]

***
ensemble_fusion Loss: 0.43186938762664795


Epoch 8/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.99img/s]

---
ensemble_fusion Validation Dice Loss: 0.46159008145332336
ensemble_fusion Validation Pixel Accuracy: 0.91200309887267
ensemble_fusion Validation MIoU: 0.7131100863399817
ensemble_fusion Validation Dice Score: 0.6637545824050903



Epoch 9/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.69img/s]

***
ensemble_fusion Loss: 0.34398892521858215


Epoch 9/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.78img/s]

---
ensemble_fusion Validation Dice Loss: 0.458099901676178
ensemble_fusion Validation Pixel Accuracy: 0.9257279780873081
ensemble_fusion Validation MIoU: 0.7747561287323692
ensemble_fusion Validation Dice Score: 0.7855696082115173



Epoch 10/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.74img/s]

***
ensemble_fusion Loss: 0.2948160171508789


Epoch 10/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.24img/s]

---
ensemble_fusion Validation Dice Loss: 0.3713113069534302
ensemble_fusion Validation Pixel Accuracy: 0.9381917317708334
ensemble_fusion Validation MIoU: 0.8010610698225016
ensemble_fusion Validation Dice Score: 0.8396633267402649



Epoch 11/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.54img/s]

***
ensemble_fusion Loss: 0.2976301908493042


Epoch 11/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.54img/s]

---
ensemble_fusion Validation Dice Loss: 0.35588181018829346
ensemble_fusion Validation Pixel Accuracy: 0.9432249905770285
ensemble_fusion Validation MIoU: 0.8140148678803532
ensemble_fusion Validation Dice Score: 0.8473102450370789



Epoch 12/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.84img/s]

***
ensemble_fusion Loss: 0.3143421411514282


Epoch 12/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.24img/s]

---
ensemble_fusion Validation Dice Loss: 0.3327033519744873
ensemble_fusion Validation Pixel Accuracy: 0.949877287212171
ensemble_fusion Validation MIoU: 0.8301101970024788
ensemble_fusion Validation Dice Score: 0.8483392000198364



Epoch 13/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.71img/s]

***
ensemble_fusion Loss: 0.2943533658981323


Epoch 13/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.48img/s]

---
ensemble_fusion Validation Dice Loss: 0.34705978631973267
ensemble_fusion Validation Pixel Accuracy: 0.9466196695963541
ensemble_fusion Validation MIoU: 0.8242462366165337
ensemble_fusion Validation Dice Score: 0.8597230315208435



Epoch 14/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.76img/s]

***
ensemble_fusion Loss: 0.33583757281303406


Epoch 14/50:  99%|█████████▊| 200/203 [00:11<00:00, 17.97img/s]

---
ensemble_fusion Validation Dice Loss: 0.32044750452041626
ensemble_fusion Validation Pixel Accuracy: 0.9533113178453947
ensemble_fusion Validation MIoU: 0.8398328441747032
ensemble_fusion Validation Dice Score: 0.8552243113517761



Epoch 15/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.69img/s]

***
ensemble_fusion Loss: 0.2755896747112274


Epoch 15/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.44img/s]

---
ensemble_fusion Validation Dice Loss: 0.3402639627456665
ensemble_fusion Validation Pixel Accuracy: 0.9480539957682291
ensemble_fusion Validation MIoU: 0.8266051792108802
ensemble_fusion Validation Dice Score: 0.8637052774429321



Epoch 16/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.67img/s]

***
ensemble_fusion Loss: 0.32273250818252563


Epoch 16/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.52img/s]

---
ensemble_fusion Validation Dice Loss: 0.32955870032310486
ensemble_fusion Validation Pixel Accuracy: 0.9497487921463815
ensemble_fusion Validation MIoU: 0.8324214848989885
ensemble_fusion Validation Dice Score: 0.859879195690155



Epoch 17/50:  99%|█████████▊| 200/203 [00:07<00:00, 38.44img/s]

***
ensemble_fusion Loss: 0.2603551745414734


Epoch 17/50:  99%|█████████▊| 200/203 [00:10<00:00, 18.31img/s]

---
ensemble_fusion Validation Dice Loss: 0.3274252116680145
ensemble_fusion Validation Pixel Accuracy: 0.9521883245100055
ensemble_fusion Validation MIoU: 0.8410275481883932
ensemble_fusion Validation Dice Score: 0.866078794002533



Epoch 18/50:   0%|          | 0/203 [00:01<?, ?img/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_2900603/2213695298.py", line 27, in <module>
    for batch in train_loader:
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1316, in _next_data
    idx, data = self._get_data()
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1282, in _get_data
    success, data = self._try_get_data()
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1120, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/user1/anaconda3/envs/k

TypeError: object of type 'NoneType' has no len()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
df = pd.DataFrame([TrainLoss_list1, val_losses, valScore_list1, val_accs, val_mious]).T
df.columns = ['train_loss', 'val_loss', 'val_score', 'val_acc', 'val_miou']
df.to_csv(checkpoint_dir + '/model_check.csv', encoding = 'UTF-8')

In [None]:
plt.figure(figsize= (10,5))
plt.plot(TrainLoss_list1)
plt.plot(val_losses)

In [None]:
plt.figure(figsize= (10,5))
plt.plot(valScore_list1)
plt.plot(val_accs)
plt.plot(val_mious)