In [1]:
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader, random_split
import torch
from torch import optim
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import logging
from evaluate import evaluate
import os
import torchsummary
import pytorch_model_summary

from model.unet.unet_model import UNet
from model.segnet.segnet_model import SegNet
from model.ensemblenet_model import EnsembleNet as esbnet
#from model.segnet.segnet2 import SegNet

from utils.dice_score import dice_loss
from utils.data_load import KittiDataset

In [2]:
Val_Percent = 0.4
Scale_Percent = 0.5
Batch_Size = 8
learning_rate = 0.001
Pin_Memory = False
epochs = 100
#Image_Size = [375, 1242]
Image_Size = [384, 1216]
Gradient_Clipping = 1.0


#Num_Class = 34
Num_Class = 31
Num_Channel = 3
amp = False

Img_Path =  'data/training/image_2'
Mask_Path = 'data/training/semantic'

#Img_Path =  'data/feature'
#Mask_Path = 'data/target'



save_checkpoint = True
checkpoint_dir = '../trained'
batch_size = Batch_Size

In [3]:
dirImg = Path(Img_Path)
dirMask = Path(Mask_Path)
dir_checkpoint = Path(checkpoint_dir)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
datasets =  KittiDataset(dirImg, dirMask, Image_Size, Scale_Percent)
n_val = int(len(datasets) * Val_Percent)
n_train = len(datasets) - n_val
train_set, val_set = random_split(datasets, [n_train, n_val], generator=torch.Generator().manual_seed(0))

loader_args = dict(batch_size=Batch_Size, num_workers= os.cpu_count(), pin_memory=Pin_Memory)
train_loader = DataLoader(train_set, shuffle=True, drop_last = True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 445.88it/s]


In [6]:
unet = UNet(Num_Channel, Num_Class)
segnet = SegNet(Num_Channel, Num_Class)

In [7]:
model = esbnet(unet, segnet)
model = model.to(memory_format=torch.channels_last, device = device)

print(pytorch_model_summary.summary(model, torch.cuda.FloatTensor(1, 3,384, 1216), show_parent_layers=True, max_depth= None))

model = esbnet(Num_Channel, Num_Class)
model = model.to(memory_format=torch.channels_last, device = device)

In [6]:
x1 = torch.cuda.FloatTensor(1, 31,384, 1216)
x2 = torch.cuda.FloatTensor(1, 31,384, 1216)

In [7]:
x1 = torch.FloatTensor(1, 31,384, 1216)
x2 = torch.FloatTensor(1, 31,384, 1216)

In [8]:
torch.dot(x1,x2)

RuntimeError: 1D tensors expected, but got 4D and 4D tensors

In [34]:
x3 = torch.mul(x1, x2)

In [35]:
x3 = x3.to(device)

In [31]:
con = nn.Conv2d(31, 31, kernel_size=3, padding=1)

In [37]:
con = con.to(device)

In [39]:
con(x3).shape

torch.Size([1, 31, 384, 1216])

torch.Size([4, 31, 192, 608])

In [8]:
print(pytorch_model_summary.summary(model, torch.cuda.FloatTensor(1, 3,384, 1216), show_parent_layers=True, max_depth= None))

-----------------------------------------------------------------------------------------------------------------------------------
                       Parent Layers          Layer (type)                             Output Shape         Param #     Tr. Param #
         EnsembleNet/UNet/DoubleConv              Conv2d-1                       [1, 64, 384, 1216]           1,728           1,728
         EnsembleNet/UNet/DoubleConv         BatchNorm2d-2                       [1, 64, 384, 1216]             128             128
         EnsembleNet/UNet/DoubleConv                ReLU-3                       [1, 64, 384, 1216]               0               0
         EnsembleNet/UNet/DoubleConv              Conv2d-4                       [1, 64, 384, 1216]          36,864          36,864
         EnsembleNet/UNet/DoubleConv         BatchNorm2d-5                       [1, 64, 384, 1216]             128             128
         EnsembleNet/UNet/DoubleConv                ReLU-6                  

print(pytorch_model_summary.summary(model, torch.cuda.FloatTensor(1, 3,384, 1216), show_parent_layers=True, max_depth= None))

torchsummary.summary(model, (3, 384, 1216))

In [8]:
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
#optimizer = optim.RMSprop(unet.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0

