In [1]:
import torch
from torch import nn
from torchvision import transforms
from torchinfo import summary
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import tqdm

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(DEVICE)

PIN_MEMORY = True if DEVICE == "cuda" else False

In [3]:
DATASET_PATH = ""
MASK_DATASET_PATH = ""

In [4]:
# Here we set all the hyper parameters

NUM_CHANNELS = 1
NUM_CLASSES = 1
NUM_LEVELS = 3

LR = 0.0001
NUM_EPOCHS = 40
BATCH_SIZE = 8

INPUT_IMAGE_HEIGHT = 512
INPUT_IMAGE_WIDTH = 512

THRESHOLD = 0.5

BASE_OUTPUT = "output"

MODEL_PATH = ""
PLOT_PATH = ""
TEST_PATH = ""

In [5]:
# Creating the custom segmentation dataset class
from torch.utils.data import Dataset
def cv2():
    def imread(i):
        pass

    def cvtColor(i, s):
        pass

    def imread(i, s):
        pass
    
class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, maskPaths, transforms):
        super().__init__()
        
        self.imagePaths = imagePaths
        self.maskPaths = imagePaths
        self.transforms = transforms

        def __len__(self):
            return len(self.imagePath)

        def __getitem__(self, idx):
            imagePath = self.imagePaths[idx]

            image = cv2.imread(imagePath)
            image = cv2.cvtColor(image,  cv2.COLOR_BGR2RGB)
            mask = cv2.imread(self.maskPaths[idx], 0)

            if self.transforms is not None:
                image = self.transforms(image)
                mask = self.transforms(mask)

            return image, mask



In [6]:
class litsUnet(nn.Module):
    def __init__(self, n_class: tuple):
        super().__init__()

        # Encoder
        # input: 572X572X3
        self.encoder11 = nn.Conv2d(3, 64, kernel_size= 3, padding= 1)
        self.encoder12 = nn.Conv2d(64, 64, kernel_size= 3, padding= 1)
        self.pool1 = nn.MaxPool2d(kernel_size= 2, stride= 2)

        self.encoder21 = nn.Conv2d(64, 128, kernel_size= 3, padding= 1)
        self.encoder22 = nn.Conv2d(128, 128, kernel_size= 2, padding= 1)
        self.pool2 = nn.MaxPool2d(kernel_size= 2, stride= 2)

        self.encoder31 = nn.Conv2d(128, 256, kernel_size= 3, padding= 1)
        self.encoder32 = nn.Conv2d(256, 256, kernel_size= 3, padding= 1)
        self.pool3 = nn.MaxPool2d(kernel_size= 3, stride= 2)

        self.encoder41 = nn.Conv2d(256, 512, kernel_size= 3, padding= 1)
        self.encoder42 = nn.Conv2d(512, 512, kernel_size= 3, padding= 1)
        self.pool4 = nn.MaxPool2d(kernel_size= 2, stride= 2)

        self.encoder51 = nn.Conv2d(512, 1024, kernel_size= 3, padding= 1)
        self.encoder52 = nn.Conv2d(1024, 1024, kernel_size= 3, padding= 1)

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size= 2, stride= 2)
        self.decoder11 = nn.Conv2d(1024, 512, kernel_size= 3, padding= 1)
        self.decoder12 = nn.Conv2d(512, 512, kernel_size= 3, padding= 1)

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size= 2, stride= 2)
        self.decoder21 = nn.Conv2d(512, 256, kernel_size= 3, padding= 1)
        self.decoder22 = nn.Conv2d(256, 256, kernel_size= 3, padding= 1)

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size= 2, stride= 2)
        self.decoder31 = nn.Conv2d(256, 128, kernel_size= 2, padding= 1)
        self.decode32 = nn.Conv2d(128, 128, kernel_size= 2, padding= 1)

        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size= 2, stride= 2)
        self.decoder41 = nn.Conv2d(128, 64, kernel_size= 2, padding= 1)
        self.decoder42 = nn.Conv2d(64, 64, kernel_size= 2, padding= 1)

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size= 1)

    def forward(self, x):
        # Encoder
        x_encoder11 = nn.ReLU(self.encoder11(x))
        x_encoder12 = nn.ReLU(self.encoder12(x_encoder11))
        x_pool1 = self.pool1(xencoder12)

        x_encoder21 = nn.ReLU(self.encoder_21(x_pool1))
        x_encoder22 = nn.ReLU(self.encoder_22(x_encoder21))
        x_pool2 = self.pool2(x_encoder22)

        x_encoder31 = nn.ReLu(self.encoder31(x_pool2))
        x_encoder32 = nn.ReLU(self.encoder32(x_encoder31))
        x_pool3 = self.pool3(x_encoder32)

        x_encoder41 = nn.ReLU(self.encoder41(x_pool3))
        x_encoder42 = nn.ReLU(self.encoder42(x_encoder41))
        x_pool4 = self.pool4(x_encoder42)

        x_encoder51 = nn.ReLU(self.encoder51(x_pool4))
        x_encoder52 = nn.ReLU(self.encoder52(x_encoder51))

        # Decoder

        xu1 = self.upconv(x_encoder52)
        xu11 = torch.cat([xu1, x_encoder42], dim= 1)
        x_decoder11 = nn.ReLU(self.decoder11(xu11))
        x_decoder12 = nn.ReLU(self.decoder12(x_decoder11))

        xu2 = self.upconv2(x_decoder12)
        xu22 = torch.cat([xu2, x_encoder32], dim= 1)
        x_decoder21 = nn.ReLU(self.decoder21(xu22))
        x_decoder22 = nn.ReLU(self.decoder22(x_decoder21))

        xu3 = self.upconv2(x_decoder22)
        xu33 = torch.cat([xu3, x_encoder22], dim= 1)
        x_decoder31 = nn.ReLU(self.decoder31(xu33))
        x_decoder32 = nn.ReLU(self.decoder32(x_decoder31))

        xu4 = self.upconv4(x_decoder32)
        xu44 = torch.cat([xu4, x_encoder12], dim = 1)
        x_decoder41 = nn.ReLU(self.decoder41(xu44))
        x_decoder42 = nn.ReLU(self.decoder42(x_decoder41))

        # Output layer
        out = self.outconv(x_decoder42)

        return out

In [7]:
train_images = ""
train_masks= ""
test_images = ""
test_masks = ""

transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH)),
                                transforms.ToTensor()
                                ])

train_dataset = SegmentationDataset(imagePaths = train_images, maskPaths = train_masks, transforms = transforms)

test_dataset = SegmentationDataset(imagePaths = test_images, maskPaths = test_masks, transforms = transforms)