In [1]:
import os

script_path = ""
try:
    os.path.dirname(os.path.abspath(__file__))
except NameError:
    for root, dirs, files in os.walk(os.getcwd()):
        # Skip 'data' directory and its subdirectories
        if "Data" in dirs:
            dirs.remove("Data")

        if "mainSegmentationChallenge.ipynb" in files:
            script_path = root
            break

if script_path == "":
    raise FileNotFoundError(
        "There is a problem in the folder structure.\nCONTACT gheith.abinader@icloud.com (514)699-5611"
    )

os.chdir(script_path)

print("Current Working Directory: ", os.getcwd())

from torch.utils.data import DataLoader
from torchvision import transforms
from progressBar import printProgressBar

import medicalDataLoader
import argparse
from utils import *

from UNet_Base import *
import random
import torch
import pdb

import numpy as np
from torch.nn.modules.loss import CrossEntropyLoss
import torch.optim as optim
from torch.optim.lr_scheduler import PolynomialLR
from torchvision.utils import save_image

Current Working Directory:  c:\Users\gheith\OneDrive - ETS\0 2023 MTI 865 - Apprentissage profind pour la vision par ordinateur\CleanedGithub\MTI865-Competition\Gheith\Cross_Teaching_CNN_CNN_plus


In [2]:
import warnings

warnings.filterwarnings("ignore")

In [3]:
# put outside of the function for pickeling
def worker_init_fn(worker_id):
    random.seed(1208 + worker_id)