In [9]:
# 5. Begin training
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            images, true_masks = batch['image'], batch['mask']

            images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            true_masks = true_masks.to(device=device, dtype=torch.long)

            with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                #masks_pred, masks_pred2 = model(images)
                masks_pred = model(images)
                
                if model.n_classes == 1:
                #if model.unet.n_classes == 1:    
                    loss = criterion(masks_pred.squeeze(1), true_masks.float())
                    loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
                else:
                    loss = criterion(masks_pred, true_masks)
                    loss += dice_loss(
                        F.softmax(masks_pred, dim=1).float(),
                        F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                        #F.one_hot(true_masks, model.unet.n_classes).permute(0, 3, 1, 2).float(),
                        multiclass=True
                    )

            optimizer.zero_grad(set_to_none=True)
            grad_scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), Gradient_Clipping)
            grad_scaler.step(optimizer)
            grad_scaler.update()

            pbar.update(images.shape[0])
            global_step += 1
            epoch_loss += loss.item()


            # Evaluation round
            division_step = (n_train // (5 * batch_size))
            if division_step > 0:
                if global_step % division_step == 0:

                    val_score = evaluate(model, val_loader, device, amp)
                    
                    scheduler.step(val_score)

                    #logging.info('Validation Dice score: {}'.format(val_score))
                    print('Validation Dice score: {}'.format(val_score))
                    
    if save_checkpoint:
        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))

Epoch 1/100:  20%|██████████████████████████████████████████████▊                                                                                                                                                                                           | 24/120 [00:05<00:16,  5.87img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
Validation round:  10%|██████████████████████▉                                                                                                                                                                                                              | 1/10 [00:02<00:22,  2.47s/batch][A
Validation round:  30%|████████████████████████████████████████████████████████████████████▋                                         

Validation Dice score: 0.5925632119178772


Epoch 1/100:  40%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                            | 48/120 [00:09<00:09,  7.50img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
Validation round:  10%|██████████████████████▉                                                                                                                                                                                                              | 1/10 [00:02<00:22,  2.47s/batch][A
Validation round:  30%|████████████████████████████████████████████████████████████████████▋                                         

Validation Dice score: 0.5950719714164734


Epoch 1/100:  60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                             | 72/120 [00:13<00:05,  8.07img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
Validation round:  10%|██████████████████████▉                                                                                                                                                                                                              | 1/10 [00:02<00:19,  2.16s/batch][A
Validation round:  20%|█████████████████████████████████████████████▊                                                                

Validation Dice score: 0.5651812553405762


Epoch 1/100:  80%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                              | 96/120 [00:17<00:02,  8.49img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
Validation round:  10%|██████████████████████▉                                                                                                                                                                                                              | 1/10 [00:02<00:22,  2.52s/batch][A
Validation round:  30%|████████████████████████████████████████████████████████████████████▋                                         

Validation Dice score: 0.5954886078834534


Epoch 1/100: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:21<00:00,  8.32img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
Validation round:  10%|██████████████████████▉                                                                                                                                                                                                              | 1/10 [00:02<00:20,  2.32s/batch][A
Validation round:  30%|████████████████████████████████████████████████████████████████████▋                                         

Validation Dice score: 0.48123255372047424


Epoch 2/100:  20%|██████████████████████████████████████████████▊                                                                                                                                                                                           | 24/120 [00:03<00:08, 10.68img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
Validation round:  10%|██████████████████████▉                                                                                                                                                                                                              | 1/10 [00:02<00:20,  2.32s/batch][A
Validation round:  20%|█████████████████████████████████████████████▊                                                                

Validation Dice score: 0.47598153352737427


Epoch 2/100:  40%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                                                            | 48/120 [00:06<00:08,  8.96img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
Validation round:  10%|██████████████████████▉                                                                                                                                                                                                              | 1/10 [00:02<00:22,  2.52s/batch][A
Validation round:  20%|█████████████████████████████████████████████▊                                                                

Validation Dice score: 0.5604016780853271


Epoch 2/100:  60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                             | 72/120 [00:11<00:05,  8.35img/s]
Validation round:   0%|                                                                                                                                                                                                                                             | 0/10 [00:00<?, ?batch/s][A
                                                                                                                                                                                                                                                                                              [AException ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28a9c9d710>
Traceback (most recent call last):
  File "/

Traceback (most recent call last):
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_1156415/1565076085.py", line 45, in <module>
    val_score = evaluate(model, val_loader, device, amp)
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/user1/projects/ksh/Ensemble-Net/evaluate.py", line 16, in evaluate
    for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/tqdm/std.py", line 1182, in __iter__
    for obj in iterable:
  File "/home/user1/anaconda3/envs/ksh/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/user1/anaconda3/e

TypeError: object of type 'NoneType' has no len()

In [None]:
masks_pred.shape

In [14]:
masks_pred[0].shape

torch.Size([31, 192, 608])