In [None]:
## UNet Model

import torch
import torch.nn as nn

import torch
import torch.nn as nn



def double_convolution(in_channels, out_channels):
    """
    In the original paper implementation, the convolution operations were
    not padded but we are padding them here. This is because, we need the
    output result size to be same as input size.
    """
    conv_op = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels, affine=False, track_running_stats=False),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels, affine=False, track_running_stats=False),
        nn.ReLU(inplace=True)
    )
    return conv_op




class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        # Contracting path.
        # Each convolution is applied twice.
        self.down_convolution_1 = double_convolution(3, 64)
        self.down_convolution_2 = double_convolution(64, 128)
        self.down_convolution_3 = double_convolution(128, 256)
        self.down_convolution_4 = double_convolution(256, 512)
        self.down_convolution_5 = double_convolution(512, 1024)

        # Expanding path.
        self.up_transpose_1 = nn.ConvTranspose2d(
            in_channels=1024, out_channels=512,
            kernel_size=2,
            stride=2)
        # Below, `in_channels` again becomes 1024 as we are concatinating.
        self.up_convolution_1 = double_convolution(1024, 512)
        self.up_transpose_2 = nn.ConvTranspose2d(
            in_channels=512, out_channels=256,
            kernel_size=2,
            stride=2)
        self.up_convolution_2 = double_convolution(512, 256)
        self.up_transpose_3 = nn.ConvTranspose2d(
            in_channels=256, out_channels=128,
            kernel_size=2,
            stride=2)
        self.up_convolution_3 = double_convolution(256, 128)
        self.up_transpose_4 = nn.ConvTranspose2d(
            in_channels=128, out_channels=64,
            kernel_size=2,
            stride=2)
        self.up_convolution_4 = double_convolution(128, 64)
        # output => `out_channels` as per the number of classes.
        self.out = nn.Conv2d(
            in_channels=64, out_channels=3,
            kernel_size=1
        )

    def forward(self, x):
        # TODO: Write here!
        down_1 = self.down_convolution_1(x)
        down_2 = self.max_pool2d(down_1)
        down_3 = self.down_convolution_2(down_2)
        down_4 = self.max_pool2d(down_3)
        down_5 = self.down_convolution_3(down_4)
        down_6 = self.max_pool2d(down_5)
        down_7 = self.down_convolution_4(down_6)
        down_8 = self.max_pool2d(down_7)
        down_9 = self.down_convolution_5(down_8)

        up_1 = self.up_transpose_1(down_9)
        up_2 = self.up_convolution_1(torch.cat([down_7, up_1], 1))
        up_3 = self.up_transpose_2(up_2)
        up_4 = self.up_convolution_2(torch.cat([down_5, up_3], 1))
        up_5 = self.up_transpose_3(up_4)
        up_6 = self.up_convolution_3(torch.cat([down_3, up_5], 1))
        up_7 = self.up_transpose_4(up_6)
        up_8 = self.up_convolution_4(torch.cat([down_1, up_7], 1))

        out = self.out(up_8)

        return out

In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms 
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
import cv2
import albumentations as A

import time
import os
from tqdm.notebook import tqdm


# Define the device to be used for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



## Define diceloss
import torch.nn.functional as F


In [None]:



## Load dataset
import os
from glob import glob
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform_img=None, transform_msk=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform_img = transform_img
        self.transform_msk = transform_msk
        # Define the transformations to be applied to the images and masks
        
        
        self.transform_img = transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0,0,0], std=[1,1,1])

])

        self.transform_msk = transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0,0,0], std=[1,1,1])

])

        self.image_paths = sorted(glob(os.path.join(self.image_dir, '*.png')))
        self.mask_paths = sorted(glob(os.path.join(self.mask_dir, '*.png')))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        if not os.path.exists(img_path) or not os.path.exists(mask_path):
            raise FileNotFoundError(f"Image or mask file not found at index {idx}")

        image = Image.open(img_path).convert("RGB")
        image = self.transform_img(image)

        mask = (Image.open(mask_path))
        mask = self.transform_msk(mask)



        return image, mask




In [None]:

import os
from glob import glob
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from torchvision import transforms
import cv2
import albumentations as album


In [None]:
# Define the paths to the training data
train_image_dir = "/kaggle/input/cvcclinicdb/PNG/Original"
train_mask_dir = "/kaggle/input/cvcclinicdb/PNG/Ground Truth"

# Initialize the full dataset
full_dataset = CustomDataset(train_image_dir, train_mask_dir)

# Define the percentages for splitting (e.g., 80% training, 10% validation, 10% testing)
train_percentage = 0.8
val_percentage = 0.1
test_percentage = 0.1

