In [None]:
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchmetrics import Dice

from progressBar import printProgressBar

import medicalDataLoader
import argparse
from utils import *
from losses import *
from UNet_Base import *
import random
import torch
import pdb
import cv2

import matplotlib.pyplot as plt

from UNet_Attention import *
import copy

In [None]:
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(420)
random.seed(420)
np.random.seed(420)

In [None]:
def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length


def get_current_consistency_weight(epoch,consistency,consistency_rampup):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return consistency * sigmoid_rampup(epoch, consistency_rampup)


def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

def create_model(num_classes,ema=False,n1=8):
    # Network definition
    model = UNet(num_classes=num_classes,n1=n1)
    if ema:
        for param in model.parameters():
            param.detach_()
    return model

In [None]:
import itertools
from torch.utils.data.sampler import Sampler

class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices
    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """

    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch) in zip(
                grouper(primary_iter, self.primary_batch_size),
                grouper(secondary_iter, self.secondary_batch_size),
            )
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size


def iterate_once(iterable):
    return np.random.permutation(iterable)


def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)

    return itertools.chain.from_iterable(infinite_shuffles())


def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)

In [None]:
def runTraining():
    batch_size = 16
    labeled_batch_size = 4
    batch_size_val = 8
    lr = 0.001     # Learning Rate 
    weight_decay = 1e-5
    epochs = 200 # Number of epochs
    consistency = 0.1
    consistency_rampup = 200
    ema_decay = 0.99

    root_dir = './Data/'
    train_img_path = os.path.join(root_dir, 'train', 'Img')
    unlabeled_img_path = os.path.join(root_dir, 'train', 'Img-Unlabeled')

    labeled_images = os.listdir(train_img_path)
    unlabeled_images = os.listdir(unlabeled_img_path)
    labeled_idx = list(range(0, len(labeled_images)))
    unlabeled_idxs = list(range(len(labeled_images),len(labeled_images)+len(unlabeled_images)))
    batch_sampler = TwoStreamBatchSampler(
            labeled_idx, unlabeled_idxs, batch_size,batch_size-labeled_batch_size)
    ## DEFINE THE TRANSFORMATIONS TO DO AND THE VARIABLES FOR TRAINING AND VALIDATION

    transform = transforms.Compose([
    transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
    transforms.ToTensor()
    ])

    train_set_full = medicalDataLoader.MedicalImageDataset('train',
                                                           'semi',
                                                            root_dir,
                                                            transform=transform,
                                                            mask_transform=mask_transform,
                                                            augment=False,
                                                            equalize=False)

    train_loader_full = DataLoader(train_set_full,
                            batch_sampler=batch_sampler,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=False)


    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    'semi',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)

    val_loader = DataLoader(val_set,
                        batch_size=batch_size_val,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)


    ## INITIALIZE YOUR MODEL
    num_classes = 4 # NUMBER OF CLASSES

    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    modelName = 'UNet_Model'
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    net = create_model(num_classes,n1=16)
    ema_model = create_model(num_classes,ema=True,n1=16)
    # net = AttU_Net(num_classes)

    print("Total params: {0:,}".format(sum(p.numel() for p in net.parameters() if p.requires_grad)))

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    softMax = torch.nn.Softmax(dim=1)
    CE_loss = torch.nn.CrossEntropyLoss()
    dice_loss = DiceLoss(num_classes)
    ## PUT EVERYTHING IN GPU RESOURCES    
    if torch.cuda.is_available():
        net.cuda()
        ema_model.cuda()
        softMax.cuda()
        CE_loss.cuda()
        dice_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(params=net.parameters(),lr=lr,weight_decay=weight_decay)


    ### To save statistics ####
    lossTotalTraining = []
    lossTotalValidation = []
    Best_loss_val = 1000
    BestEpoch = 0

    directory = 'Results/Statistics/ssl/' + modelName

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory)==False:
        os.makedirs(directory)

    ## START THE TRAINING
    num_iter = 0
    ## FOR EACH EPOCH
    for epoch in range(epochs):
        net.train()
        lossEpoch = []
        vlossEpoch = []
        DSCEpoch = []
        vDSCEpoch = []
        ConsistEpoch = []
        DSCEpoch_w = []
        num_batches = len(train_loader_full)
        v_num_batches = len(val_loader)
        ########## Training ##########
        net.train(True)
        ## FOR EACH BATCH
        for j, data in enumerate(train_loader_full):
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data
            unlabeled_images = images[labeled_batch_size:]
            noise = torch.clamp(torch.randn_like(
                unlabeled_images) * 0.1, -0.2, 0.2)
            ema_inputs = unlabeled_images + noise
            ### From numpy to torch variables
            labels = to_var(labels)
            images = to_var(images)
            ema_inputs = to_var(ema_inputs)

            ################### Train ###################
            #-- The CNN makes its predictions (forward pass)
            net_predictions = net.forward(images)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
                ema_output_soft = torch.softmax(ema_output, dim=1)

            #-- Compute the losses --#
            # THIS FUNCTION IS TO CONVERT LABELS TO A FORMAT TO BE USED IN THIS CODE
            segmentation_classes = getTargetSegmentation(labels)
            # COMPUTE THE LOSS
            CE_loss_value = CE_loss(softMax(net_predictions[:labeled_batch_size]),segmentation_classes[:labeled_batch_size])
            dice_loss_value = dice_loss(softMax(net_predictions[:labeled_batch_size]),segmentation_classes[:labeled_batch_size].unsqueeze(1))
           
            # consistency loss
            consistency_weight = get_current_consistency_weight(epoch,consistency,consistency_rampup)
            if epoch < 20:
                consistency_loss = torch.tensor(0.0)
            else:
                consistency_loss = torch.mean(
                (softMax(net_predictions[labeled_batch_size:])-ema_output_soft)**2)
            
            lossTotal = 0.5*(CE_loss_value + dice_loss_value) + consistency_weight*consistency_loss
            # DO THE STEPS FOR BACKPROP (two things to be done in pytorch)
            lossTotal.backward()
            optimizer.step()
            update_ema_variables(net, ema_model, ema_decay, num_iter)
            num_iter+=1
            # THIS IS JUST TO VISUALIZE THE TRAINING 
            lossEpoch.append(CE_loss_value.cpu().data.numpy())
            DSCEpoch.append(dice_loss_value.cpu().data.numpy())
            ConsistEpoch.append(consistency_loss.cpu().data.numpy())
            print
            printProgressBar(j + 1, num_batches,
                                prefix="[Training] Epoch: {} ".format(epoch),
                                length=15,
                                suffix=" CE_Loss: {:.4f}, dice_loss:  {:.4f}, consist_loss:  {:.4f}".format(CE_loss_value,dice_loss_value,consistency_loss))

        lossEpoch = np.asarray(lossEpoch)
        lossEpoch = lossEpoch.mean()
        DSCEpoch = np.asarray(DSCEpoch)
        DSCEpoch = DSCEpoch.mean()
        ConsistEpoch = np.asarray(ConsistEpoch)
        ConsistEpoch = ConsistEpoch.mean()
        lossTotalTraining.append(lossEpoch+DSCEpoch+ConsistEpoch)
        printProgressBar(num_batches, num_batches,
                                done="[Training] Epoch: {}, LossG: {:.4f}".format(epoch,lossEpoch+DSCEpoch+ConsistEpoch))
        
        ######### Validation ############
        net.train(False)
        for j, vdata in enumerate(val_loader):
            vimages, vlabels, vimg_names = vdata
            vlabels = to_var(vlabels)
            vimages = to_var(vimages)
            voutputs = net(vimages)
            vsegmentation_classes = getTargetSegmentation(vlabels)
            vloss = CE_loss(softMax(voutputs), vsegmentation_classes)
            vdice = dice_loss(softMax(voutputs), vsegmentation_classes.unsqueeze(1))
            vlossTotal = 0.5*(vloss + vdice)
            vlossEpoch.append(vloss.cpu().data.numpy())
            vDSCEpoch.append(vdice.cpu().data.numpy())
            printProgressBar(j + 1, num_batches,
                                prefix="[Validation] Epoch: {} ".format(epoch),
                                length=15,
                                suffix=" CE_val_Loss: {:.4f}, dice_val_loss: {:.4f}".format(vloss,vdice))
            
        vlossEpoch = np.asarray(vlossEpoch)
        vlossEpoch = vlossEpoch.mean()
        vDSCEpoch = np.asarray(vDSCEpoch)
        vDSCEpoch = vDSCEpoch.mean()
        lossTotalValidation.append(vlossEpoch+vDSCEpoch)
        printProgressBar(v_num_batches, v_num_batches,
                                done="[Validation] Epoch: {}, val_LossG: {:.4f}".format(epoch,vlossEpoch+vDSCEpoch))



        ## THIS IS HOW YOU WILL SAVE THE TRAINED MODELS AFTER EACH EPOCH. 
        ## WARNING!!!!! YOU DON'T WANT TO SAVE IT AT EACH EPOCH, BUT ONLY WHEN THE MODEL WORKS BEST ON THE VALIDATION SET!!
        if not os.path.exists('./models/ssl/' + modelName):
            os.makedirs('./models/ssl/' + modelName)
        if  vlossEpoch < Best_loss_val:
            Best_loss_val = vlossEpoch
            BestEpoch = epoch
            torch.save(net.state_dict(), './models/ssl/' + modelName + '/'  + str(BestEpoch)+'_Epoch')
            
        np.save(os.path.join(directory, 'TrainLosses.npy'), lossTotalTraining)
        np.save(os.path.join(directory, 'ValLosses.npy'), lossTotalValidation)

runTraining()


In [None]:
def runTrainingTversky():
    batch_size = 16
    labeled_batch_size = 4
    batch_size_val = 8
    lr = 0.001     # Learning Rate
    weight_decay = 1e-5
    epochs = 220 # Number of epochs
    consistency = 0.1
    consistency_rampup = 200
    alpha = 0.3# 0.5 0.3 0.8
    beta = 0.8 # 0.5 0.8 0.3
    ema_decay = 0.99

    root_dir = './Data/'
    train_img_path = os.path.join(root_dir, 'train', 'Img')
    unlabeled_img_path = os.path.join(root_dir, 'train', 'Img-Unlabeled')

    labeled_images = os.listdir(train_img_path)
    unlabeled_images = os.listdir(unlabeled_img_path)
    labeled_idx = list(range(0, len(labeled_images)))
    unlabeled_idxs = list(range(len(labeled_images),len(labeled_images)+len(unlabeled_images)))
    batch_sampler = TwoStreamBatchSampler(
            labeled_idx, unlabeled_idxs, batch_size,batch_size-labeled_batch_size)
    ## DEFINE THE TRANSFORMATIONS TO DO AND THE VARIABLES FOR TRAINING AND VALIDATION

    transform = transforms.Compose([
    transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
    transforms.ToTensor()
    ])

    train_set_full = medicalDataLoader.MedicalImageDataset('train',
                                                           'semi',
                                                            root_dir,
                                                            transform=transform,
                                                            mask_transform=mask_transform,
                                                            augment=False,
                                                            equalize=False)

    train_loader_full = DataLoader(train_set_full,
                            batch_sampler=batch_sampler,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=False)


    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    'semi',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)

    val_loader = DataLoader(val_set,
                        batch_size=batch_size_val,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)


    ## INITIALIZE YOUR MODEL
    num_classes = 4 # NUMBER OF CLASSES

    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    modelName = 'UNet_Model_Tversky_lr'
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    net = create_model(num_classes,n1=16)
    ema_model = create_model(num_classes,ema=True,n1=16)
    # net = AttU_Net(num_classes)

    print("Total params: {0:,}".format(sum(p.numel() for p in net.parameters() if p.requires_grad)))

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    softMax = torch.nn.Softmax(dim=1)
    tversky = TverskyLoss(alpha=alpha,beta=beta,n_classes=num_classes)
    ## PUT EVERYTHING IN GPU RESOURCES    
    if torch.cuda.is_available():
        net.cuda()
        ema_model.cuda()
        softMax.cuda()
        tversky.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(params=net.parameters(),lr=lr,weight_decay=weight_decay)
    lr_scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    ### To save statistics ####
    lossTotalTraining = []
    lossTotalValidation = []
    Best_loss_val = 1000
    BestEpoch = 0

    directory = 'Results/Statistics/ssl/' + modelName

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory)==False:
        os.makedirs(directory)

    ## START THE TRAINING
    num_iter = 0
    ## FOR EACH EPOCH
    for epoch in range(epochs):
        net.train()
        lossEpoch = []
        ConsistEpoch = []
        vlossEpoch = []

        num_batches = len(train_loader_full)
        v_num_batches = len(val_loader)
        ########## Training ##########
        net.train(True)
        ## FOR EACH BATCH
        for j, data in enumerate(train_loader_full):
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data
            unlabeled_images = images[labeled_batch_size:]
            noise = torch.clamp(torch.randn_like(
                unlabeled_images) * 0.1, -0.2, 0.2)
            ema_inputs = unlabeled_images + noise
            ### From numpy to torch variables
            labels = to_var(labels)
            images = to_var(images)
            ema_inputs = to_var(ema_inputs)

            ################### Train ###################
            #-- The CNN makes its predictions (forward pass)
            net_predictions = net.forward(images)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
                ema_output_soft = torch.softmax(ema_output, dim=1)

            #-- Compute the losses --#
            # THIS FUNCTION IS TO CONVERT LABELS TO A FORMAT TO BE USED IN THIS CODE
            segmentation_classes = getTargetSegmentation(labels)
            # COMPUTE THE LOSS
            Tversky_loss_value = tversky(softMax(net_predictions[:labeled_batch_size]),segmentation_classes[:labeled_batch_size])
            # consistency loss
            consistency_weight = get_current_consistency_weight(epoch,consistency,consistency_rampup)
            if epoch < 20:
                consistency_loss = torch.tensor(0.0)
            else:
                consistency_loss = torch.mean(
                (softMax(net_predictions[labeled_batch_size:])-ema_output_soft)**2)
            lossTotal =  Tversky_loss_value + consistency_weight*consistency_loss
            # DO THE STEPS FOR BACKPROP (two things to be done in pytorch)
            lossTotal.backward()
            optimizer.step()
            update_ema_variables(net, ema_model, ema_decay, num_iter)
            num_iter+=1
            # THIS IS JUST TO VISUALIZE THE TRAINING 
            lossEpoch.append(Tversky_loss_value.cpu().data.numpy())
            ConsistEpoch.append(consistency_loss.cpu().data.numpy())
            printProgressBar(j + 1, num_batches,
                                prefix="[Training] Epoch: {} ".format(epoch),
                                length=15,
                                suffix=" Tversky_Loss: {:.4f}, Conssistency_loss: {:.4f}".format(Tversky_loss_value,consistency_loss))
        
        lossEpoch = np.asarray(lossEpoch)
        lossEpoch = lossEpoch.mean()
        ConsistEpoch = np.asarray(ConsistEpoch)
        ConsistEpoch = ConsistEpoch.mean()
        lossTotalMean = lossEpoch+ConsistEpoch
        lossTotalTraining.append(lossTotalMean)
        lr_scheduler.step(lossTotalMean)
        printProgressBar(num_batches, num_batches,
                                done="[Training] Epoch: {}, LossG: {:.4f}, lr: {:.8f} ".format(epoch,lossEpoch+ConsistEpoch,optimizer.param_groups[0]["lr"]))
        
        ######### Validation ############
        net.train(False)
        for j, vdata in enumerate(val_loader):
            vimages, vlabels, vimg_names = vdata
            vlabels = to_var(vlabels)
            vimages = to_var(vimages)
            voutputs = net(vimages)
            vsegmentation_classes = getTargetSegmentation(vlabels)
            vloss = tversky(softMax(voutputs), vsegmentation_classes)
            vlossTotal = vloss
            vlossEpoch.append(vloss.cpu().data.numpy())
            printProgressBar(j + 1, num_batches,
                                prefix="[Validation] Epoch: {} ".format(epoch),
                                length=15,
                                suffix=" Tversky_val_Loss: {:.4f}".format(vloss))
            
        vlossEpoch = np.asarray(vlossEpoch)
        vlossEpoch = vlossEpoch.mean()
        lossTotalValidation.append(vlossEpoch)
        printProgressBar(v_num_batches, v_num_batches,
                                done="[Validation] Epoch: {}, val_LossG: {:.4f}".format(epoch,vlossEpoch))


        ## THIS IS HOW YOU WILL SAVE THE TRAINED MODELS AFTER EACH EPOCH. 
        ## WARNING!!!!! YOU DON'T WANT TO SAVE IT AT EACH EPOCH, BUT ONLY WHEN THE MODEL WORKS BEST ON THE VALIDATION SET!!
        if not os.path.exists('./models/ssl/' + modelName):
            os.makedirs('./models/ssl/' + modelName)
        if  vlossEpoch < Best_loss_val:
            Best_loss_val = vlossEpoch
            BestEpoch = epoch 
            torch.save(net.state_dict(), './models/ssl/' + modelName + '/'  + str(BestEpoch)+'_Epoch')
            
        np.save(os.path.join(directory, 'TrainLosses.npy'), lossTotalTraining)
        np.save(os.path.join(directory, 'ValLosses.npy'), lossTotalValidation)

runTrainingTversky()


In [None]:
net = UNet(num_classes=4,n1=16) # attention unet
net.load_state_dict(torch.load(r'models\ssl\UNet_Model_Tversky\145_Epoch'))
net.eval()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
    ])
test_image = cv2.imread(r"Data/val/Img/patient001_12_1.png",0)
test_image = transform(test_image).float()
test_image = test_image.unsqueeze(0)
test_image_label = cv2.imread(r"Data/val/GT/patient001_12_1.png",0)
test_image_label = transform(test_image_label).float()

In [None]:
softMax = torch.nn.Softmax(dim=1)
with torch.no_grad():
    preds = softMax(net.forward(test_image))

color_map = {0:0,1:1/3,2:2/3,3:1}
preds = predToSegmentation(preds)
seg = torch.zeros((256,256))
for i in range(len(preds[0])):
    for x in range(256):
        for y in range(256):
            if preds[0][i][x][y] == 1:
                seg[x][y] = color_map[i]

In [None]:
fig,axs = plt.subplots(nrows=1,ncols=3,figsize=(15,15))
axs[0].imshow(test_image.squeeze(0)[0],cmap='gray')
axs[0].set_title('Image')
axs[1].imshow(test_image_label[0],cmap='gray')
axs[1].set_title('GT')
axs[2].imshow(seg,cmap='gray')
axs[2].set_title('prediction mask')

In [None]:
HD = computeHD(preds[0][1:],test_image_label[0])
HD

In [None]:
DSC = computeDSC(preds[0][1:],test_image_label[0])
DSC

In [None]:
train_loss = np.load(r'Results\Statistics\ssl\UNet_Model_Tversky_lr\TrainLosses.npy')
val_loss = np.load(r'Results\Statistics\ssl\UNet_Model_Tversky_lr\ValLosses.npy')
plt.plot(train_loss,label='train loss')
plt.plot(val_loss,label='val loss')
plt.legend()