In [None]:
import os
import numpy as np
import time
import cv2
import random
from operator import add
from glob import glob
from tqdm import tqdm
import imageio
from albumentations import HorizontalFlip, VerticalFlip, Rotate
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
def load_data(data_path):
    train_imaegs = sorted(glob(os.path.join(data_path, "training", "images", "*.tif")))
    train_masks = sorted(glob(os.path.join(data_path, "training", "1st_manual", "*.gif")))

    test_images = sorted(glob(os.path.join(data_path, "test", "images", "*.tif")))
    test_masks = sorted(glob(os.path.join(data_path, "test", "1st_manual", "*.gif")))

    return (train_imaegs, train_masks), (test_images, test_masks)

In [None]:
def data_augmenteation(images, masks, save_path, augment=True):
    size = (512, 512)

    for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
        name = x.split("\\")[-1].split(".")[0]
        x = cv2.imread(x, cv2.IMREAD_COLOR)
        y = imageio.mimread(y)[0]
        print(f"X shape : {x.shape}, y shape :{y.shape}")

        index = 0
        if augment == True:
            
            aug = HorizontalFlip(p=1.0)
            augmented = aug(image=x, mask=y)
            x1 = augmented["image"]
            y1 = augmented["mask"]

            aug = VerticalFlip(p=1.0)
            augmented = aug(image=x, mask=y)
            x2 = augmented["image"]
            y2 = augmented["mask"]

            aug = Rotate(limit=45, p=1.0)
            augmented = aug(image=x, mask=y)
            x3 = augmented["image"]
            y3 = augmented["mask"]

            X = [x, x1, x2, x3]
            Y = [y, y1, y2, y3]

        else:
            X = [x]
            Y = [y]

        for i, m in zip(X, Y):
            i = cv2.resize(i, size)
            m = cv2.resize(m, size)

            tem_image_name = f"{name}_{index}.png"
            tem_mask_name = f"{name}_{index}.png"

            image_path = os.path.join(save_path, "Training_augmented_data", "images", tem_image_name)
            mask_path = os.path.join(save_path, "Training_augmented_data", "masks", tem_mask_name)

            images = cv2.imwrite(image_path, i)
            masks = cv2.imwrite(mask_path, m)

            index += 1
            
        
    return images, masks

# Building the UNET

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, input_channel, output_channel):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(output_channel)

        self.conv2 = nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(output_channel)
        
        self.activation1 = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.activation1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation1(x)
        
        return x

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, input_channel, output_channel):
        super().__init__()
        self.conv = ConvBlock(input_channel, output_channel)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_channel, output_channel):
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(input_channel, output_channel, kernel_size=2, stride=2, padding=0)
        self.conv = ConvBlock(output_channel + output_channel, output_channel)

    def forward(self, inputs, skip_connection):
        x = self.conv_transpose(inputs)
        x = torch.cat([x, skip_connection], axis=1)
        x = self.conv(x)

        return x
    

In [None]:
class BuildUNet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = EncoderBlock(3, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256, 512)

        """ Bottleneck """
        self.b = ConvBlock(512, 1024)

        """ Decoder """
        self.decoder = Decoder(1024, 512)
        self.decoder2 = Decoder(512, 256)
        self.decoder3 = Decoder(256, 128)
        self.decoder4 = Decoder(128, 64)

        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        """ Bottleneck """
        b = self.b(p4)
        # print(s1.shape, s2.shape, s3.shape, s4.shape)
        # print(b.shape)

        d1 = self.decoder(b, s4)
        d2 = self.decoder2(d1, s3)
        d3 = self.decoder3(d2, s2)
        d4 = self.decoder4(d3, s1)
        #print(d1.shape)
        print(d4.shape)

        outputs = self.outputs(d4)

        return outputs

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)

        # Flatten label and prediction tensors
        inputs = inputs.veiw(-1)
        targets = targets.veiw(-1)

        intersection = (inputs * traget).sum()
        dice = (2. * intersection + smooth) / ( inputs.sum() + targets.sum() + smooth)
        return 1 - dice

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)

        # Flatten label and prediction tensors
        inputs = inputs.veiw(-1)
        targets = targets.veiw(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum())
        BCE = F.binary_cross_entropy(inputs, targets, reduction="mean")
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