# Ensure percentages sum to 1
assert train_percentage + val_percentage + test_percentage == 1, "Splitting percentages must sum to 1."

# Calculate lengths for each split
train_size = int(train_percentage * len(full_dataset))
val_size = int(val_percentage * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size  # Remaining samples for the test set

# Split the dataset into training, validation, and testing sets
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# Create DataLoaders for the training, validation, and testing datasets
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)



# Define the training and testing datasets
train_dataset = CustomDataset(train_image_dir, train_mask_dir)
test_dataset = CustomDataset(test_image_dir, test_mask_dir)
val_dataset = CustomDataset(val_image_dir, val_mask_dir)

# Define the training and testing data loaders
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=20, shuffle=False)

In [None]:

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice



In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Assuming UNet is defined and imported already
model = UNet().to(device)

# Setup loss functions and optimizer
loss_fn_1 = DiceLoss()  # Replace with your actual Dice Loss class if needed
loss_fn_2 = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

# Set the seed for reproducibility
torch.manual_seed(42)

# Number of epochs for training
epochs = 300
train_total_losses = []
val_total_losses = []

# Training loop
for epoch in tqdm(range(epochs)):
    train_losses, test_losses = [], []
    print(f"Epoch: {epoch+1} of {epochs}")

    ### Training
    train_loss_1, train_loss_2, train_loss = 0, 0, 0
    model.train()

    for batch, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)

        # Forward pass
        y_pred = model(X)

        # Calculate loss
        loss_1 = loss_fn_1(y_pred, y)
        loss_2 = loss_fn_2(y_pred, y)
        loss = loss_1 + loss_2
        train_loss += loss.item()
        train_loss_1 += loss_1.item()
        train_loss_2 += loss_2.item()

        # Zero gradients, backward pass, and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

    # Average training loss
    train_loss /= len(train_loader)
    train_loss_1 /= len(train_loader)
    train_loss_2 /= len(train_loader)

    ### Validation
    test_loss_1, test_loss_2, test_loss = 0, 0, 0
    model.eval()

    for X, y in val_loader:
        X, y = X.to(device), y.to(device)

        # Forward pass
        y_pred = model(X)

        # Calculate loss
        loss_1 = loss_fn_1(y_pred, y)
        loss_2 = loss_fn_2(y_pred, y)
        loss = loss_1 + loss_2
        test_loss += loss.item()
        test_loss_1 += loss_1.item()
        test_loss_2 += loss_2.item()

        test_losses.append(loss.item())

    # Average validation loss
    test_loss /= len(val_loader)
    test_loss_1 /= len(val_loader)
    test_loss_2 /= len(val_loader)

    # Print out losses for training and validation
    print(f"Train loss: {train_loss:.5f}, Dice: {train_loss_1:.5f}, BCE: {train_loss_2:.5f} | Test loss: {test_loss:.5f}, Dice: {test_loss_1:.5f}, BCE: {test_loss_2:.5f}\n")

    train_loss = np.average(train_losses)
    train_total_losses.append(train_loss)
    val_loss = np.average(test_losses)
    val_total_losses.append(test_loss)

    # Display predictions at the end of every 5 epochs
    if epoch % 5 == 0:
        # Display a batch of images and corresponding predicted masks
        plt.figure(figsize=(15, 5))
        plt.subplot(131)
        plt.imshow(X[0, 0].cpu().detach().numpy(), cmap='gray')
        plt.title("Input Image")
        plt.axis('off')

        plt.subplot(132)
        plt.imshow(y[0, 0].cpu().detach().numpy(), cmap='gray')
        plt.title("Ground Truth Mask")
        plt.axis('off')

        plt.subplot(133)
        # Apply a threshold to the predicted mask for visualization
        y_pred_mask = torch.sigmoid(y_pred[0, 0]).cpu().detach().numpy()  # Use sigmoid to convert logits to probabilities
        plt.imshow(y_pred_mask, cmap='gray')
        plt.title("Predicted Mask")
        plt.axis('off')

        plt.show()

    # Save the model and plot the losses every 5 epochs
    if epoch % 5 == 0 and epoch != 0:
        torch.save(model.state_dict(), f"./model-{epoch}.pth")
        plt.figure(figsize=(20, 5))
        plt.subplot(1, 2, 1)
        plt.plot(train_total_losses, label='train_loss')
        plt.plot(val_total_losses, label='val_loss')
        plt.title("Training & Validation Losses")
        plt.ylabel("Loss")
        plt.xlabel("Epochs")
        plt.legend()
        plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