def runTraining():
    print("-" * 40)
    print("~~~~~~~~  Starting the training... ~~~~~~")
    print("-" * 40)

    ## DEFINE HYPERPARAMETERS (batch_size > 1)
    batch_size = 16
    secondaty_batch_size = 8
    batch_size_val = 24
    base_lr = 0.01  # Learning Rate
    max_iterations = 30000

    ## DEFINE THE TRANSFORMATIONS TO DO AND THE VARIABLES FOR TRAINING AND VALIDATION

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set_full = medicalDataLoader.MedicalImageDataset(
        "train",
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False,
    )

    total_slices = len(train_set_full)
    labeled_slice = 204
    print(
        "Total silices is: {}, labeled slices is: {}".format(
            total_slices, labeled_slice
        )
    )
    labeled_idxs = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = medicalDataLoader.TwoStreamBatchSampler(
        labeled_idxs, unlabeled_idxs, batch_size, secondaty_batch_size
    )
    trainloader = DataLoader(train_set_full, batch_sampler=batch_sampler, num_workers=0)

    val_set = medicalDataLoader.MedicalImageDataset(
        "val", transform=transform, mask_transform=mask_transform, equalize=False
    )

    val_loader = DataLoader(
        val_set,
        batch_size=batch_size_val,
        worker_init_fn=np.random.seed(0),
        num_workers=0,
        shuffle=False,
    )

    ## INITIALIZE YOUR MODEL
    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    modelName1, modelName2 = "K3", "K5"
    print(" Model Name1: {}".format(modelName1))
    print(" Model Name2: {}".format(modelName2))

    # ## CREATION OF YOUR MODEL
    UEncK3 = UNetEncoderK3()
    UDecK3 = UNetDecoderK3()
    UEncK5 = UNetEncoderK5()
    UDecK5 = UNetDecoderK5()

    print(
        "Total params: {0:,}".format(
            sum(
                p.numel()
                for p in list(UEncK3.parameters())
                + list(UDecK3.parameters())
                + list(UEncK5.parameters())
                + list(UDecK5.parameters())
                if p.requires_grad
            )
        )
    )

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    ce_loss = CrossEntropyLoss()
    dice_loss = DiceLoss(4)
    softMax = torch.nn.Softmax()

    if torch.cuda.is_available():
        UEncK3.cuda()
        UDecK3.cuda()
        UEncK5.cuda()
        UDecK5.cuda()
        ce_loss.cuda()
        dice_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizerK3 = optim.SGD(
        list(UEncK3.parameters()) + list(UDecK3.parameters()),
        lr=base_lr,
        momentum=0.9,
        weight_decay=0.0001,
    )
    optimizerK5 = optim.SGD(
        list(UEncK3.parameters()) + list(UDecK3.parameters()),
        lr=base_lr,
        momentum=0.9,
        weight_decay=0.0001,
    )
    lr_schedulerK3 = PolynomialLR(
        optimizerK3,
        total_iters=max_iterations,  # The number of steps that the scheduler decays the learning rate.
        power=1,
    )  # The power of the polynomial.
    lr_schedulerK5 = PolynomialLR(
        optimizerK5,
        total_iters=max_iterations,  # The number of steps that the scheduler decays the learning rate.
        power=1,
    )  # The power of the polynomial.

    ### To save statistics ####
    lossTotalTraining = []
    lossTotalValidation = []
    Best_loss_val = 1000
    BestEpoch = 0
    no_improvement_counter = 0

    directory = "Results/Statistics/" + "CrossTeachingK3K5"

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory) == False:
        os.makedirs(directory)

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    print("{} iterations per epoch".format(len(trainloader)))
    ## START THE TRAINING
    ## FOR EACH EPOCH
    for epoch_num in range(max_epoch):
        UEncK3.train()
        UDecK3.train()
        UEncK5.train()
        UDecK5.train()
        lossEpoch = []
        num_batches = len(trainloader)
        ## FOR EACH BATCH
        for i_batch, sampled_batch in enumerate(trainloader):
            ### Set to zero all the gradients
            optimizerK3.zero_grad()
            optimizerK5.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            volume_batch, label_batch = sampled_batch["image"], sampled_batch["label"]
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            ################### Train ###################
            # -- The CNN makes its predictions (forward pass)
            featuresK3 = UEncK3(volume_batch)
            outK3 = UDecK3(featuresK3)
            outK3_soft = torch.softmax(outK3, dim=1)
            ##
            featuresK5 = UEncK5(volume_batch)
            outK5 = UDecK5(featuresK5)
            outK5_soft = torch.softmax(outK5, dim=1)

            # COMPUTE THE LOSS #adapted from https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py
            superv_ce_lossK3 = ce_loss(outK3[:8], label_batch.squeeze(1)[:8].long())
            superv_dice_lossK3 = dice_loss(outK3_soft[:8], label_batch[:8])
            supervised_lossK3 = 0.5 * (superv_ce_lossK3 + superv_dice_lossK3)
            ##
            superv_ce_lossK5 = ce_loss(outK5[:8], label_batch[:8].squeeze(1).long())
            superv_dice_lossK5 = dice_loss(outK5_soft[:8], label_batch[:8])
            supervised_lossK5 = 0.5 * (superv_ce_lossK5 + superv_dice_lossK5)

            pseudo_lblK3 = torch.argmax(outK3_soft[8:].detach(), dim=1, keepdim=False)
            pseudo_lblK5 = torch.argmax(outK5_soft[8:].detach(), dim=1, keepdim=False)

            pseudo_suprv_lossK3 = dice_loss(outK3_soft[8:], pseudo_lblK3.unsqueeze(1))
            pseudo_suprv_lossK5 = dice_loss(outK5_soft[8:], pseudo_lblK5.unsqueeze(1))

            consistency_weight = get_current_consistency_weight(iter_num // 150)

            K3Loss = supervised_lossK3 + consistency_weight * pseudo_suprv_lossK3
            K5Loss = supervised_lossK5 + consistency_weight * pseudo_suprv_lossK5

            loss = K3Loss + K5Loss
            # DO THE STEPS FOR BACKPROP (two things to be done in pytorch)
            loss.backward()

            optimizerK3.step()
            optimizerK5.step()
            lr_schedulerK3.step()
            lr_schedulerK5.step()

            iter_num = iter_num + 1

            # THIS IS JUST TO VISUALIZE THE TRAINING
            lossEpoch.append(loss.cpu().data.numpy())
            printProgressBar(
                i_batch + 1,
                num_batches,
                prefix="[Training] Epoch: {} ".format(epoch_num),
                length=15,
                suffix=" Loss: {:.4f}, ".format(loss),
            )

        lossEpoch = np.asarray(lossEpoch)
        lossEpoch = lossEpoch.mean()

        lossTotalTraining.append(lossEpoch)

        printProgressBar(
            num_batches,
            num_batches,
            done="[Training] Epoch: {}, LossG: {:.4f}".format(epoch_num, lossEpoch),
        )
        # VALIDATION
        UEncK3.eval()
        UDecK3.eval()
        UEncK5.eval()
        UDecK5.eval()
        lossEpochVal = []
        lossEpochVal1 = []
        lossEpochVal2 = []
        num_batches = len(val_loader)
        for i_batch, sampled_batch in enumerate(val_loader):
            images, labels = sampled_batch["image"], sampled_batch["label"]
            labels = to_var(labels)
            images = to_var(images)

            featuresK3 = UEncK3(images)
            outK3 = UDecK3(featuresK3)

            segmentation_classes = getTargetSegmentation(labels)
            CE_loss_value = ce_loss(outK3, segmentation_classes)

            predsoft = softMax(outK3)
            pred = predsoft.argmax(dim=1)

            # Show live view of model segmentation
            if not os.path.exists("Results/Segmentation1/"):
                os.makedirs("Results/Segmentation1/")
            save_image(
                torch.cat([pred.view(labels.shape[0], 1, 256, 256).data / 3.0]),
                "Results/Segmentation1/liveview.png".format(epoch_num),
            )

            # Dice_loss_value = computeDSC(pred.unsqueeze(1), segmentation_classes.unsqueeze(1))
            Dice_loss_value = dice_loss(predsoft, segmentation_classes.unsqueeze(1))

            lossTotal = CE_loss_value + Dice_loss_value

            lossEpochVal.append(lossTotal.cpu().data.numpy())
            printProgressBar(
                i_batch + 1,
                num_batches,
                prefix="[Validation] Epoch: {} ".format(epoch_num),
                length=15,
                suffix=" Loss: {:.4f}, CE: {:.4f}, Dice: {:.8f}".format(
                    lossTotal, CE_loss_value, Dice_loss_value
                ),
            )

            # SECOND

            featuresK5 = UEncK5(images)
            outK5 = UDecK5(featuresK5)

            segmentation_classes = getTargetSegmentation(labels)
            CE_loss_value = ce_loss(outK5, segmentation_classes)

            predsoft = softMax(outK5)
            pred = predsoft.argmax(dim=1)

            # Show live view of model segmentation
            if not os.path.exists("Results/Segmentation2/"):
                os.makedirs("Results/Segmentation2/")
            save_image(
                torch.cat([pred.view(labels.shape[0], 1, 256, 256).data / 3.0]),
                "Results/Segmentation2/liveview.png".format(epoch_num),
            )

            # Dice_loss_value = computeDSC(pred.unsqueeze(1), segmentation_classes.unsqueeze(1))
            Dice_loss_value = dice_loss(predsoft, segmentation_classes.unsqueeze(1))

            lossTotal = CE_loss_value + Dice_loss_value

            lossEpochVal.append(lossTotal.cpu().data.numpy())
            printProgressBar(
                i_batch + 1,
                num_batches,
                prefix="[Validation] Epoch: {} ".format(epoch_num),
                length=15,
                suffix=" Loss: {:.4f}, CE: {:.4f}, Dice: {:.8f}".format(
                    lossTotal, CE_loss_value, Dice_loss_value
                ),
            )
        # Save the model if it is the best so far
        modelName = modelName1
        lossEpochVal = lossEpochVal1
        modelStateDict = {"ENC": UEncK3.state_dict(), "DEC": UDecK3.state_dict()}
        if lossEpochVal2 < lossEpochVal1:
            modelName = modelName2
            lossEpochVal = lossEpochVal2
            modelStateDict = {"ENC": UEncK5.state_dict(), "DEC": UDecK5.state_dict()}

        lossEpochVal = np.asarray(lossEpochVal)
        lossEpochVal = lossEpochVal.mean()

        lossTotalValidation.append(lossEpochVal)

        printProgressBar(
            num_batches,
            num_batches,
            done="[Validation] Epoch: {}, LossG: {:.4f}".format(epoch_num, lossEpochVal),
        )

        if lossEpochVal < Best_loss_val:
            Best_loss_val = lossEpochVal
            BestEpoch = epoch_num
            no_improvement_counter = 0

            if not os.path.exists("./models/" + modelName):
                os.makedirs("./models/" + modelName)
            torch.save(modelStateDict, "./models/" + modelName + "/" + str(epoch_num) + "_Epoch")
            print("Best model saved at epoch {}".format(epoch_num))
        else:
            no_improvement_counter = no_improvement_counter + 1
            print(
                "No improvement in last epoch. Counter: {}".format(no_improvement_counter))
            
            if no_improvement_counter % 3 == 0 and no_improvement_counter != 0:
                print("No improvement in last 3 epochs.")
                # lr = lr/10
                # optimizer = torch.optim.Adam(net.parameters(), lr=lr)
            if epoch_num - BestEpoch > 7:
                print("No improvement in last 7 epochs. Stopping training.")
                break

    np.save(os.path.join(directory, "Losses.npy"), lossTotalTraining)
    np.save(os.path.join(directory, "Losses_val.npy"), lossTotalValidation)

    print("Training finished. Best model saved at epoch {}".format(BestEpoch))

    graphLosses(lossTotalTraining, lossTotalValidation, modelName, directory)

    return BestEpoch

    # ## THIS IS HOW YOU WILL SAVE THE TRAINED MODELS AFTER EACH EPOCH.
    #     ## WARNING!!!!! YOU DON'T WANT TO SAVE IT AT EACH EPOCH, BUT ONLY WHEN THE MODEL WORKS BEST ON THE VALIDATION SET!!
    #     if not os.path.exists('./models/' + modelName):
    #             os.makedirs('./models/' + modelName)

    #         torch.save(net.state_dict(), './models/' + modelName + '/' + str(i) + '_Epoch')

    #     np.save(os.path.join(directory, 'Losses.npy'), lossTotalTraining)

In [4]:
runTraining()

----------------------------------------
~~~~~~~~  Starting the training... ~~~~~~
----------------------------------------
Total silices is: 1208, labeled slices is: 204
~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~
 Model Name1: K3
 Model Name2: K5
Total params: 6,766,344
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
25 iterations per epoch
[Training] Epoch: 0 [DONE]                                 
[Training] Epoch: 0, LossG: 1.6998                                                                           
[Validation] Epoch: 0 [DONE]                                                             
[Validation] Epoch: 0 [DONE]                                                             
[Validation] Epoch: 0, LossG: nan                                                                            
No improvement in last epoch. Counter: 1
[Training] Epoch: 1 [DONE]                                 
[Training] Epoch: 1, LossG: 1.5756                                                            

KeyboardInterrupt: 