In [None]:
class DriveDataset(Dataset):
    def __init__(self, image_path, mask_path):
        self.image_path = image_path
        self.mask_path = mask_path
        self.n_samples = len(image_path)

    def __getitem__(self, index):
        """ Reading Image """
        image = cv2.imread(self.image_path[index], cv2.IMREAD_COLOR)
        image = image / 255.0
        image = np.transpose(image, (2, 0, 1))
        image = image.astype(np.float32)
        image = torch.from_numpy(image)

        """ Reading Mask """
        mask = cv2.imread(self.mask_path[index], cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0

        return image, mask

In [None]:
def __len__(self):
    return self.n_samples

In [None]:
def seeding(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_time, elapsed_secs

In [None]:
if __name__ == "__main__":
    x = torch.randn((2, 3, 512, 512))
    print(x.shape)
    f = BuildUNet()
    y = f(x)
    print(y.shape)

In [None]:
if __name__ == "__main__":
    x = torch.randn((2, 32, 128, 128))
    f = ConvBlock(32, 64)
    y = f(x)
    print(y.shape)

    e = EncoderBlock(32, 64)
    encoder, decoder = e(x)
    print(encoder.shape, decoder.shape)

In [None]:
data_path = r"E:\python\segmentation\Computer Vision\UNET\data\blood\DRIVE"

In [None]:
if __name__ == "__main__":

    """ Seeding """
    np.random.seed(42)

    """ Load the data """
    data_path = r"E:\python\segmentation\Computer Vision\UNET\data\blood\DRIVE"
    (train_imaegs, train_masks), (test_images, test_masks) = load_data(data_path)
    print(f"Training images : {len(train_imaegs)}, Training masks : {len(train_masks)}")
    print(f"Test images : {len(test_images)}, Test masks : {len(test_masks)}")

    """ Creating some folders for saving the augmented dataset """
    create_dir(data_path + "\\Training_augmented_data")
    create_dir(data_path + "\\Training_augmented_data" + "\\images")
    create_dir(data_path + "\\Training_augmented_data" + "\\masks")
    data_augmenteation(train_imaegs, train_masks, data_path, augment=True)

In [None]:
def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss = loss.item()

    epoch_loss = epoch_loss / len(loader) 
    return epoch_loss

In [None]:
def evaluate(model, val_loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)

        epoch_loss = epoch_loss / len(loader) 
    return epoch_loss

In [None]:
if __name__ == "__main__":
    """ Seeding """
    seeding(42)

    """ Directories """
    create_dir(data_path + "\\files")

    """ Loading dataset """
    augmented_data_path = r"E:\python\segmentation\Computer Vision\UNET\data\blood\DRIVE"
    X_train = sorted(glob(augmented_data_path + "\\Training_augmented_data\\images\\*"))
    y_train = sorted(glob(augmented_data_path + "\\Training_augmented_data\\masks\\*"))

    X_val = sorted(glob(data_path + "\\test\\images\\*"))
    y_val = sorted(glob(data_path + "\\test\\masks\\*"))
    print(f"Dataset Size:\n Train: {len(X_train)} \n Valid : {len(X_val)}")

    """ Hyperparameters """
    H, W = 512, 512
    size = (H, W)
    batch_size = 2
    num_epochs = 52
    lr = 1e-4
    checkpoint_path = data_path + "\\files\\Retina_Blood_Vessel_with_pytorch.pth"

    """ Dataset and Loader """
    train_dataset = DriveDataset(X_train, y_train)
    val_dataset = DriveDataset(X_val, y_val)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )

    device = torch.device("cuda")
    model = BuildUNet()
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5, verbose=True)

    loss_fn = DiceBCELoss()

    """ Training The Model """

    best_valid_loss = float("inf")
    for epoch in range(num_epochs):
        start_time = time.time()

        train_loss = train(model, train_loader, optimizer, loss_fn, device)
        val_loss = evaluate(model, val_loader, loss_fn, device)

        """ Saving the Model """
        if val_loss < best_valid_loss:
            data_str = f"Valid loss improve from {best_valid_loss:2.4f} to {val_loss:2.4f}"
            print(data_str)

            best_valid_loss = val_loss
            torch.save(model.state_dict(), checkpoint_path)
        
        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        data_str = f"Epoch: {epoch+1:02} | Epoch Time : {epoch_mins}m {epoch_secs}"
        data_str = f"\tTrain loss: {train_loss:.3f}\n"
        data_str = f"\t Val loss : {val_loss:.3f}\n"
        print(data_str)

