# Here, we will try to segment skin leison using U-Net architecture.

01/03/2024

Author = <a href="https://jimut123.github.io/" target="_blank" alt="Jimut">Jimut</a>

Link to the data for single class segmentation: https://drive.google.com/drive/u/0/folders/1JWRW_GuolArm3mPsO3Elf1SFJEZPzWBR

## Download the data for segmentation, pre-process it to load in the train-val-test pipeline

In [1]:
! gdown 12VD4JR_ORiOoIuZHfy9xhO_QXcP5CEzU

Downloading...
From (original): https://drive.google.com/uc?id=12VD4JR_ORiOoIuZHfy9xhO_QXcP5CEzU
From (redirected): https://drive.google.com/uc?id=12VD4JR_ORiOoIuZHfy9xhO_QXcP5CEzU&confirm=t&uuid=c4ba6ffc-c7ae-4636-8b1e-369609ebb97b
To: /content/skin_lesion.zip
100% 50.2M/50.2M [00:00<00:00, 118MB/s]


In [2]:
! unzip -qq skin_lesion.zip

replace ph2_resized/trainx/X_img_0.bmp? [y]es, [n]o, [A]ll, [N]one, [r]ename: A


## Create the train-val-test set

In [59]:
import os
import sys
import cv2
import glob


FOLDER_NAME = "train_imgs"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass

FOLDER_NAME = "val_imgs"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass

FOLDER_NAME = "test_imgs"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass

FOLDER_NAME = "train_masks"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass

FOLDER_NAME = "val_masks"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass

FOLDER_NAME = "test_masks"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass

ALL_IMAGES = glob.glob('trainx/*.bmp')
train_per = 70
vali_per = 10
test_per = 20

TRAIN_IMG_NUM = int(train_per*len(ALL_IMAGES)/100)
VALID_IMG_NUM = int(vali_per*len(ALL_IMAGES)/100)
TEST_IMG_NUM = TRAIN_IMG_NUM + VALID_IMG_NUM

print("Training images = ",TRAIN_IMG_NUM)
print("Validation images = ",VALID_IMG_NUM)
print("Test images = ",TEST_IMG_NUM)

for image_name in ALL_IMAGES[:TRAIN_IMG_NUM]:
    img = cv2.imread(image_name, cv2.IMREAD_UNCHANGED)
    img_x_name = str(image_name.split('/')[-1]).split('.')[0]
    img_x_id = str(str(image_name.split('/')[-1]).split('_')[-1]).split('.')[0]
    mask_name = "trainy/"+'Y_img_'+str(img_x_id)+".bmp"
    mask = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
    save_img_name = "train_imgs/"+img_x_name+".bmp"
    save_mask_name = "train_masks/"+img_x_name+".bmp"
    cv2.imwrite(save_img_name,img)
    cv2.imwrite(save_mask_name,mask)

for image_name in ALL_IMAGES[TRAIN_IMG_NUM:TRAIN_IMG_NUM+VALID_IMG_NUM]:
    img = cv2.imread(image_name, cv2.IMREAD_UNCHANGED)
    img_x_name = str(image_name.split('/')[-1]).split('.')[0]
    img_x_id = str(str(image_name.split('/')[-1]).split('_')[-1]).split('.')[0]
    mask_name = "trainy/"+'Y_img_'+str(img_x_id)+".bmp"
    mask = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
    save_img_name = "val_imgs/"+img_x_name+".bmp"
    save_mask_name = "val_masks/"+img_x_name+".bmp"
    cv2.imwrite(save_img_name,img)
    cv2.imwrite(save_mask_name,mask)

for image_name in ALL_IMAGES[TEST_IMG_NUM:]:
    img = cv2.imread(image_name, cv2.IMREAD_UNCHANGED)
    img_x_name = str(image_name.split('/')[-1]).split('.')[0]
    img_x_id = str(str(image_name.split('/')[-1]).split('_')[-1]).split('.')[0]
    mask_name = "trainy/"+'Y_img_'+str(img_x_id)+".bmp"
    mask = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
    save_img_name = "test_imgs/"+img_x_name+".bmp"
    save_mask_name = "test_masks/"+img_x_name+".bmp"
    cv2.imwrite(save_img_name,img)
    cv2.imwrite(save_mask_name,mask)



Training images =  140
Validation images =  20
Test images =  160


In [60]:
all_train_imgs = glob.glob('train_imgs/*.bmp')
all_train_masks = glob.glob('train_masks/*.bmp')
all_val_imgs = glob.glob('val_imgs/*.bmp')
all_val_masks = glob.glob('val_masks/*.bmp')
all_test_imgs = glob.glob('test_imgs/*.bmp')
all_test_masks = glob.glob('test_masks/*.bmp')
print("Length of all_train_imgs = ",len(all_train_imgs))
print("Length of all_train_masks = ",len(all_train_masks))
print("Length of all_val_imgs = ",len(all_val_imgs))
print("Length of all_val_masks = ",len(all_val_masks))
print("Length of all_test_imgs = ",len(all_test_imgs))
print("Length of all_test_masks = ",len(all_test_masks))

Length of all_train_imgs =  140
Length of all_train_masks =  140
Length of all_val_imgs =  20
Length of all_val_masks =  20
Length of all_test_imgs =  40
Length of all_test_masks =  40


