# U-Net Model Notebook

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImprovedUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        
        # Encoder
        self.encoder1 = self.conv_block(3, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)
        self.encoder5 = self.conv_block(512, 1024)

        # Decoder
        self.upconv1 = self.upconv_block(1024, 512)
        self.decoder1 = self.conv_block(1024, 512)
        self.upconv2 = self.upconv_block(512, 256)
        self.decoder2 = self.conv_block(512, 256)
        self.upconv3 = self.upconv_block(256, 128)
        self.decoder3 = self.conv_block(256, 128)
        self.upconv4 = self.upconv_block(128, 64)
        self.decoder4 = self.conv_block(128, 64)

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        xe1 = self.encoder1(x)
        xe2 = self.encoder2(xe1)
        xe3 = self.encoder3(xe2)
        xe4 = self.encoder4(xe3)
        xe5 = self.encoder5(xe4)

        # Decoder
        xu1 = self.upconv1(xe5)
        xu1 = torch.cat([xu1, xe4], dim=1)
        xu1 = self.decoder1(xu1)

        xu2 = self.upconv2(xu1)
        xu2 = torch.cat([xu2, xe3], dim=1)
        xu2 = self.decoder2(xu2)

        xu3 = self.upconv3(xu2)
        xu3 = torch.cat([xu3, xe2], dim=1)
        xu3 = self.decoder3(xu3)

        xu4 = self.upconv4(xu3)
        xu4 = torch.cat([xu4, xe1], dim=1)
        xu4 = self.decoder4(xu4)

        # Output layer
        out = self.outconv(xu4)
        return out


In [8]:
import os
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class SegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None, subset_size=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.image_names = os.listdir(images_dir)[:subset_size]  # Use a subset of images if subset_size is provided
    
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.images_dir, img_name)
        
        mask_name = img_name.replace(".png", "_mask.png")  # Adjust this based on your mask naming convention
        mask_path = os.path.join(self.masks_dir, mask_name)
        
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)  # Read RGB image
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Read grayscale mask
        
        # Apply transformations if specified
        if self.transform:
            img = self.transform(img)
            mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # Assuming single-channel mask
        
        return img, mask