In [None]:
def calculate_metrics(y_true, y_pred):
    """ Ground Truth """
    y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    """ Prediction """
    y_pred = y_pred.cpu().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    jaccard = jaccard_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    accuracy = accuracy_score(y_true, y_pred)

    return [jaccard, f1, recall, precision, accuracy]

In [None]:
def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    mask = np.concatenate([mask, mask, mask], axis=-1)
    return mask

In [None]:
if __name__ == "__main__":
    """ Seeding """
    seeding(42)

    """ Folders """
    create_dir(data_path + "\\results")

    """ Load dataset """
    X_test = sorted(glob(data_path + "\\test\\image\\*"))
    y_test = sorted(glob(data_path + "\\test\\mask\\*"))

    """ Hyperparameters """
    H, W = 512, 512
    size = (H, W)
    checkpoint_path = data_path + "\\files\\checkpoint.pth"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = BuildUNet()
    model = mode.to(device)
    mode.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()

    metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0]
    time_taken = []

    for i, (x, y) in tqdm(enumerate(zip(X_test, y_test)), total=len(X_test)):
        """ Extracting the Name """
        name = y.split("\\")[-1].split(".")[0]

        """ Reading Images """
        image = cv2.imread(x, cv2.IMREAD_COLOR)
        # image = cv2.reaise(image, size)
        x = np.transpose(image, (2, 0, 1))
        x = x / 255.0
        x = np.expand_dims(x, axis=0)
        x = x.astype(np.float32)
        x = torch.from_numpy(x)
        x = x.to(device)

        """ Reading Masks """
        mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
        #mask = cv2.resize(mask, size)
        y = np.expand_dims(y, axis=0)
        y = y / 255.0
        y = np.expand_dims(y, axis=0)
        y = y.astype(np.float32)
        y = torch.from_numpy(y)
        y = y.to(device)

        with torch.no_grad():
            """ Prediction and Calculation FPS """
            start_time = time.time()
            y_pred = model(x)
            y_pred = torch.sigmoid(y_pred)
            total_time = time.time() - start_time
            time_taken.append(total_time)

            score = calculate_metrics(y, y_pred)
            metrics_score = list(map(add, metrics_score, score))
            y_pred = y_pred[0].cpu().numpy()
            y_pred = np.squeeze(y_pred, axis=0)
            y_pred = y_pred > 0.5
            y_pred = np.array(y_perd, dtype=np.uint8)

        """ Saving masks """
        ori_mask = mask_parse(mask)
        y_pred = mask_parse(y_pred)
        line = np.ones((size[1], 10, 3)) * 128

        cat_images = np.concatenate(
            [image, line, ori_mask, line, y_pred * 255], axis=1
        )
        cv2.imwrite(data_path + f"\\results\\{name}.pnt", cat_images)

    jaccard = metrics_score[0] / len(X_test)
    f1 = metrics_score[1] / len(X_test)
    recall = metrics_score[2] / len(X_test)
    precision = metrics_score[3] / len(X_test)
    acc = metrics_score[4] / len(X_test)
    print(f"Jaccard: {jaccard:1.4f} - F1:{f1:1.4f} - Recall:{recall:1.4f} - Precision:{precision:1.4f} - Accuracy:{acc:1.4f}")

    fps = 1 / np.mean(time_taken)
    print("FPS: ", fps)