In [61]:

# mostly torch imports and plot imports
import os
import sys
import cv2
# mostly torch imports and plot imports
import torch
import shutil
import glob
import pickle
import random
random.seed(42)

import numpy as np
np.random.seed(42)
import torch.utils
import torchvision
from PIL import Image
from torch import optim
import torchvision.transforms as T
import torch.distributions
torch.manual_seed(42)
from tqdm import tqdm
import torch.nn as nn
from torchsummary import summary
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from matplotlib import rc, rcParams

## Enter CUDA/CPU mode

In [70]:
# Check CUDA  and stuffs

use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
device = torch.device("cuda" if use_cuda else "cpu")
print("Device to be used : ",device)


use_cuda: True
Device to be used :  cuda


## Create the dump folders

In [62]:

FOLDER_NAME = "checkpoint"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass


FOLDER_NAME = "inference_data"


try:
    os.makedirs(FOLDER_NAME)
except:
    pass


FOLDER_NAME = "history"

try:
    os.makedirs(FOLDER_NAME)
except:
    pass


## Evalutation metrics here

In [63]:

# Evaluation metrics  here
########################################################################################
########################################################################################
########################################################################################

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances


def dice_coef(y_true, y_pred):
    smooth = 0.0001
    y_true = np.where((y_true > 0.5) & (y_true <= 1.), 1, y_true)
    y_true = np.where((y_true > 0.0) & (y_true <= 0.5), 0, y_true)
    y_pred = np.where((y_pred > 0.5) & (y_pred <= 1.), 1, y_pred)
    y_pred = np.where((y_pred > 0.0) & (y_pred <= 0.5), 0, y_pred)
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)


def jacard(y_true, y_pred):
    smooth = 0.0001
    y_true = np.where((y_true > 0.5) & (y_true <= 1.), 1, y_true)
    y_true = np.where((y_true > 0.0) & (y_true <= 0.5), 0, y_true)
    y_pred = np.where((y_pred > 0.5) & (y_pred <= 1.), 1, y_pred)
    y_pred = np.where((y_pred > 0.0) & (y_pred <= 0.5), 0, y_pred)
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f + y_pred_f - y_true_f * y_pred_f)
    return intersection/union


def dice_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)


def dice_coef_multilabel(y_true, y_pred, numLabels):
    dice=0
    for index in range(numLabels):
        dice += dice_coef(y_true[:,index,:,:], y_pred[:,index,:,:])
    return dice/numLabels # taking average


def jacard_multilabel(y_true, y_pred, numLabels):
    jacard_sum=0
    for index in range(numLabels):
        jacard_sum += jacard(y_true[:,index,:,:], y_pred[:,index,:,:])
    return jacard_sum/numLabels # taking average


def accuracy_check(mask, prediction):
    ims = [mask, prediction]
    np_ims = []
    for item in ims:
        if 'str' in str(type(item)):
            item = np.array(Image.open(item))
        elif 'PIL' in str(type(item)):
            item = np.array(item)
        elif 'torch' in str(type(item)):
            item = item.numpy()
        np_ims.append(item)

    compare = np.equal(np_ims[0], np_ims[1])
    accuracy = np.sum(compare)

    return accuracy/len(np_ims[0].flatten())


def accuracy_check_for_batch(masks, predictions, batch_size):
    total_acc = 0
    for index in range(batch_size):
        total_acc += accuracy_check(masks[index], predictions[index])
    return total_acc/batch_size

## U-Net model here

In [64]:

# U-Net model here
########################################################################################
########################################################################################
########################################################################################

import torch
import torch.nn as nn
import torch.nn.functional as F

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )


class UNet_Sigmoid(nn.Module):

    def __init__(self, n_class):
        super().__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, n_class, 1)


    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        x = self.dconv_down4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        return torch.sigmoid(out)

## Some Image Augmentation stuffs that you can use

In [71]:

####################
def add_random_boxes(img, mask, n_k,size=32):
    h,w = size,size
    img = np.asarray(img)
    mask = np.asarray(mask)
    img_size = img.shape[1]
    boxes = []
    for k in range(n_k):
        y,x = np.random.randint(0,img_size-w,(2,))
        img[y:y+h,x:x+w] = 0
        mask[y:y+h,x:x+w] = 0
        boxes.append((x,y,h,w))
    # img = Image.fromarray(img.astype('uint8'), 'RGB')
    return img, mask

import torchvision.transforms.functional as TF
from skimage.filters import gaussian
from skimage.filters import unsharp_mask

def transformer(image, mask):
    # image and mask are PIL image object.
    img_w, img_h = mask.size

    # Random horizontal flipping
    if random.random() > 0.5:
        image = TF.hflip(image)
        mask = TF.hflip(mask)

    # Random vertical flipping
    if random.random() > 0.5:
        image = TF.vflip(image)
        mask = TF.vflip(mask)

    # Random affine
    affine_param = transforms.RandomAffine.get_params(
        degrees = [-180, 180], translate = [0.3,0.3],
        img_size = [img_w, img_h], scale_ranges = [1, 1.3],
        shears = [2,2])
    image = TF.affine(image,
                      affine_param[0], affine_param[1],
                      affine_param[2], affine_param[3])
    mask = TF.affine(mask,
                     affine_param[0], affine_param[1],
                     affine_param[2], affine_param[3])

    image = np.array(image)
    mask = np.array(mask)
    return image, mask

