In [1]:
import os

script_path = ""
try:
    os.path.dirname(os.path.abspath(__file__))
except NameError:
    for root, dirs, files in os.walk(os.getcwd()):
        # Skip 'data' directory and its subdirectories
        if "Data" in dirs:
            dirs.remove("Data")

        if "mainSegmentationChallenge.ipynb" in files:
            script_path = root
            break

if script_path == "":
    raise FileNotFoundError(
        "There is a problem in the folder structure.\nCONTACT gheith.abinader@icloud.com (514)699-5611"
    )

os.chdir(script_path)

print("Current Working Directory: ", os.getcwd())

from torch.utils.data import DataLoader
from torchvision import transforms
from progressBar import printProgressBar

import medicalDataLoader
import argparse
from utils import *

from UNet_Base import *
import random
import torch
import pdb

import numpy as np
from torch.nn.modules.loss import CrossEntropyLoss
import torch.optim as optim
from torch.optim.lr_scheduler import PolynomialLR
from torchvision.utils import save_image


Current Working Directory:  c:\Users\gheith\OneDrive - ETS\0 2023 MTI 865 - Apprentissage profind pour la vision par ordinateur\CleanedGithub\MTI865-Competition\Gheith\Cross_Teaching_CNN_CNN_plus


In [2]:
import warnings

warnings.filterwarnings("ignore")

In [3]:
# put outside of the function for pickeling
def worker_init_fn(worker_id):
    random.seed(1208 + worker_id)


