In [9]:
from torch.utils.data import DataLoader
from torchvision import transforms
from progressBar import printProgressBar

import medicalDataLoader
import argparse
from utils import *

from UNet_Base import *
import random
import torch
import pdb

import segmentation_models_pytorch as smp

In [10]:
import warnings
warnings.filterwarnings("ignore")

In [11]:
class HLoss(nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x):
        b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
        b = -1.0 * b.mean()
        # b = -1.0 * b.sum() # If using this one, the value on the final loss for eloss is 0.0000001
        return b

In [12]:
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    ## DEFINE HYPERPARAMETERS (batch_size > 1)
    batch_size = 32
    batch_size_val = 16
    lr =  0.001   # Learning Rate
    epoch = 100 # Number of epochs

    loss_weight = [0.7, 0.3] # Weight for the loss function (CE, Dice)
    CE_weights = [0.05, 0.40, 0.30, 0.25] # Weight for the Cross Entropy loss function
    
    root_dir = './Data/'
    print(' Dataset: {} '.format(root_dir))

    ## DEFINE THE TRANSFORMATIONS TO DO AND THE VARIABLES FOR TRAINING AND VALIDATION
    
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_set_full = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform,
                                                      mask_transform=mask_transform,
                                                      augment=True,
                                                      equalize=False)

    train_loader_full = DataLoader(train_set_full,
                              batch_size=batch_size,
                              worker_init_fn=np.random.seed(0),
                              num_workers=0,
                              shuffle=True)


    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=False)


    ## INITIALIZE YOUR MODEL
    num_classes = 4 # NUMBER OF CLASSES

    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    modelName = 'Test_Model'
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    # net = UNet(num_classes)
    net = smp.Unet('resnet34', encoder_weights='imagenet', in_channels=1, classes=num_classes)

    print("Total params: {0:,}".format(sum(p.numel() for p in net.parameters() if p.requires_grad)))

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    softMax = torch.nn.Softmax()
    # CE_loss = torch.nn.CrossEntropyLoss(weight=torch.Tensor([0.05, 0.35, 0.30, 0.30]))
    # DSC_loss_weight = 0.4
    # Dice_loss = smp.losses.DiceLoss(mode='multiclass', classes=[1, 2, 3])
    CE_loss = torch.nn.CrossEntropyLoss(weight=torch.Tensor(CE_weights))
    Dice_loss = smp.losses.DiceLoss(mode='multiclass', ignore_index=0)

    ## PUT EVERYTHING IN GPU RESOURCES    
    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=0.0001)

    ### To save statistics ####
    lossTotalTraining = []
    lossTotalValidation = []
    Best_loss_val = 1000
    BestEpoch = 0
    no_improvement_counter = 0
        
    directory = 'Results/Statistics/' + modelName

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory)==False:
        os.makedirs(directory)

    ## START THE TRAINING
    
    ## FOR EACH EPOCH
    for i in range(epoch):
        net.train()
        lossEpochTrain = []
        DSCEpoch = []
        DSCEpoch_w = []
        num_batches = len(train_loader_full)
        
        ## FOR EACH BATCH
        for j, data in enumerate(train_loader_full):
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data

            ### From numpy to torch variables
            labels = to_var(labels)
            images = to_var(images)

            ################### Train ###################
            #-- The CNN makes its predictions (forward pass)
            net_predictions = net(images)

            #-- Compute the losses --#
            # THIS FUNCTION IS TO CONVERT LABELS TO A FORMAT TO BE USED IN THIS CODE
            segmentation_classes = getTargetSegmentation(labels)
            
            # COMPUTE THE LOSS
            CE_loss_value = CE_loss(net_predictions, segmentation_classes)
            
            predsoft = softMax(net_predictions)
            pred = predsoft.argmax(dim=1)

            #Show live view of model segmentation
            if not os.path.exists('Results/Segmentation/'):
                os.makedirs('Results/Segmentation/')
            torchvision.utils.save_image(torch.cat([pred.view(labels.shape[0], 1, 256, 256).data / 3.0]), 'Results/Segmentation/liveview.png'.format(i))

            # Dice_loss_value = computeDSC(pred.unsqueeze(1), segmentation_classes.unsqueeze(1))
            Dice_loss_value = Dice_loss(predsoft, segmentation_classes)

            lossTotal = loss_weight[0]*CE_loss_value + loss_weight[1]*(Dice_loss_value)

            # DO THE STEPS FOR BACKPROP (two things to be done in pytorch)
            lossTotal.backward()
            optimizer.step()
            
            # THIS IS JUST TO VISUALIZE THE TRAINING 
            lossEpochTrain.append(lossTotal.cpu().data.numpy())
            printProgressBar(j + 1, num_batches,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Loss: {:.4f}, CE: {:.4f}, Dice: {:.8f}".format(lossTotal, CE_loss_value, Dice_loss_value))

        lossEpochTrain = np.asarray(lossEpochTrain)
        lossEpochTrain = lossEpochTrain.mean()

        lossTotalTraining.append(lossEpochTrain)

        printProgressBar(num_batches, num_batches,
                             done="[Training] Epoch: {}, LossG: {:.4f}".format(i,lossEpochTrain))

        # Validation
        net.eval()
        lossEpochVal = []
        num_batches = len(val_loader)
        
        for j,data in enumerate(val_loader):
            images, labels, img_names = data
            labels = to_var(labels)
            images = to_var(images)

            net_predictions = net(images)
            segmentation_classes = getTargetSegmentation(labels)
            CE_loss_value = CE_loss(net_predictions, segmentation_classes)
            
            predsoft = softMax(net_predictions)
            pred = predsoft.argmax(dim=1)

            #Show live view of model segmentation
            if not os.path.exists('Results/Segmentation/'):
                os.makedirs('Results/Segmentation/')
            torchvision.utils.save_image(torch.cat([pred.view(labels.shape[0], 1, 256, 256).data / 3.0]), 'Results/Segmentation/liveview.png'.format(i))
            
            # Dice_loss_value = computeDSC(pred.unsqueeze(1), segmentation_classes.unsqueeze(1))
            Dice_loss_value = Dice_loss(predsoft, segmentation_classes)

            lossTotal = loss_weight[0]*CE_loss_value + loss_weight[1]*(Dice_loss_value)

            lossEpochVal.append(lossTotal.cpu().data.numpy())
            printProgressBar(j + 1, num_batches,
                             prefix="[Validation] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Loss: {:.4f}, CE: {:.4f}, Dice: {:.8f}".format(lossTotal, CE_loss_value, Dice_loss_value))
            
        lossEpochVal = np.asarray(lossEpochVal)
        lossEpochVal = lossEpochVal.mean()

        lossTotalValidation.append(lossEpochVal)

        printProgressBar(num_batches, num_batches,
                             done="[Validation] Epoch: {}, LossG: {:.4f}".format(i,lossEpochVal))

        # Save the model if it is the best so far

        if lossEpochVal < Best_loss_val:
            Best_loss_val = lossEpochVal
            BestEpoch = i
            no_improvement_counter = 0

            if not os.path.exists('./models/' + modelName):
                os.makedirs('./models/' + modelName)
            torch.save(net.state_dict(), './models/' + modelName + '/' + str(i) + '_Epoch')
            print('Best model saved at epoch {}'.format(i))
        else:
            no_improvement_counter = no_improvement_counter + 1
            print('No improvement in last epoch. Counter: {}'.format(no_improvement_counter))
            if no_improvement_counter % 3 == 0 and no_improvement_counter != 0:
                print('No improvement in last 3 epochs. Lowering learning rate.')
                lr = lr/10
                optimizer = torch.optim.Adam(net.parameters(), lr=lr)
            if i - BestEpoch > 7:
                print('No improvement in last 7 epochs. Stopping training.')
                break

    np.save(os.path.join(directory, 'Losses.npy'), lossTotalTraining)
    np.save(os.path.join(directory, 'Losses_val.npy'), lossTotalValidation)

    print('Training finished. Best model saved at epoch {}'.format(BestEpoch))

    graphLosses(lossTotalTraining, lossTotalValidation, modelName, directory)

    return BestEpoch

In [13]:
# Evaluation and segmentation
def runEvaluation(BestEpoch):

    #Load the best model
    modelName = 'Test_Model'
    # net = UNet(4)
    net = smp.Unet('resnet34', encoder_weights='imagenet', in_channels=1, classes=4)
    net.load_state_dict(torch.load('./models/' + modelName + '/' + str(BestEpoch) + '_Epoch'))
    # net.load_state_dict(torch.load('./Saved_Models/Model_loss_0_0190'))

    #Load the val set
    batch_size_val = 8
    root_dir = './Data/'
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=False)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=False)
    

    loss, tp, fp, fn, tn = inference(net, val_loader, modelName, BestEpoch)

    def metric_print(tp, fp, fn, tn):
        f1_score = tp/(tp + 0.5*(fp + fn))
        precision = tp/(tp + fp)
        recall = tp/(tp + fn)
        accuracy = (tp + tn)/(tp + tn + fp + fn)
        ballanced_accuracy = 0.5*(tp/(tp + fn) + tn/(tn + fp))
        jaccard = tp/(tp + fp + fn)
        sensitivity = tp/(tp + fn)
        specificity = tn/(tn + fp)
        print('     ---------- Data ----------')
        print('                TP: {}'.format(tp))
        print('                FP: {}'.format(fp))
        print('                FN: {}'.format(fn))
        print('                TN: {}'.format(tn))
        print('     ---------- Metrics ----------')
        print('          F1 score: {}'.format(f1_score))
        print('         Precision: {}'.format(precision))
        print('            Recall: {}'.format(recall))
        print('          Accuracy: {}'.format(accuracy))
        print('Ballanced accuracy: {}'.format(ballanced_accuracy))
        print('           Jaccard: {}'.format(jaccard))
        print('       Sensitivity: {}'.format(sensitivity))
        print('       Specificity: {}'.format(specificity))
        print('')

    print('     ---------- Loss ----------')
    print('           CE Loss: {}'.format(loss))
    print('')
    print('     ---------- Class 1 (BG) ----------')
    metric_print(tn[0], fp[0], fn[0], tp[0])
    print('     ---------- Class 2  ----------')
    metric_print(tn[1], fp[1], fn[1], tp[1])
    print('     ---------- Class 3  ----------')
    metric_print(tn[2], fp[2], fn[2], tp[2])
    print('     ---------- Class 4  ----------')
    metric_print(tn[3], fp[3], fn[3], tp[3])
    print('     ---------- Total  ----------')
    metric_print(np.sum(tp), np.sum(fp), np.sum(fn), np.sum(tn))


    
    

