# U-Net Model Notebook

In [15]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe1 = self.encoder[0](x)
        xe1 = self.encoder[1](xe1)
        xe1 = self.encoder[2](xe1)
        xe2 = self.encoder[3](xe1)
        xe2 = self.encoder[4](xe2)
        xp1 = self.encoder[5](xe2)

        xe3 = self.encoder[6](xp1)
        xe3 = self.encoder[7](xe3)
        xe4 = self.encoder[8](xe3)
        xe4 = self.encoder[9](xe4)
        xp2 = self.encoder[10](xe4)

        xe5 = self.encoder[11](xp2)
        xe5 = self.encoder[12](xe5)
        xe6 = self.encoder[13](xe5)
        xe6 = self.encoder[14](xe6)
        xp3 = self.encoder[15](xe6)

        xe7 = self.encoder[16](xp3)
        xe7 = self.encoder[17](xe7)
        xe8 = self.encoder[18](xe7)
        xe8 = self.encoder[19](xe8)
        xp4 = self.encoder[20](xe8)

        xe9 = self.encoder[21](xp4)
        xe9 = self.encoder[22](xe9)
        xe10 = self.encoder[23](xe9)
        xe10 = self.encoder[24](xe10)

        # Decoder
        xu1 = self.decoder[0](xe10)
        xu1 = self.decoder[1](xu1)
        xu2 = self.decoder[2](torch.cat([xu1, xe8], dim=1))
        xu2 = self.decoder[3](xu2)
        xu3 = self.decoder[4](xu2)
        xu3 = self.decoder[5](torch.cat([xu3, xe6], dim=1))
        xu4 = self.decoder[6](xu3)
        xu4 = self.decoder[7](xu4)
        xu5 = self.decoder[8](xu4)
        xu5 = self.decoder[9](torch.cat([xu5, xe4], dim=1))
        xu6 = self.decoder[10](xu5)
        xu6 = self.decoder[11](xu6)
        xu7 = self.decoder[12](xu6)
        xu7 = self.decoder[13](torch.cat([xu7, xe2], dim=1))
        xu8 = self.decoder[14](xu7)
        xu8 = self.decoder[15](xu8)

        # Output layer
        out = self.outconv(xu8)

        return out


In [45]:
import os
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class SegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None, subset_size=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.image_names = os.listdir(images_dir)[:subset_size]  # Use a subset of images if subset_size is provided
    
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.images_dir, img_name)
        
        # Adjust the mask filename here
        mask_name = img_name.replace(".png", "_mask.png")  # Example adjustment
        mask_path = os.path.join(self.masks_dir, mask_name)
        
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)  # Read RGB image
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Read grayscale mask
        
        # Apply transformations if specified
        if self.transform:
            img = self.transform(img)
            # Convert mask to binary format (0s and 1s)
            mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # Assuming single-channel mask
        
        return img, mask