def runTraining():
    print("-" * 40)
    print("~~~~~~~~  Starting the training... ~~~~~~")
    print("-" * 40)

    ## DEFINE HYPERPARAMETERS (batch_size > 1)
    batch_size = 16
    secondaty_batch_size = 8
    batch_size_val = 24
    base_lr = 0.01  # Learning Rate
    max_iterations = 30000

    ## DEFINE THE TRANSFORMATIONS TO DO AND THE VARIABLES FOR TRAINING AND VALIDATION

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set_full = medicalDataLoader.MedicalImageDataset(
        "train",
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False,
    )

    total_slices = len(train_set_full)
    labeled_slice = 204
    print(
        "Total silices is: {}, labeled slices is: {}".format(
            total_slices, labeled_slice
        )
    )
    labeled_idxs = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = medicalDataLoader.TwoStreamBatchSampler(
        labeled_idxs, unlabeled_idxs, batch_size, secondaty_batch_size
    )
    trainloader = DataLoader(train_set_full, batch_sampler=batch_sampler, num_workers=0)

    val_set = medicalDataLoader.MedicalImageDataset(
        "val", transform=transform, mask_transform=mask_transform, equalize=False
    )

    val_loader = DataLoader(
        val_set,
        batch_size=batch_size_val,
        worker_init_fn=np.random.seed(0),
        num_workers=0,
        shuffle=False,
    )

    ## INITIALIZE YOUR MODEL
    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    modelName1, modelName2 = "K3", "K5"
    print(" Model Name1: {}".format(modelName1))
    print(" Model Name2: {}".format(modelName2))

    # ## CREATION OF YOUR MODEL
    UEncK3 = UNetEncoderK3()
    UDecK3 = UNetDecoderK3()
    UEncK5 = UNetEncoderK5()
    UDecK5 = UNetDecoderK5()

    print(
        "Total params: {0:,}".format(
            sum(
                p.numel()
                for p in list(UEncK3.parameters())
                + list(UDecK3.parameters())
                + list(UEncK5.parameters())
                + list(UDecK5.parameters())
                if p.requires_grad
            )
        )
    )

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    ce_loss = CrossEntropyLoss()
    dice_loss = DiceLoss(4)
    softMax = torch.nn.Softmax()

    if torch.cuda.is_available():
        UEncK3.cuda()
        UDecK3.cuda()
        UEncK5.cuda()
        UDecK5.cuda()
        ce_loss.cuda()
        dice_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizerK3 = optim.SGD(
        list(UEncK3.parameters()) + list(UDecK3.parameters()),
        lr=base_lr,
        momentum=0.9,
        weight_decay=0.0001,
    )
    optimizerK5 = optim.SGD(
        list(UEncK3.parameters()) + list(UDecK3.parameters()),
        lr=base_lr,
        momentum=0.9,
        weight_decay=0.0001,
    )
    lr_schedulerK3 = PolynomialLR(
        optimizerK3,
        total_iters=max_iterations,  # The number of steps that the scheduler decays the learning rate.
        power=1,
    )  # The power of the polynomial.
    lr_schedulerK5 = PolynomialLR(
        optimizerK5,
        total_iters=max_iterations,  # The number of steps that the scheduler decays the learning rate.
        power=1,
    )  # The power of the polynomial.

    ### To save statistics ####
    lossTotalTraining = []
    lossTotalValidation1 = []
    lossTotalValidation2 = []
    Best_loss_val1 = 1000
    Best_loss_val2 = 1000
    BestEpoch = 0
    performance1 = 0
    performance2 = 0

    directory = "Results/Statistics/" + "CrossTeachingK3K5"

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory) == False:
        os.makedirs(directory)

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    print("{} iterations per epoch".format(len(trainloader)))
    ## START THE TRAINING
    ## FOR EACH EPOCH
    for epoch_num in range(max_epoch):
        UEncK3.train()
        UDecK3.train()
        UEncK5.train()
        UDecK5.train()
        lossEpoch = []
        num_batches = len(trainloader)
        ## FOR EACH BATCH
        for i_batch, sampled_batch in enumerate(trainloader):
            ### Set to zero all the gradients
            optimizerK3.zero_grad()
            optimizerK5.zero_grad()

            ## GET imagesv, LABELS and IMG NAMES
            volume_batch, label_batch = sampled_batch["image"], sampled_batch["label"]
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            ################### Train ###################
            # -- The CNN makes its predictions (forward pass)
            featuresK3 = UEncK3(volume_batch)
            outK3 = UDecK3(featuresK3)
            outK3_soft = torch.softmax(outK3, dim=1)
            ##
            featuresK5 = UEncK5(volume_batch)
            outK5 = UDecK5(featuresK5)
            outK5_soft = torch.softmax(outK5, dim=1)

            # COMPUTE THE LOSS #adapted from https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py
            superv_ce_lossK3 = ce_loss(outK3[:8], label_batch.squeeze(1)[:8].long())
            superv_dice_lossK3 = dice_loss(outK3_soft[:8], label_batch[:8])
            supervised_lossK3 = 0.5 * (superv_ce_lossK3 + superv_dice_lossK3)
            ##
            superv_ce_lossK5 = ce_loss(outK5[:8], label_batch[:8].squeeze(1).long())
            superv_dice_lossK5 = dice_loss(outK5_soft[:8], label_batch[:8])
            supervised_lossK5 = 0.5 * (superv_ce_lossK5 + superv_dice_lossK5)

            pseudo_lblK3 = torch.argmax(outK3_soft[8:].detach(), dim=1, keepdim=False)
            pseudo_lblK5 = torch.argmax(outK5_soft[8:].detach(), dim=1, keepdim=False)

            pseudo_suprv_lossK3 = dice_loss(outK3_soft[8:], pseudo_lblK3.unsqueeze(1))
            pseudo_suprv_lossK5 = dice_loss(outK5_soft[8:], pseudo_lblK5.unsqueeze(1))

            consistency_weight = get_current_consistency_weight(iter_num // 150)

            K3Loss = supervised_lossK3 + consistency_weight * pseudo_suprv_lossK3
            K5Loss = supervised_lossK5 + consistency_weight * pseudo_suprv_lossK5

            loss = K3Loss + K5Loss
            # DO THE STEPS FOR BACKPROP (two things to be done in pytorch)
            loss.backward()

            optimizerK3.step()
            optimizerK5.step()
            lr_schedulerK3.step()
            lr_schedulerK5.step()

            iter_num = iter_num + 1

            # THIS IS JUST TO VISUALIZE THE TRAINING
            lossEpoch.append(loss.cpu().data.numpy())
            printProgressBar(
                i_batch + 1,
                num_batches,
                prefix="[Training] Epoch: {} ".format(epoch_num),
                length=15,
                suffix=" Loss: {:.4f}, ".format(loss),
            )
            #from the https://github.com/HiLab-git/SSL4MIS/val_2d.py
            # EVALUATE BOTH MODELS AND PICK THE BEST ONE
            if i_batch>0 and i_batch%24 == 0:
                UEncK3.eval()
                UDecK3.eval()
                UEncK5.eval()
                UDecK5.eval()
                metric_list_accum1 = 0.0
                metric_list_accum2 = 0.0
                total_batches = len(val_loader)    
                print("{} iterations per validaion epoch".format(total_batches))
                random_batch_index = random.randint(0, total_batches - 1)
                for i_batch2, sampled_batchv in enumerate(val_loader):
                    volume_batchv, label_batchv = sampled_batchv["image"], sampled_batchv["label"]
                    imagesv, labelsv = volume_batchv.squeeze(1).cpu().detach().numpy(), label_batchv.squeeze(1).cpu().detach()
                    gt_classes = getTargetSegmentation(labelsv).numpy()
                    labelsv = labelsv.numpy()
                    prediction1 = np.zeros_like(labelsv)
                    prediction2 = np.zeros_like(labelsv)
                    for ind in range(imagesv.shape[0]):
                        slice = imagesv[ind, :, :]
                        x, y = slice.shape[0], slice.shape[1]
                        input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
                        with torch.no_grad():
                            featuresK3 = UEncK3(input)
                            outK3 = UDecK3(featuresK3)
                            out1 = torch.argmax(torch.softmax(outK3, dim=1), dim=1).squeeze(0)
                            pred1 = out1.cpu().detach().numpy()
                            prediction1[ind] = pred1
                            featuresK5 = UEncK5(input)
                            outK5 = UDecK5(featuresK5)
                            out2 = torch.argmax(torch.softmax(outK5, dim=1), dim=1).squeeze(0)
                            pred2 = out2.cpu().detach().numpy()
                            prediction2[ind] = pred2
                    metric_list1=[]
                    metric_list2=[]
                    for i in range(1, 4):
                        pred1 = np.where(prediction1 == i, 1, 0)
                        pred2 = np.where(prediction2 == i, 1, 0)
                        gt = np.where(gt_classes == i, 1, 0)
                        metric_list1.append(calculate_metric_percase(pred1, gt))
                        metric_list2.append(calculate_metric_percase(pred2, gt))
                    metric_list_accum1 += np.array(metric_list1)
                    metric_list_accum2 += np.array(metric_list2)
                    #SAVE IMAGE
                    if i_batch%1500 == 0 and ibatch2 == random_batch_index:
                        save_sample_imagesv(imagesv, gt_classes, prediction1, predictio2, os.join('Results', 'Segmentation1And2'))
                        
                    printProgressBar(
                            total_batches,
                            total_batches,
                            done="[Validation] Epoch: {}, Loss1: {:.4f}, Loss2: {:.4f}".format(epoch_num, metric_list_accum1[-1][0], metric_list_accum2[-1][0]),
                        )
                        
                metric_list_accum1 = metric_list_accum1 / len(val_loader)
                metric_list_accum2 = metric_list_accum2 / len(val_loader)
                for class_i in range(3):
                    print('info/model1_val_{}_dice'.format(class_i+1),
                                      metric_list_accum1[class_i, 0], iter_num)
                    print('info/model1_val_{}_hd95'.format(class_i+1),
                                      metric_list_accum1[class_i, 1], iter_num)
                    # print('info/model1_val_{}_asd'.format(class_i+1),
                    #                   metric_list_accum1[class_i, 2], iter_num)
                    print('info/model2_val_{}_dice'.format(class_i+1),
                                      metric_list_accum2[class_i, 0], iter_num)
                    print('info/model2_val_{}_hd95'.format(class_i+1),
                                      metric_list_accum2[class_i, 1], iter_num)
                    # print('info/model2_val_{}_asd'.format(class_i+1),
                    #                   metric_list_accum2[class_i, 2], iter_num)
                
                performance1 = np.mean(metric_list_accum1, axis=0)[0]
                performance2 = np.mean(metric_list_accum2, axis=0)[0]

                # mean_hd951 = np.mean(metric_list_accum1, axis=0)[1]
                # mean_hd952 = np.mean(metric_list_accum2, axis=0)[1]
                
                # mean_asd = np.mean(metric_list_accum1, axis=0)[2]
                # mean_asd = np.mean(metric_list_accum2, axis=0)[2]
                
                print('info/model1_val_mean_dice', performance1, iter_num)
                print('info/model2_val_mean_dice', performance2, iter_num)
                # print('info/model1_val_mean_hd95', mean_hd951, iter_num)
                # print('info/model2_val_mean_hd95', mean_hd952, iter_num)
                # print('info/model1_val_mean_asd', mean_asd, iter_num)
                # print('info/model2_val_mean_asd', mean_asd, iter_num)
                # EXIT
                UEncK3.train()
                UDecK3.train()
                UEncK5.train()
                UDecK5.train()
                

        lossEpoch = np.asarray(lossEpoch)
        lossEpoch = lossEpoch.mean()

        lossTotalTraining.append(lossEpoch)

        printProgressBar(
            num_batches,
            num_batches,
            done="[Training] Epoch: {}, LossG: {:.4f}".format(epoch_num, lossEpoch),
        )
        # Save the model if it is the best so far
        if performance1 > Best_loss_val1:
            Best_loss_val1 = performance1
            if not os.path.exists("./models/" + modelName1):
                os.makedirs("./models/" + modelName1)
            torch.save({"ENC": UEncK3.state_dict(), "DEC": UDecK3.state_dict()}, "./models/" + modelName + "/" + str(epoch_num) + "_Epoch")
        if performance2 > Best_loss_val2:
            Best_loss_val2 = performance2
            if not os.path.exists("./models/" + modelName2):
                os.makedirs("./models/" + modelName2)
            torch.save({"ENC": UEncK5.state_dict(), "DEC": UDecK5.state_dict()}, "./models/" + modelName + "/" + str(epoch_num) + "_Epoch")

        lossEpochVal1 = performance1
        lossEpochVal2 = performance2

        lossTotalValidation1.append(lossEpochVal1)
        lossTotalValidation2.append(lossEpochVal2)

    np.save(os.path.join(directory, "Losses.npy"), lossTotalTraining)
    np.save(os.path.join(directory, "Losses_val1.npy"), lossTotalValidation1)
    np.save(os.path.join(directory, "Losses_val2.npy"), lossTotalValidation2)

    print("Training finished. Best model saved at epoch {}".format(BestEpoch))

    graphLosses(lossTotalTraining, lossTotalValidation1, modelName1, directory)
    graphLosses(lossTotalTraining, lossTotalValidation2, modelName2, directory)

    return Best_loss_val1, Best_loss_val2

In [4]:
runTraining()

----------------------------------------
~~~~~~~~  Starting the training... ~~~~~~
----------------------------------------
Total silices is: 1208, labeled slices is: 204
~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~
 Model Name1: K3
 Model Name2: K5
Total params: 6,766,344
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
25 iterations per epoch


[Training] Epoch: 0 [DONE]                                 
4 iterations per validaion epoch
[Validation] Epoch: 0, Loss1: 0.0000, Loss2: 0.0016                                                          
[Validation] Epoch: 0, Loss1: 0.0000, Loss2: 0.0083                                                          
[Validation] Epoch: 0, Loss1: 0.0000, Loss2: 0.0148                                                          
[Validation] Epoch: 0, Loss1: 0.0000, Loss2: 0.0148                                                          
info/model1_val_1_dice 0.0 25
info/model1_val_1_hd95 0.0 25


IndexError: index 2 is out of bounds for axis 1 with size 2