## Create the train and test data-loaders

In [66]:
NUM_LABELS = 1

# Use the data generator to load the dataset
class DataGenerator(Dataset):
    def __init__(self, image_list,mask_path):
        self.files = image_list
        self.select_path = mask_path

    #NUMBER OF FILES IN THE DATASET
    def __len__(self):
        return len(self.files)

    #GETTING SINGLE PAIR OF DATA
    def __getitem__(self,idx):
        file_name = self.files[idx].split('/')[-1]
        mask_name = self.select_path+file_name
        img = cv2.imread(self.files[idx],cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_name,cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (256, 256))
        mask = cv2.resize(mask, (256, 256))
        mask = mask/255.0
        img_transpose = np.transpose(img, (2, 0, 1))
        mask = np.expand_dims(mask, axis=2)
        mask_one_hot_transpose = np.transpose(mask, (2, 0, 1))
        return torch.FloatTensor(img_transpose/img_transpose.max()), torch.FloatTensor(mask_one_hot_transpose)


class TestDataGenerator(Dataset):
    def __init__(self, image_list,mask_path):
        self.files = image_list
        self.select_path = mask_path

    #NUMBER OF FILES IN THE DATASET
    def __len__(self):
        return len(self.files)

    #GETTING SINGLE PAIR OF DATA
    def __getitem__(self,idx):
        file_name = self.files[idx].split('/')[-1]
        mask_name = self.select_path+file_name
        img = cv2.imread(self.files[idx],cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_name,cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (256, 256))
        mask = cv2.resize(mask, (256, 256))
        mask = mask/255.0
        img_transpose = np.transpose(img, (2, 0, 1))
        mask = np.expand_dims(mask, axis=2)
        mask_one_hot_transpose = np.transpose(mask, (2, 0, 1))
        print("Unique values of image == ",np.unique(torch.FloatTensor(img_transpose/img_transpose.max())))
        return torch.FloatTensor(img_transpose/img_transpose.max()), torch.FloatTensor(mask_one_hot_transpose)


def load_data(image_list, mask_path, batch_size=2, num_workers=10, shuffle=True):
    dataset = DataGenerator(image_list, mask_path)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
    return data_loader

def load_test_data(image_list, mask_path, batch_size=2, num_workers=10, shuffle=True):
    dataset = TestDataGenerator(image_list, mask_path)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
    return data_loader

def get_image_address(image_data_folder):
    image_address_list = []
    image_address_list = glob.glob(image_data_folder+"/*.bmp")
    print("Number of Files : ", len(image_address_list))
    for img_addr in image_address_list:
        try :
            img = cv2.imread(img_addr)
            x = img.shape
        except :
            image_address_list.remove(img_addr)
            os.remove(img_addr)
    print("Number of Files after removing : ", len(image_address_list))
    return image_address_list

## Save and load checkpoint

In [None]:

# save checkpoint in pytorch
def save_ckp(checkpoint, checkpoint_path):
    torch.save(checkpoint, checkpoint_path)


# load checkpoint in pytorch
def load_ckp(checkpoint_path, model, model_opt):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    model_opt.load_state_dict(checkpoint['optimizer'])
    return model, model_opt, checkpoint['epoch']


## Train-Val-Test pipeline

In [67]:


# Train, Val and Test
#####################################################################

def train_epoch(train_loader, model, optimizer, epoch):
    print("\n\n---------------------------------------------------------------------------------------------------------------\n")

    progress_bar = tqdm(enumerate(train_loader))
    total_loss = 0.0
    total_dice = 0.0
    total_jacard = 0.0
    total_acc = 0.0
    N = 0
    for step, (inp__, gt__) in progress_bar:
        # if mask_1 is None and low_res_img is None:
        #     continue
        model.train()

        #TRANSFERRING DATA TO DEVICE
        inp__ = inp__.to(device)
        gt__ = gt__.to(device)

        # clear the gradient
        optimizer.zero_grad()

        #GETTING THE PREDICTED IMAGE
        pred_img = model.forward(inp__)
        #LOSS FUNCTIONS
        BCELOSS = nn.BCELoss()

        #CALCULATING LOSSES
        BCE_loss = BCELOSS(pred_img, gt__)

        #LOSS TAKEN INTO CONSIDERATION
        loss = BCE_loss

        # CALCULATING METRICS
        total_loss += loss
        pred_img_detached = pred_img.cpu().detach().numpy()
        gt_img_detached = gt__.cpu().detach().numpy()
        numLabels = NUM_LABELS
        dice = dice_coef(gt_img_detached, pred_img_detached)
        jacard_ = jacard(gt_img_detached, pred_img_detached)
        acc_ = accuracy_check_for_batch(np.round(pred_img_detached), gt_img_detached, pred_img_detached.shape[0])
        total_dice += float(dice)
        total_jacard += float(jacard_)
        total_acc += float(acc_)
        N += 1

        #BACKPROPAGATING THE LOSS
        loss.backward()
        optimizer.step()

        #DISPLAYING THE LOSS
        progress_bar.set_description("Epoch: {} -  Loss: {} - Acc: {} - Dice: {} - Jaccard: {}".format(epoch, loss, acc_, dice, jacard_))

    mean_loss = total_loss/N
    mean_dice = total_dice/N
    mean_jacard = total_jacard/N
    mean_acc = total_acc/N

    with open("history/train_logs.txt", "a") as text_file:
        text_file.write("{} {} {} {} {}\n".format(epoch, mean_loss, mean_acc, mean_dice, mean_jacard))
    print("[Training] Training Epoch: {} |  Loss: {} | Acc: {} | Dice: {} | Jaccard: {}".format(epoch, mean_loss, mean_acc, mean_dice, mean_jacard))

    return model, optimizer




