# LIDC IDRI 2D SEGMENTATION WITH TERNARY CLASSES

## Import Libraries

In [1]:
import pandas as pd
import argparse
import os
from collections import OrderedDict
from glob import glob
import yaml
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms.functional as TF
from torchvision import transforms
import torchsummary as summary

import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, jaccard_score
from tqdm import tqdm

from Unet_new.unet_model import UNet
from UnetNested.Nested_Unet import NestedUNet

## Define Parameters

In [2]:
name = "UNet"           #default = "UNet"; can be NestedUNet
epochs = 100            #default = 400
batch_size = 12          #default = 12
early_stopping = 50     #default = 50
num_workers = 8        #default = 8
optimizer = 'Adam'      #default = 'Adam'; can be SGD
lr = 1e-5               #default = 1e-5
momentum = 0.9          #default = 0.9
weight_decay = 1e-4     #default = 1e-4
nesterov = False        #default = False
augmentation = True     #default = False

## Define Functions

### Dataset

In [3]:
class LidcDataset(Dataset):
    def __init__(self, IMAGES_PATHS, MASK_PATHS, transforms):
        self.image_paths = IMAGES_PATHS
        self.mask_paths = MASK_PATHS
        
        self.transforms = transforms

    def __getitem__(self, index):
        image = np.load(self.image_paths[index])
        mask = np.load(self.mask_paths[index])

        #Make image and mask 3 dimensional
        image = image.reshape(512,512,1)
        mask = mask.reshape(512,512,1)

        #Convert datatype
        mask = mask.astype('uint8')

        #Apply augmentation
        augmented = self.transforms(image=image,mask=mask)
        image = augmented['image']
        mask = augmented['mask']
        mask = mask.reshape([1,512,512])

        image, mask = image.type(torch.FloatTensor), mask.type(torch.FloatTensor)     

        return image, mask
    
    def __len__(self):
        return len(self.image_paths)

In [4]:
transform = A.Compose([
            A.ElasticTransform(alpha=1.1,alpha_affine=0.5,sigma=5,p=0.15),
            A.HorizontalFlip(p=0.5),
            ToTensorV2()
        ])

### Metrics

In [5]:
def sensitivity_metric(target, output):
    tn, fp, fn, tp = confusion_matrix(target, output)
    s0 = tp[0]/(tp[0]+fn[0])
    s1 = tp[1]/(tp[1]+fn[1])
    s2 = tp[2]/(tp[2]+fn[2])
    s3 = tp[3]/(tp[3]+fn[3])
    sensitivity = np.array([s0, s1, s2, s3]) 
    return sensitivity

### Utilities

In [6]:
def str_to_bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class AverageMeter(object):
    #Computes and stores the average and current value
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Get Configuration

In [7]:
if augmentation == True:
    file_name = name + '_with_augmentation'
else:
    file_name = name + '_base'
os.makedirs('model_outputs/{}'.format(file_name), exist_ok=True)
print("Creating directory called ", file_name)

print('-' * 20)
print("Configuration Setting: ")
print("Model: ", name)
print("Max Epochs: ", epochs)
print("Batch Size: ", batch_size)
print("Number of Workers: ", num_workers)
print("Optimizer: ", optimizer)
print("Learning Rate: ", lr)
print("Augmentation: ", augmentation)

Creating directory called  UNet_with_augmentation
--------------------
Configuration Setting: 
Model:  UNet
Max Epochs:  100
Batch Size:  12
Number of Workers:  8
Optimizer:  Adam
Learning Rate:  1e-05
Augmentation:  True


## Create Model

In [8]:
criterion = torch.nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else torch.nn.CrossEntropyLoss()
cudnn.benchmark = True

#Creating the model
print("Creating model...")
if name == 'NestedUNet':
    model = NestedUNet(num_classes=4)
else:
    model = UNet(n_channels=1, n_classes=4)
model = model.cuda() if torch.cuda.is_available() else model

if torch.cuda.device_count() > 1:
    print("We can use ", torch.cuda.device_count(), " GPUs.")
    model = nn.DataParallel(model)