In [14]:
def runTesting(modelName='Test_Model', BestEpoch=0):
    print('-' * 40)
    print('~~~~~~~~  Starting the testing... ~~~~~~')
    print('-' * 40)

    batch_size_val = 1
    root_dir = './Data/'

    # https://sparrow.dev/pytorch-normalize/
    transform = transforms.Compose([
        transforms.ToTensor()
        # transforms.Normalize((0.5), (0.20))
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    test_set = medicalDataLoader.MedicalImageDataset('val',
                                                     root_dir,
                                                     transform=transform,
                                                     mask_transform=mask_transform,
                                                     equalize=False)

    test_loader = DataLoader(test_set,
                             batch_size=batch_size_val,
                             num_workers=5,
                             shuffle=False)

    # Initialize
    num_classes = 4

    # Create and load model
    # net = UNet(num_classes)
    net = smp.Unet('resnet34', encoder_weights='imagenet', in_channels=1, classes=4)

    # Load
    # net.load_state_dict(torch.load('./models/'+modelName))
    net.load_state_dict(torch.load('./models/' + modelName + '/' + str(BestEpoch) + '_Epoch'))
    # net.load_state_dict(torch.load('./Saved_Models/Model_loss_0_0190'))
    net.eval()

    if torch.cuda.is_available():
        net.cuda()

    print("~~~~~~~~~~~ Starting the testing ~~~~~~~~~~")
    [DSC1, DSC1s, DSC2, DSC2s, DSC3, DSC3s, HD1, HD1s, HD2, HD2s, HD3, HD3s, ASD1,
        ASD1s, ASD2, ASD2s, ASD3, ASD3s] = inferenceTest(net, test_loader, modelName)

    print("###                                                       ###")
    print("###         TEST RESULTS                                  ###")
    print("###  Dice : c1: {:.4f} ({:.4f}) c2: {:.4f} ({:.4f}) c3: {:.4f} ({:.4f}) Mean: {:.4f} ({:.4f}) ###".format(DSC1,
                                                                                                                     DSC1s,
                                                                                                                     DSC2,
                                                                                                                     DSC2s,
                                                                                                                     DSC3,
                                                                                                                     DSC3s,
                                                                                                                     (DSC1+DSC2+DSC3)/3,
                                                                                                                     (DSC1s+DSC2s+DSC3s)/3))
    print("###  HD   : c1: {:.4f} ({:.4f}) c2: {:.4f} ({:.4f}) c3: {:.4f} ({:.4f}) Mean: {:.4f} ({:.4f}) ###".format(HD1,
                                                                                                                     HD1s,
                                                                                                                     HD2,
                                                                                                                     HD2s,
                                                                                                                     HD3,
                                                                                                                     HD3s,
                                                                                                                     (HD1 + HD2 + HD3) / 3,
                                                                                                                     (HD1s + HD2s + HD3s) / 3))
    print("###  ASD  : c1: {:.4f} ({:.4f}) c2: {:.4f} ({:.4f}) c3: {:.4f} ({:.4f}) Mean: {:.4f} ({:.4f}) ###".format(ASD1,
                                                                                                                     ASD1s,
                                                                                                                     ASD2,
                                                                                                                     ASD2s,
                                                                                                                     ASD3,
                                                                                                                     ASD3s,
                                                                                                                     (ASD1 + ASD2 + ASD3) / 3,
                                                                                                                     (ASD1s + ASD2s + ASD3s) / 3))
    print("###                                                       ###")

In [15]:
BESTEPOCH = runTraining()
runTesting(BestEpoch=BESTEPOCH)
runEvaluation(BESTEPOCH)

----------------------------------------
~~~~~~~~  Starting the training... ~~~~~~
----------------------------------------
 Dataset: ./Data/ 
~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~
 Model Name: Test_Model
Total params: 24,430,532
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
[Training] Epoch: 0 [DONE]                                                             
[Training] Epoch: 0, LossG: 1.2238                                                                           
[Validation] Epoch: 0 [DONE]                                                             
[Validation] Epoch: 0, LossG: 1.4389                                                                         
Best model saved at epoch 0
[Training] Epoch: 1 [DONE]                                                             
[Training] Epoch: 1, LossG: 0.8646                                                                           
[Validation] Epoch: 1 [DONE]                                                             
[V

KeyboardInterrupt: 

In [8]:
# runEvaluation(22)
runTesting(BestEpoch=22)
# runTesting(modelName='Model_loss_0_0190')


----------------------------------------
~~~~~~~~  Starting the testing... ~~~~~~
----------------------------------------
~~~~~~~~~~~ Starting the testing ~~~~~~~~~~
[Inference] Segmentation Done !                                                                              
###                                                       ###
###         TEST RESULTS                                  ###
###  Dice : c1: 0.5687 (0.3683) c2: 0.7404 (0.2356) c3: 0.8610 (0.2595) Mean: 0.7233 (0.2878) ###
###  HD   : c1: 23.9789 (19.4527) c2: 6.8428 (10.6456) c3: 4.5790 (8.2024) Mean: 11.8002 (12.7669) ###
###  ASD  : c1: 7.3617 (7.6343) c2: 2.2034 (4.3106) c3: 1.4603 (3.1765) Mean: 3.6751 (5.0405) ###
###                                                       ###


In [None]:
#For all ground truth images in the training and val sets, compute class distribution
root_dir = './Data/'
transform = transforms.Compose([
    transforms.ToTensor()
])
mask_transform = transforms.Compose([
    transforms.ToTensor()
])
train_set = medicalDataLoader.MedicalImageDataset('train',
                                                  root_dir,
                                                  transform=transform,
                                                  mask_transform=mask_transform,
                                                  equalize=False)
train_loader = DataLoader(train_set,
                            batch_size=1,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=False)
val_set = medicalDataLoader.MedicalImageDataset('val',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)
val_loader = DataLoader(val_set,
                        batch_size=1,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)

class_distribution = np.zeros(4)
for i, data in enumerate(train_loader):
    images, labels, img_names = data
    labels = to_var(labels)
    images = to_var(images)
    segmentation_classes = getTargetSegmentation(labels)
    for j in range(4):
        class_distribution[j] += torch.sum(segmentation_classes == j).item()

for i, data in enumerate(val_loader):
    images, labels, img_names = data
    labels = to_var(labels)
    images = to_var(images)
    segmentation_classes = getTargetSegmentation(labels)
    for j in range(4):
        class_distribution[j] += torch.sum(segmentation_classes == j).item()

class_distribution = class_distribution / np.sum(class_distribution)

print('Class distribution: {}'.format(class_distribution))