def val_epoch(val_loader, model, optimizer, epoch):
    progress_bar = tqdm(enumerate(val_loader))
    total_loss = 0.0
    total_dice = 0.0
    total_jacard = 0.0
    total_acc = 0.0
    N = 0
    for step, (inp__, gt__) in progress_bar:
        if inp__ is None and gt__ is None:
            continue

        inp__ = inp__.to(device)
        gt__ = gt__.to(device)

        model.eval()
        pred_img = model.forward(inp__)

        #LOSS FUNCTIONS
        BCELOSS = nn.BCELoss()

        #CALCULATING LOSSES
        BCE_loss = BCELOSS(pred_img, gt__)

        #LOSS TAKEN INTO CONSIDERATION
        loss = BCE_loss

        # CALCULATING METRICS
        total_loss += loss
        pred_img_detached = pred_img.cpu().detach().numpy()
        gt_img_detached = gt__.cpu().detach().numpy()
        numLabels = NUM_LABELS
        #dice = dice_coef(gt_img_detached, pred_img_detached)
        dice = dice_coef(gt_img_detached, pred_img_detached)
        #dice_coef(gt_img_detached, pred_img_detached)
        jacard_ = jacard(gt_img_detached, pred_img_detached)
        acc_ = accuracy_check_for_batch(np.round(pred_img_detached), gt_img_detached, pred_img_detached.shape[0])
        total_dice += float(dice)
        total_jacard += float(jacard_)
        total_acc += float(acc_)
        N += 1
        progress_bar.set_description("Epoch: {} -  Loss: {} - Acc: {} - Dice: {} - Jaccard: {}".format(epoch, loss, acc_, dice, jacard_))

    # mean_hausdroff = total_hausdroff/N
    mean_loss = total_loss/N
    mean_dice = total_dice/N
    mean_jacard = total_jacard/N
    mean_acc = total_acc/N

    with open("history/val_logs.txt", "a") as text_file:
        text_file.write("{} {} {} {} {}\n".format(epoch, mean_loss, mean_acc, mean_dice, mean_jacard))

    print("[Validation] Validation Epoch: {} |  Loss: {} | Acc: {} | Dice: {} | Jaccard: {}".format(epoch, mean_loss, mean_acc, mean_dice, mean_jacard))



def test_epoch(test_loader, model, optimizer, epoch):

    progress_bar = tqdm(enumerate(test_loader))
    total_loss = 0.0

    #SETTING THE NUMBER OF IMAGES TO CHECK AFTER EACH ITERATION
    no_img_to_write = 20

    total_loss = 0.0
    total_dice = 0.0
    total_jacard = 0.0
    total_acc = 0.0
    N = 0
    for step, (inp__, gt__) in progress_bar:
        if inp__ is None and gt__ is None:
            continue

        inp__ = inp__.to(device)
        gt__ = gt__.to(device)


        #PREDICTED IMAGE
        pred_img = model.forward(inp__)

        #LOSS FUNCTIONS
        BCELOSS = nn.BCELoss()

        model.eval()

        #CALCULATING LOSSES
        BCE_loss = BCELOSS(pred_img, gt__)

        #LOSS TAKEN INTO CONSIDERATION
        loss = BCE_loss

        # CALCULATING METRICS
        total_loss += loss
        pred_img_detached = pred_img.cpu().detach().numpy()
        gt_img_detached = gt__.cpu().detach().numpy()
        numLabels = NUM_LABELS
        dice = dice_coef(gt_img_detached, pred_img_detached)
        jacard_ = jacard(gt_img_detached, pred_img_detached)
        acc_ = accuracy_check_for_batch(np.round(pred_img_detached), gt_img_detached, pred_img_detached.shape[0])
        total_dice += float(dice)
        total_jacard += float(jacard_)
        total_acc += float(acc_)
        N += 1

        progress_bar.set_description("[Test] Epoch: {} -  Loss: {} - Acc: {} - Dice: {} - Jaccard: {}".format(epoch, loss, acc_, dice, jacard_))

        #WRITING THE IMAGES INTO THE SPECIFIED DIRECTORY
        if(step < no_img_to_write):

            p_img = pred_img.cpu().numpy() * 255
            gt_img = gt__.cpu().numpy() * 255
            inp_img = inp__.cpu().numpy() * 255.0

            #FOLDER PATH TO WRITE THE INFERENCES
            inference_folder = "inference_data"
            if not os.path.isdir(inference_folder):
                os.mkdir(inference_folder)

            print("\n Saving inferences at epoch === ",epoch)
            for p_image_loop, gt_img_loop, inp_img_loop in zip(p_img, gt_img, inp_img):

                p_image_loop = np.transpose(p_image_loop, (1, 2, 0))
                cv2.imwrite(os.path.join(inference_folder, "img_"+str(step)+"_pred.png"),p_image_loop.round())

                gt_img_loop = np.transpose(gt_img_loop, (1, 2, 0))
                cv2.imwrite(os.path.join(inference_folder, "img_"+str(step)+"_gt.png"), gt_img_loop)

                inp_img_loop = np.transpose(inp_img_loop, (1, 2, 0))
                cv2.imwrite(os.path.join(inference_folder, "img_"+str(step)+"_inp.png"), inp_img_loop)

    mean_loss = total_loss/N
    mean_dice = total_dice/N
    mean_jacard = total_jacard/N
    mean_acc = total_acc/N

    with open("history/test_logs.txt", "a") as text_file:
        text_file.write("{} {} {} {} {}\n".format(epoch, mean_loss, mean_acc, mean_dice, mean_jacard))

    print("Test Epoch: {} |  Loss: {} | Accuracy: {} | Dice: {} | Jaccard: {}".format(epoch, mean_loss, mean_acc, mean_dice, mean_jacard))
    print("---------------------------------------------------------------------------------------------------------------")