params = filter(lambda p: p.requires_grad, model.parameters())

if optimizer == 'Adam':
    optimizer = optim.Adam(params, lr=lr, weight_decay=weight_decay)
elif optimizer == 'SGD':
    optimizer = optim.SGD(params, lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
else:
    raise NotImplementedError
    
summary.summary(model,(1,512,512))

Creating model...
Layer (type:depth-idx)                   Output Shape              Param #
├─inconv: 1-1                            [-1, 64, 512, 512]        --
|    └─double_conv: 2-1                  [-1, 64, 512, 512]        --
|    |    └─Sequential: 3-1              [-1, 64, 512, 512]        37,824
├─down: 1-2                              [-1, 128, 256, 256]       --
|    └─Sequential: 2-2                   [-1, 128, 256, 256]       --
|    |    └─MaxPool2d: 3-2               [-1, 64, 256, 256]        --
|    |    └─double_conv: 3-3             [-1, 128, 256, 256]       221,952
├─down: 1-3                              [-1, 256, 128, 128]       --
|    └─Sequential: 2-3                   [-1, 256, 128, 128]       --
|    |    └─MaxPool2d: 3-4               [-1, 128, 128, 128]       --
|    |    └─double_conv: 3-5             [-1, 256, 128, 128]       886,272
├─down: 1-4                              [-1, 512, 64, 64]         --
|    └─Sequential: 2-4                   [-1, 512, 64

Layer (type:depth-idx)                   Output Shape              Param #
├─inconv: 1-1                            [-1, 64, 512, 512]        --
|    └─double_conv: 2-1                  [-1, 64, 512, 512]        --
|    |    └─Sequential: 3-1              [-1, 64, 512, 512]        37,824
├─down: 1-2                              [-1, 128, 256, 256]       --
|    └─Sequential: 2-2                   [-1, 128, 256, 256]       --
|    |    └─MaxPool2d: 3-2               [-1, 64, 256, 256]        --
|    |    └─double_conv: 3-3             [-1, 128, 256, 256]       221,952
├─down: 1-3                              [-1, 256, 128, 128]       --
|    └─Sequential: 2-3                   [-1, 256, 128, 128]       --
|    |    └─MaxPool2d: 3-4               [-1, 128, 128, 128]       --
|    |    └─double_conv: 3-5             [-1, 256, 128, 128]       886,272
├─down: 1-4                              [-1, 512, 64, 64]         --
|    └─Sequential: 2-4                   [-1, 512, 64, 64]         --
|

## Load Dataset

In [9]:
#directory of Images and Masks folders (generated from preprocessing)                                         
IMAGE_DIR = '/scratch1/joseph.portugal/LIDC-IDRI Preprocessed Exp 3/Image/'
MASK_DIR = '/scratch1/joseph.portugal/LIDC-IDRI Preprocessed Exp 3/Mask/'                                                                 

#meta information
meta = pd.read_csv('/scratch1/joseph.portugal/LIDC-IDRI Preprocessed Exp 3/Meta/meta.csv')
meta = meta[meta['patient_diagnosis'] != 0]

#Get train/test label from metadata file
meta['original_image'] = meta['original_image'].apply(lambda x: IMAGE_DIR + "LIDC-IDRI-" + x[:4] + "/" + x + ".npy")
meta['mask_image'] = meta['mask_image'].apply(lambda x: MASK_DIR + "LIDC-IDRI-" + x[:4] + "/" + x + ".npy")


#Split into training and validation
train_meta = meta[meta['data_split']=='Train']
val_meta = meta[meta['data_split']=='Validation']

#Get training images into list
train_image_paths = list(train_meta['original_image'])
train_mask_paths = list(train_meta['mask_image'])

#Get validation images into list
val_image_paths = list(val_meta['original_image'])
val_mask_paths = list(val_meta['mask_image'])

print("*"*50)
print("Original images: {}, masks: {} for training.".format(len(train_image_paths),len(train_mask_paths)))
print("Original images: {}, masks: {} for validation.".format(len(val_image_paths),len(val_mask_paths)))
print("Ratio between Validation and Training is {:2f}".format(len(val_image_paths)/len(train_image_paths)))
print("*"*50)


#Creating custom LIDC dataset
train_dataset = LidcDataset(train_image_paths, train_mask_paths, transforms=transform)
val_dataset = LidcDataset(val_image_paths, val_mask_paths, transforms=transform)

#Creating Dataloader
train_loader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  shuffle=True,
  pin_memory=True,
  drop_last=True,
  num_workers=num_workers
)
val_loader = DataLoader(
  val_dataset,
  batch_size=batch_size,
  shuffle=False,
  pin_memory=True,
  drop_last=False,
  num_workers=num_workers
)

**************************************************
Original images: 980, masks: 980 for training.
Original images: 177, masks: 177 for validation.
Ratio between Validation and Training is 0.180612
**************************************************


## Train the Model

In [10]:
torch.cuda.empty_cache()

In [11]:
log = pd.DataFrame(index=[], columns=['epoch','lr','loss','iou','dice','sensitivity','val_loss','val_iou', 'val_sensitivity'])
# log = pd.DataFrame(index=[], columns=['epoch','lr','loss','iou','dice','val_loss','val_iou'])

best_dice = 0
trigger = 0

for epoch in range(epochs):

    #Model Training
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter(), 'dice': AverageMeter(), 'sensitivity': AverageMeter()}
#     avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter(), 'dice': AverageMeter()}
    model.train()
    pbar = tqdm(total=len(train_loader)) #progress bar

    for i, data in enumerate(train_loader):

        input = data[0].cuda()
        target = data[1].cuda()
        output = model(input)

        #Get loss and metric
        target_metric = target.flatten().cpu()
        output_metric = torch.argmax(output, dim=1).flatten().cpu()
        loss = criterion(output, torch.argmax(target, dim=1))
        iou = jaccard_score(target_metric, output_metric, average=None)
        dice = f1_score(target_metric, output_metric, average=None)
        sensitivity = sensitivity_metric(target_metric, output_metric)

        #Calculate the gradient and perform optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #Update average metrics
        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))
        avg_meters['dice'].update(dice, input.size(0))
        avg_meters['sensitivity'].update(sensitivity, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
            ('dice', avg_meters['dice'].avg),
            ('sensitivity', avg_meters['sensitivity'].avg)
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    train_log = OrderedDict([
        ('loss', avg_meters['loss'].avg),
        ('iou', avg_meters['iou'].avg),
        ('dice', avg_meters['dice'].avg),
        ('sensitivity', avg_meters['sensitivity'].avg)
    ])


    #Model Validation
    val_avg_meters = {'val_loss': AverageMeter(), 'val_iou': AverageMeter(), 'val_dice': AverageMeter(), 'val_sensitivity': AverageMeter()}
#     val_avg_meters = {'val_loss': AverageMeter(), 'val_iou': AverageMeter(), 'val_dice': AverageMeter()}
    model.eval()

    with torch.no_grad():
        val_pbar = tqdm(total=len(val_loader))
        for i, val_data in enumerate(val_loader):

            val_input = val_data[0].cuda()
            val_target = val_data[1].cuda()
            val_output = model(val_input)
            
            val_target_metric = val_target.flatten().cpu()
            val_output_metric = torch.argmax(val_output, dim=1).flatten().cpu()
            val_loss = criterion(val_output, torch.argmax(val_target, dim=1))
            val_iou = jaccard_score(val_target_metric, val_output_metric, average=None)
            val_dice = f1_score(val_target_metric, val_output_metric, average=None)
            val_sensitivity = sensitivity_metric(val_target_metric, val_output_metric)

            val_avg_meters['val_loss'].update(val_loss.item(), val_input.size(0))
            val_avg_meters['val_iou'].update(val_iou, val_input.size(0))
            val_avg_meters['val_dice'].update(val_dice, val_input.size(0))
            val_avg_meters['val_sensitivity'].update(val_sensitivity, val_input.size(0))

            val_postfix = OrderedDict([
                ('val_loss', val_avg_meters['val_loss'].avg),
                ('val_iou', val_avg_meters['val_iou'].avg),
                ('val_dice', val_avg_meters['val_dice'].avg),
                ('val_sensitivity', val_avg_meters['val_sensitivity'].avg)
            ])
            val_pbar.set_postfix(val_postfix)
            val_pbar.update(1)
        val_pbar.close()

    val_log = OrderedDict([
        ('val_loss', val_avg_meters['val_loss'].avg),
        ('val_iou', val_avg_meters['val_iou'].avg),
        ('val_dice', val_avg_meters['val_dice'].avg),
        ('val_sensitivity', val_avg_meters['val_sensitivity'].avg)
    ])
    

    print('Training Epoch {}/{},  Training Loss: {:.4f},  Training DICE: {},  Training IOU: {},  Training Sensitivity: {},  Validation Loss: {:.4f},  Validation DICE: {},  Validation IOU: {},  Validation Sensitivity: {}'.format(
        epoch+1, epochs, train_log['loss'], train_log['dice'], train_log['iou'], train_log['sensitivity'], val_log['val_loss'], val_log['val_dice'], val_log['val_iou'], val_log['val_sensitivity']
    ))
    # print('Training Epoch {}/{},  Training Loss: {:.4f},  Training DICE: {:.4f},  Training IOU: {:.4f},  Validation Loss: {:.4f},  Validation DICE: {:.4f},  Validation IOU: {:.4f}'.format(
    #     epoch+1, epochs, train_log['loss'], train_log['dice'], train_log['iou'], val_log['val_loss'], val_log['val_dice'], val_log['val_iou']
    # ))
#     print('Training Epoch {}/{},  Training Loss: {:.4f},  Training DICE: {},  Training IOU: {},  Validation Loss: {:.4f},  Validation DICE: {},  Validation IOU: {}'.format(
#         epoch+1, epochs, train_log['loss'], train_log['dice'], train_log['iou'], val_log['val_loss'], val_log['val_dice'], val_log['val_iou']
#     ))

    #Save values to csv file
    tmp = pd.Series([
        epoch,
        lr,
        train_log['loss'],
        train_log['iou'],
        train_log['dice'],
        train_log['sensitivity'],
        val_log['val_loss'],
        val_log['val_iou'],
        val_log['val_dice'],
        val_log['val_sensitivity']
    ], index=['epoch', 'lr', 'loss', 'iou', 'dice', 'sensitivity', 'val_loss', 'val_iou', 'val_dice', 'val_sensitivity'])
#     index=['epoch', 'lr', 'loss', 'iou', 'dice', 'val_loss', 'val_iou', 'val_dice'])
    

    log = log.append(tmp, ignore_index=True)
    log.to_csv('model_outputs/{}/log_metrics.csv'.format(file_name), index=False)

    trigger += 1

    #If best DICE score, save the model
    if np.mean(val_log['val_dice']) > best_dice:
        torch.save(model.state_dict(), 'model_outputs/{}/model_metrics.pth'.format(file_name))
        best_dice = np.mean(val_log['val_dice'])
        print("Saved new best model based on DICE metric!")
        trigger = 0
    
    if early_stopping >= 0 and trigger >= early_stopping:
        print("Early stopping.")
        break

    torch.cuda.empty_cache()

100%|██████████| 81/81 [10:15<00:00,  7.60s/it, loss=1.26, iou=[8.31553101e-01 2.82317894e-04 3.06457467e-03 1.13835248e-03], dice=[9.07770616e-01 5.64087031e-04 6.10417563e-03 2.27337243e-03], sensitivity=[0.2917769  0.35687026 0.38858356 0.41027167]]
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
100%|██████████| 15/15 [00:58<00:00,  3.88s/it, val_loss=1.2, val_iou=[8.55119945e-01 1.39049913e-03 2.53345583e-03 2.37538365e-04], val_dice=[9.20825658e-01 2.74479782e-03 5.04404216e-03 4.74655422e-04], val_sensitivity=[nan nan nan nan]]


Training Epoch 1/100,  Training Loss: 1.2583,  Training DICE: [9.07770616e-01 5.64087031e-04 6.10417563e-03 2.27337243e-03],  Training IOU: [8.31553101e-01 2.82317894e-04 3.06457467e-03 1.13835248e-03],  Training Sensitivity: [0.2917769  0.35687026 0.38858356 0.41027167],  Validation Loss: 1.1988,  Validation DICE: [9.20825658e-01 2.74479782e-03 5.04404216e-03 4.74655422e-04],  Validation IOU: [8.55119945e-01 1.39049913e-03 2.53345583e-03 2.37538365e-04],  Validation Sensitivity: [nan nan nan nan]
Saved new best model based on DICE metric!


100%|██████████| 81/81 [09:49<00:00,  7.28s/it, loss=1.18, iou=[8.75525857e-01 5.08337082e-04 2.67418636e-03 1.60968582e-03], dice=[0.93356233 0.00101442 0.00532933 0.0032123 ], sensitivity=[0.38283649 0.39185259 0.39379956 0.37978425]]                
  s0 = tp[0]/(tp[0]+fn[0])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
100%|██████████| 15/15 [00:54<00:00,  3.66s/it, val_loss=1.16, val_iou=[8.77463499e-01 1.09686045e-03 2.95487213e-03 3.83179005e-04], val_dice=[9.33612978e-01 2.17449317e-03 5.86758477e-03 7.65242085e-04], val_sensitivity=[nan nan nan nan]]


Training Epoch 2/100,  Training Loss: 1.1809,  Training DICE: [0.93356233 0.00101442 0.00532933 0.0032123 ],  Training IOU: [8.75525857e-01 5.08337082e-04 2.67418636e-03 1.60968582e-03],  Training Sensitivity: [0.38283649 0.39185259 0.39379956 0.37978425],  Validation Loss: 1.1615,  Validation DICE: [9.33612978e-01 2.17449317e-03 5.86758477e-03 7.65242085e-04],  Validation IOU: [8.77463499e-01 1.09686045e-03 2.95487213e-03 3.83179005e-04],  Validation Sensitivity: [nan nan nan nan]
Saved new best model based on DICE metric!


100%|██████████| 81/81 [10:07<00:00,  7.50s/it, loss=1.16, iou=[8.93616453e-01 5.23835943e-04 2.95767012e-03 3.17440332e-03], dice=[0.94372831 0.00104576 0.00589234 0.00632121], sensitivity=[0.37517614 0.37559897 0.3770714  0.40786095]]               
  s0 = tp[0]/(tp[0]+fn[0])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
  s3 = tp[3]/(tp[3]+fn[3])
  s0 = tp[0]/(tp[0]+fn[0])
  s0 = tp[0]/(tp[0]+fn[0])
100%|██████████| 15/15 [00:53<00:00,  3.60s/it, val_loss=1.15, val_iou=[8.88095697e-01 9.58661049e-04 3.10252827e-03 8.08829881e-04], val_dice=[0.93953584 0.00190556 0.00616638 0.00161252], val_sensitivity=[nan nan nan nan]]


Training Epoch 3/100,  Training Loss: 1.1555,  Training DICE: [0.94372831 0.00104576 0.00589234 0.00632121],  Training IOU: [8.93616453e-01 5.23835943e-04 2.95767012e-03 3.17440332e-03],  Training Sensitivity: [0.37517614 0.37559897 0.3770714  0.40786095],  Validation Loss: 1.1451,  Validation DICE: [0.93953584 0.00190556 0.00616638 0.00161252],  Validation IOU: [8.88095697e-01 9.58661049e-04 3.10252827e-03 8.08829881e-04],  Validation Sensitivity: [nan nan nan nan]
Saved new best model based on DICE metric!


100%|██████████| 81/81 [10:00<00:00,  7.41s/it, loss=1.14, iou=[9.12939819e-01 4.57557282e-04 3.47131903e-03 4.57729677e-03], dice=[9.54434054e-01 9.13649292e-04 6.90815811e-03 9.10055541e-03], sensitivity=[0.3677389  0.38256793 0.37127605 0.45373469]]
  s0 = tp[0]/(tp[0]+fn[0])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
100%|██████████| 15/15 [00:53<00:00,  3.60s/it, val_loss=1.13, val_iou=[0.91679845 0.00107327 0.00405174 0.00092473], val_dice=[0.95578662 0.00212863 0.00802116 0.00184307], val_sensitivity=[nan nan nan nan]]


Training Epoch 4/100,  Training Loss: 1.1387,  Training DICE: [9.54434054e-01 9.13649292e-04 6.90815811e-03 9.10055541e-03],  Training IOU: [9.12939819e-01 4.57557282e-04 3.47131903e-03 4.57729677e-03],  Training Sensitivity: [0.3677389  0.38256793 0.37127605 0.45373469],  Validation Loss: 1.1275,  Validation DICE: [0.95578662 0.00212863 0.00802116 0.00184307],  Validation IOU: [0.91679845 0.00107327 0.00405174 0.00092473],  Validation Sensitivity: [nan nan nan nan]
Saved new best model based on DICE metric!


100%|██████████| 81/81 [09:54<00:00,  7.35s/it, loss=1.12, iou=[9.45334551e-01 5.01332748e-04 3.63053956e-03 2.90102894e-03], dice=[0.97187308 0.00100075 0.00721736 0.00577285], sensitivity=[0.40131623 0.35078457 0.39121892 0.41512058]]
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s0 = tp[0]/(tp[0]+fn[0])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
100%|██████████| 15/15 [00:54<00:00,  3.60s/it, val_loss=1.11, val_iou=[9.52104384e-01 1.46291197e-03 2.10956904e-03 6.27396638e-04], val_dice=[0.97494019 0.00287737 0.00418529 0.00124983], val_sensitivity=[nan nan nan nan]]


Training Epoch 5/100,  Training Loss: 1.1217,  Training DICE: [0.97187308 0.00100075 0.00721736 0.00577285],  Training IOU: [9.45334551e-01 5.01332748e-04 3.63053956e-03 2.90102894e-03],  Training Sensitivity: [0.40131623 0.35078457 0.39121892 0.41512058],  Validation Loss: 1.1102,  Validation DICE: [0.97494019 0.00287737 0.00418529 0.00124983],  Validation IOU: [9.52104384e-01 1.46291197e-03 2.10956904e-03 6.27396638e-04],  Validation Sensitivity: [nan nan nan nan]
Saved new best model based on DICE metric!


  s3 = tp[3]/(tp[3]+fn[3])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s3 = tp[3]/(tp[3]+fn[3])
  s3 = tp[3]/(tp[3]+fn[3])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s3 = tp[3]/(tp[3]+fn[3])
  s3 = tp[3]/(tp[3]+fn[3])


  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
100%|██████████| 81/81 [10:06<00:00,  7.49s/it, loss=1.11, iou=[9.72536731e-01 2.87471502e-04 2.67746926e-03 1.11276241e-03], dice=[9.86063992e-01 5.73548060e-04 5.31148892e-03 2.20110692e-03], sensitivity=[0.38500887        nan        nan        nan]]
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s0 = tp[0]/(tp[0]+fn[0])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1]+fn[1])
  s2 = tp[2]/(tp[2]+fn[2])
  s3 = tp[3]/(tp[3]+fn[3])
  s1 = tp[1]/(tp[1

 87%|████████▋ | 13/15 [00:51<00:07,  3.64s/it, val_loss=1.11, val_iou=[9.42536579e-01 1.28231094e-04 8.36955657e-04 0.00000000e+00], val_dice=[9.68841649e-01 2.56195776e-04 1.66677609e-03 0.00000000e+00], val_sensitivity=[nan nan nan nan]] 93%|█████████▎| 14/15 [00:51<00:03,  3.63s/it, val_loss=1.11, val_iou=[9.42536579e-01 1.28231094e-04 8.36955657e-04 0.00000000e+00], val_dice=[9.68841649e-01 2.56195776e-04 1.66677609e-03 0.00000000e+00], val_sensitivity=[nan nan nan nan]]

ValueError: not enough values to unpack (expected 4, got 3)