# Directory to save predicted masks if needed
output_dir = "predicted_masks"
os.makedirs(output_dir, exist_ok=True)

# Set the model to evaluation mode
model.eval()

# Iterate through the entire test dataset
with torch.no_grad():
    for batch_idx, (images, gt_masks) in enumerate(test_loader):
        # Move data to the device
        images = images.to(device)
        gt_masks = gt_masks.to(device)

        # Forward pass
        logits = model(images)
        pr_masks = (logits.squeeze(1) > 0.5).float()

        # Process each image in the batch
        for i, (image, gt_mask, pr_mask) in enumerate(zip(images, gt_masks, pr_masks)):
            # Convert tensors to NumPy arrays
            image_np = image.permute(1, 2, 0).cpu().numpy()
            gt_mask_np = gt_mask.squeeze().cpu().numpy()
            pr_mask_np = pr_mask.cpu().numpy()

            # Display the results
            plt.figure(figsize=(15, 5))

            plt.subplot(1, 3, 1)
            plt.imshow(image_np)
            plt.title("Image")
            plt.axis("off")

            plt.subplot(1, 3, 2)
            plt.imshow(gt_mask_np, cmap='gray')
            plt.title("Ground Truth")
            plt.axis("off")

            plt.subplot(1, 3, 3)
            plt.imshow(pr_mask_np, cmap='gray')
            plt.title("Prediction")
            plt.axis("off")

            plt.show()

            # Optionally save the predicted mask
            pr_mask_filename = os.path.join(output_dir, f"pred_mask_batch{batch_idx}_img{i}.png")
            plt.imsave(pr_mask_filename, pr_mask_np, cmap='gray')


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
import numpy as np
from PIL import Image


with torch.inference_mode():
    for X, y in test_loader:
        #
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)


y1 = y[0, 0].cpu().detach().numpy()
y2 = y_pred[0, 0].cpu().detach().numpy()

plt.subplot(121)
plt.imshow(y1, cmap='gray')

plt.subplot(122)
plt.imshow(y2, cmap='gray')
plt.show()

y_color = np.zeros((*y2.shape, 3))
y_color[..., 0] = y1
y_color[..., 1] = y2
y_color[..., 2] = y2


plt.subplot(121)
plt.imshow(X[0, 0].cpu().detach().numpy())

plt.subplot(122)
plt.imshow(y_color)



import torch
import matplotlib.pyplot as plt
import numpy as np

# Assuming 'test_loader' is your DataLoader
batch = next(iter(test_loader))

with torch.no_grad():
    model.eval()
    logits = model(batch[0].to(device))
pr_masks = (logits.squeeze(1) > 0.5).float()

for image, gt_mask, pr_mask in zip(batch[0], batch[1], pr_masks):
    plt.figure(figsize=(15, 5))  # Increase the width for better visualization

    plt.subplot(1, 3, 1)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())  # Use permute for CHW to HWC conversion
    plt.title("Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    # Convert RGB to grayscale for ground truth mask
    grayscale_gt_mask = np.mean(gt_mask.squeeze().cpu().numpy(), axis=0)
    plt.imshow(grayscale_gt_mask, cmap='gray')
    plt.title("Ground truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    # Convert RGB to grayscale for prediction mask
    if pr_mask.shape[0] == 3:
        grayscale_pr_mask = np.mean(pr_mask.cpu().numpy(), axis=0)
        plt.imshow(grayscale_pr_mask, cmap='gray')
    else:
        plt.imshow(pr_mask.cpu().numpy(), cmap='gray')
    
    plt.title("Prediction")
    plt.axis("off")
    plt.show()





# Assuming 'test_loader' is your DataLoader
batch = next(iter(test_loader))

with torch.no_grad():
    model.eval()
    logits = model(batch[0].to(device))  # Assuming image is at index 0
pr_masks = (logits.squeeze(1) > 0.5).float()

# Iterate through the batches
for image, gt_mask, pr_mask in zip(batch[0], batch[1], pr_masks):
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    plt.title("Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    grayscale_gt_mask = np.mean(gt_mask.squeeze().cpu().numpy(), axis=0)
    plt.imshow(grayscale_gt_mask, cmap='gray')
    plt.title("Ground truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    if pr_mask.shape[0] == 3:
        grayscale_pr_mask = np.mean(pr_mask.cpu().numpy(), axis=0)
        plt.imshow(grayscale_pr_mask, cmap='gray')
    else:
        plt.imshow(pr_mask.cpu().numpy(), cmap='gray')
    
    plt.title("Prediction")
    plt.axis("off")
    plt.show()