## Function to integrate all the train, validation, test and checkpointing stuffs using epoch numbers

In [68]:

def train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume):

    #PATH TO SAVE THE CHECKPOINT
    checkpoint_path = "checkpoint/checkpoint.pt"

    epoch = 0
    #IF TRAINING IS TO RESUMED FROM A CERTAIN CHECKPOINT
    if resume:
        model, optimizer, epoch = load_ckp(
            checkpoint_path, model, optimizer)

    while epoch <= n_epoch:
        epoch += 1
        model, optimizer = train_epoch(train_loader, model, optimizer, epoch)

        #CHECKPOINT CREATION
        checkpoint = {'epoch': epoch+1, 'state_dict': model.state_dict(),
                      'optimizer': optimizer.state_dict()}

        #CHECKPOINT SAVING
        save_ckp(checkpoint, checkpoint_path)
        print("Checkpoint Saved")

        #CHECKPOINT LOADING
        model, optimizer, epoch = load_ckp(checkpoint_path, model, optimizer)
        print("Checkpoint Loaded")
        with torch.no_grad():
            val_epoch(val_loader, model, optimizer, epoch)

    print("************************ Final Test Epoch *****************************")

    with torch.no_grad():
        test_epoch(test_loader, model, optimizer, epoch)




## Main function to load all the data, set the optimizer etc.

In [69]:
def main():

    train_image_address_list = get_image_address("train_imgs/")
    random.shuffle(train_image_address_list)

    train_loader = load_data(train_image_address_list, mask_path="train_masks/", batch_size=2, num_workers=2, shuffle=True)
    check = iter(train_loader)
    val_image_address_list = get_image_address("val_imgs/")

    random.shuffle(val_image_address_list)
    val_loader = load_data(val_image_address_list, mask_path="val_masks/", batch_size=2, num_workers=2, shuffle=True)
    check = iter(val_loader)

    test_image_address_list = get_image_address("test_imgs/")
    random.shuffle(test_image_address_list)
    test_loader = load_data(test_image_address_list, mask_path="test_masks/", batch_size=2, num_workers=2, shuffle=True)
    check = iter(test_loader)

    print("Train : {} Val : {} ".format(len(train_image_address_list), len(val_image_address_list)))

    #1. CALLING THE MODEL - U-Net, batch_size=2, num_workers=2, lr=1e-04
    model = UNet_Sigmoid(1)
    model = model.to(device)
    summary(model, input_size=(3, 128, 128))

    #DEFINING THE OPTIMIZER
    optimizer = optim.Adam(
        [p for p in model.parameters() if p.requires_grad], lr=1e-04, weight_decay=5e-4)

    n_epoch = 50

    #INDICATOR VARIABLE TO RESUME TRAINING OR START AFRESH
    resume = False
    train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume)


--- Starting the main function ----
Number of Files :  140
Number of Files after removing :  140
Number of Files :  20
Number of Files after removing :  20
Number of Files :  40
Number of Files after removing :  40
Train : 140 Val : 20 
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           1,792
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4         [-1, 64, 128, 128]          36,928
       BatchNorm2d-5         [-1, 64, 128, 128]             128
              ReLU-6         [-1, 64, 128, 128]               0
         MaxPool2d-7           [-1, 64, 64, 64]               0
            Conv2d-8          [-1, 128, 64, 64]          73,856
       BatchNorm2d-9          [-1, 128, 64, 64]             256
             ReLU-10          [-1, 128, 64, 64]           

Epoch: 23 -  Loss: 0.0888863205909729 - Acc: 0.980010986328125 - Dice: 0.9730233251695067 - Jaccard: 0.947463870048523: : 70it [00:08,  7.91it/s]


[Training] Training Epoch: 23 |  Loss: 0.18352265655994415 | Acc: 0.9373699733189174 | Dice: 0.9112663090247424 | Jaccard: 0.846001717022487
Checkpoint Saved
Checkpoint Loaded


Epoch: 24 -  Loss: 0.38152655959129333 - Acc: 0.8269729614257812 - Dice: 0.7513650527023403 - Jaccard: 0.6017491817474365: : 10it [00:00, 20.21it/s]

[Validation] Validation Epoch: 24 |  Loss: 0.37019628286361694 | Acc: 0.8555625915527344 | Dice: 0.8361379039690633 | Jaccard: 0.7276356458663941


---------------------------------------------------------------------------------------------------------------




Epoch: 25 -  Loss: 0.12611716985702515 - Acc: 0.9554443359375 - Dice: 0.910437932317356 - Jaccard: 0.8355998992919922: : 70it [00:08,  8.52it/s]


[Training] Training Epoch: 25 |  Loss: 0.1594979614019394 | Acc: 0.942220197405134 | Dice: 0.9152315358855887 | Jaccard: 0.8487675062247685
Checkpoint Saved
Checkpoint Loaded


Epoch: 26 -  Loss: 0.046911220997571945 - Acc: 0.9877548217773438 - Dice: 0.9269197099092563 - Jaccard: 0.8637934327125549: : 10it [00:00, 19.95it/s]

[Validation] Validation Epoch: 26 |  Loss: 0.18455012142658234 | Acc: 0.9269264221191407 | Dice: 0.8975368200914413 | Jaccard: 0.8175832509994507


---------------------------------------------------------------------------------------------------------------




Epoch: 27 -  Loss: 0.22883456945419312 - Acc: 0.904388427734375 - Dice: 0.7484218539912778 - Jaccard: 0.5979825258255005: : 70it [00:08,  8.58it/s]


[Training] Training Epoch: 27 |  Loss: 0.15649254620075226 | Acc: 0.943410382952009 | Dice: 0.9110251382802573 | Jaccard: 0.8414019192968096
Checkpoint Saved
Checkpoint Loaded


Epoch: 28 -  Loss: 0.5518957376480103 - Acc: 0.7391891479492188 - Dice: 0.7483893099770408 - Jaccard: 0.597940981388092: : 10it [00:00, 20.52it/s]

[Validation] Validation Epoch: 28 |  Loss: 0.22723403573036194 | Acc: 0.9220077514648437 | Dice: 0.9001185539273104 | Jaccard: 0.8224467933177948


---------------------------------------------------------------------------------------------------------------




Epoch: 29 -  Loss: 0.16836152970790863 - Acc: 0.9228439331054688 - Dice: 0.8950832202200628 - Jaccard: 0.8100910782814026: : 70it [00:07,  8.78it/s]


[Training] Training Epoch: 29 |  Loss: 0.16156858205795288 | Acc: 0.9391842433384486 | Dice: 0.9077660170125561 | Jaccard: 0.8381702572107315
Checkpoint Saved
Checkpoint Loaded


Epoch: 30 -  Loss: 0.14536038041114807 - Acc: 0.9379501342773438 - Dice: 0.9324671881811241 - Jaccard: 0.8734787106513977: : 10it [00:00, 18.33it/s]

[Validation] Validation Epoch: 30 |  Loss: 0.189841628074646 | Acc: 0.9202751159667969 | Dice: 0.8801837164878344 | Jaccard: 0.8015475898981095


---------------------------------------------------------------------------------------------------------------




Epoch: 31 -  Loss: 0.14428335428237915 - Acc: 0.9392776489257812 - Dice: 0.9030526670153889 - Jaccard: 0.8232415914535522: : 70it [00:07,  8.79it/s]


[Training] Training Epoch: 31 |  Loss: 0.152564138174057 | Acc: 0.9481772286551339 | Dice: 0.9239071425943928 | Jaccard: 0.8618105888366699
Checkpoint Saved
Checkpoint Loaded


Epoch: 32 -  Loss: 0.13208596408367157 - Acc: 0.9571609497070312 - Dice: 0.94140882165853 - Jaccard: 0.8893035054206848: : 10it [00:00, 20.96it/s]

[Validation] Validation Epoch: 32 |  Loss: 0.25574037432670593 | Acc: 0.8956626892089844 | Dice: 0.8769705302646897 | Jaccard: 0.7916630893945694


---------------------------------------------------------------------------------------------------------------




Epoch: 33 -  Loss: 0.08827266097068787 - Acc: 0.9659194946289062 - Dice: 0.9593439828289281 - Jaccard: 0.9218646287918091: : 70it [00:08,  8.48it/s]


[Training] Training Epoch: 33 |  Loss: 0.1482209414243698 | Acc: 0.9469093322753906 | Dice: 0.9212760091692709 | Jaccard: 0.8578563954148973
Checkpoint Saved
Checkpoint Loaded


Epoch: 34 -  Loss: 0.1305321455001831 - Acc: 0.9526748657226562 - Dice: 0.8722967367025717 - Jaccard: 0.7735161781311035: : 10it [00:00, 20.82it/s]

[Validation] Validation Epoch: 34 |  Loss: 0.187783882021904 | Acc: 0.9243293762207031 | Dice: 0.8855300250799027 | Jaccard: 0.810866829752922


---------------------------------------------------------------------------------------------------------------




Epoch: 35 -  Loss: 0.11251140385866165 - Acc: 0.9586639404296875 - Dice: 0.909929241678467 - Jaccard: 0.8347432613372803: : 70it [00:08,  8.51it/s]


[Training] Training Epoch: 35 |  Loss: 0.15674889087677002 | Acc: 0.9449176243373326 | Dice: 0.9177840900219861 | Jaccard: 0.8511473493916648
Checkpoint Saved
Checkpoint Loaded


Epoch: 36 -  Loss: 0.1074071004986763 - Acc: 0.9509201049804688 - Dice: 0.8677129562339799 - Jaccard: 0.7663365602493286: : 10it [00:00, 17.00it/s]

[Validation] Validation Epoch: 36 |  Loss: 0.19269852340221405 | Acc: 0.908880615234375 | Dice: 0.8874214712221355 | Jaccard: 0.8067354023456573


---------------------------------------------------------------------------------------------------------------




Epoch: 37 -  Loss: 0.07511694729328156 - Acc: 0.9802322387695312 - Dice: 0.9505332078951251 - Jaccard: 0.9057296514511108: : 70it [00:08,  8.70it/s]


[Training] Training Epoch: 37 |  Loss: 0.14370808005332947 | Acc: 0.9485885620117187 | Dice: 0.9203547149082484 | Jaccard: 0.8571001819201878
Checkpoint Saved
Checkpoint Loaded


Epoch: 38 -  Loss: 0.10042503476142883 - Acc: 0.9716339111328125 - Dice: 0.9687301922970918 - Jaccard: 0.9393566846847534: : 10it [00:00, 21.13it/s]


[Validation] Validation Epoch: 38 |  Loss: 0.17204692959785461 | Acc: 0.9302330017089844 | Dice: 0.9002593093723791 | Jaccard: 0.8242182791233063


---------------------------------------------------------------------------------------------------------------



Epoch: 39 -  Loss: 0.14391465485095978 - Acc: 0.9379806518554688 - Dice: 0.9030104516889903 - Jaccard: 0.8231714367866516: : 70it [00:07,  8.79it/s]


[Training] Training Epoch: 39 |  Loss: 0.136617049574852 | Acc: 0.9488369532993861 | Dice: 0.9226427970265094 | Jaccard: 0.8610921604292733
Checkpoint Saved
Checkpoint Loaded


Epoch: 40 -  Loss: 0.15069858729839325 - Acc: 0.933135986328125 - Dice: 0.9483724585836508 - Jaccard: 0.9018140435218811: : 10it [00:00, 20.21it/s]

[Validation] Validation Epoch: 40 |  Loss: 0.16155119240283966 | Acc: 0.9321678161621094 | Dice: 0.8831296604196043 | Jaccard: 0.8032901167869568


---------------------------------------------------------------------------------------------------------------




Epoch: 41 -  Loss: 0.10289975255727768 - Acc: 0.9630813598632812 - Dice: 0.9131186855503426 - Jaccard: 0.8401273488998413: : 70it [00:07,  8.75it/s]


[Training] Training Epoch: 41 |  Loss: 0.15304164588451385 | Acc: 0.9458350045340401 | Dice: 0.9167451456167836 | Jaccard: 0.8523496883256095
Checkpoint Saved
Checkpoint Loaded


Epoch: 42 -  Loss: 0.1621592491865158 - Acc: 0.9424591064453125 - Dice: 0.9155804577133575 - Jaccard: 0.844304621219635: : 10it [00:00, 20.82it/s]

[Validation] Validation Epoch: 42 |  Loss: 0.24151642620563507 | Acc: 0.9024467468261719 | Dice: 0.8854645936714787 | Jaccard: 0.8030771553516388


---------------------------------------------------------------------------------------------------------------




Epoch: 43 -  Loss: 0.09561961889266968 - Acc: 0.9659805297851562 - Dice: 0.9052244566628286 - Jaccard: 0.8268585205078125: : 70it [00:07,  8.82it/s]


[Training] Training Epoch: 43 |  Loss: 0.1416119635105133 | Acc: 0.948189217703683 | Dice: 0.9226854473791649 | Jaccard: 0.8609118419034141
Checkpoint Saved
Checkpoint Loaded


Epoch: 44 -  Loss: 0.12160157412290573 - Acc: 0.972076416015625 - Dice: 0.911588119888454 - Jaccard: 0.8375396728515625: : 10it [00:00, 17.58it/s]

[Validation] Validation Epoch: 44 |  Loss: 0.16445055603981018 | Acc: 0.9347679138183593 | Dice: 0.9122026282993344 | Jaccard: 0.8401764392852783


---------------------------------------------------------------------------------------------------------------




Epoch: 45 -  Loss: 0.11099831759929657 - Acc: 0.9605331420898438 - Dice: 0.9090292690152761 - Jaccard: 0.8332297205924988: : 70it [00:07,  8.77it/s]


[Training] Training Epoch: 45 |  Loss: 0.1513439565896988 | Acc: 0.9413157871791294 | Dice: 0.9142584127075801 | Jaccard: 0.8476653716393879
Checkpoint Saved
Checkpoint Loaded


Epoch: 46 -  Loss: 0.22811681032180786 - Acc: 0.8896255493164062 - Dice: 0.7039673625078736 - Jaccard: 0.5431709885597229: : 10it [00:00, 20.14it/s]

[Validation] Validation Epoch: 46 |  Loss: 0.1787782460451126 | Acc: 0.9307502746582031 | Dice: 0.8976005208142557 | Jaccard: 0.820579445362091


---------------------------------------------------------------------------------------------------------------




Epoch: 47 -  Loss: 0.36446574330329895 - Acc: 0.8256759643554688 - Dice: 0.7402740045779824 - Jaccard: 0.5876468420028687: : 70it [00:08,  8.61it/s]


[Training] Training Epoch: 47 |  Loss: 0.14164027571678162 | Acc: 0.9476068769182477 | Dice: 0.9209722698384976 | Jaccard: 0.8577516734600067
Checkpoint Saved
Checkpoint Loaded


Epoch: 48 -  Loss: 0.15454278886318207 - Acc: 0.934783935546875 - Dice: 0.9133045580661011 - Jaccard: 0.8404420614242554: : 10it [00:00, 20.56it/s]

[Validation] Validation Epoch: 48 |  Loss: 0.26220035552978516 | Acc: 0.8906776428222656 | Dice: 0.8560172409951731 | Jaccard: 0.7694890618324279


---------------------------------------------------------------------------------------------------------------




Epoch: 49 -  Loss: 0.1700117141008377 - Acc: 0.928314208984375 - Dice: 0.9335437437809394 - Jaccard: 0.8753699064254761: : 70it [00:08,  8.36it/s]


[Training] Training Epoch: 49 |  Loss: 0.18005679547786713 | Acc: 0.935740007672991 | Dice: 0.9064158435883276 | Jaccard: 0.83616561123303
Checkpoint Saved
Checkpoint Loaded


Epoch: 50 -  Loss: 0.21636676788330078 - Acc: 0.9084854125976562 - Dice: 0.8674508368161505 - Jaccard: 0.7659277319908142: : 10it [00:00, 21.08it/s]

[Validation] Validation Epoch: 50 |  Loss: 0.17529618740081787 | Acc: 0.9283973693847656 | Dice: 0.8950949684455803 | Jaccard: 0.8175874829292298


---------------------------------------------------------------------------------------------------------------




Epoch: 51 -  Loss: 0.10013596713542938 - Acc: 0.9649276733398438 - Dice: 0.9572773657805719 - Jaccard: 0.9180555939674377: : 70it [00:07,  8.88it/s]


[Training] Training Epoch: 51 |  Loss: 0.1283828616142273 | Acc: 0.9518438066755023 | Dice: 0.9285838672710117 | Jaccard: 0.8691147071974618
Checkpoint Saved
Checkpoint Loaded


Epoch: 52 -  Loss: 0.11723795533180237 - Acc: 0.9663467407226562 - Dice: 0.8954529363829524 - Jaccard: 0.8106969594955444: : 10it [00:00, 20.97it/s]

[Validation] Validation Epoch: 52 |  Loss: 0.1680155098438263 | Acc: 0.9325592041015625 | Dice: 0.9120001256300092 | Jaccard: 0.839715051651001
************************ Final Test Epoch *****************************



[Test] Epoch: 52 -  Loss: 0.052819184958934784 - Acc: 0.981353759765625 - Dice: 0.9322530817683772 - Jaccard: 0.8731030225753784: : 0it [00:00, ?it/s]


 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.1439286470413208 - Acc: 0.9519577026367188 - Dice: 0.8955998668988621 - Jaccard: 0.8109378218650818: : 2it [00:00, 13.01it/s]


 Saving inferences at epoch ===  52

 Saving inferences at epoch === 

[Test] Epoch: 52 -  Loss: 0.050094857811927795 - Acc: 0.9848480224609375 - Dice: 0.9515396123407761 - Jaccard: 0.9075589776039124: : 4it [00:00, 13.29it/s]

 52

 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.25319376587867737 - Acc: 0.8924560546875 - Dice: 0.9346606856167818 - Jaccard: 0.8773361444473267: : 6it [00:00, 13.58it/s]


 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.05726483464241028 - Acc: 0.9775772094726562 - Dice: 0.9760712836306016 - Jaccard: 0.9532609581947327: : 8it [00:00, 13.82it/s]


 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.07344847917556763 - Acc: 0.9754104614257812 - Dice: 0.9567653261350982 - Jaccard: 0.9171141982078552: : 8it [00:00, 13.82it/s]


 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.29391032457351685 - Acc: 0.8756027221679688 - Dice: 0.8205421790194285 - Jaccard: 0.6956943869590759: : 10it [00:00, 14.18it/s]


 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.08151714503765106 - Acc: 0.9737701416015625 - Dice: 0.9530401448909439 - Jaccard: 0.91029292345047: : 12it [00:00, 14.27it/s]


 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.1857537478208542 - Acc: 0.917877197265625 - Dice: 0.8187490281222471 - Jaccard: 0.6931203007698059: : 14it [00:01, 14.43it/s]  


 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.0796670913696289 - Acc: 0.9778976440429688 - Dice: 0.9106221138596098 - Jaccard: 0.8359102010726929: : 18it [00:01, 14.84it/s]


 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52

 Saving inferences at epoch ===  52


[Test] Epoch: 52 -  Loss: 0.10149592906236649 - Acc: 0.9619369506835938 - Dice: 0.9311562092635588 - Jaccard: 0.8711808323860168: : 20it [00:01, 13.96it/s]



 Saving inferences at epoch ===  52
Test Epoch: 52 |  Loss: 0.15378613770008087 | Accuracy: 0.9406269073486329 | Dice: 0.9066531451843179 | Jaccard: 0.8329307198524475
---------------------------------------------------------------------------------------------------------------


## Run the code

In [None]:
if __name__ == "__main__":
    print("--- Starting the main function ----")
    main()