<a href="https://colab.research.google.com/github/dusarp/deep-learning-experiments/blob/main/basic_segmentation_unet_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

U-net

In [None]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Encoder
        self.encoder1 = self.conv_block(in_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.decoder4 = self.upconv_block(1024, 512)
        self.decoder3 = self.upconv_block(512, 256)
        self.decoder2 = self.upconv_block(256, 128)
        self.decoder1 = self.upconv_block(128, 64)

        # Final Convolution
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(nn.MaxPool2d(kernel_size=2)(enc1))
        enc3 = self.encoder3(nn.MaxPool2d(kernel_size=2)(enc2))
        enc4 = self.encoder4(nn.MaxPool2d(kernel_size=2)(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(nn.MaxPool2d(kernel_size=2)(enc4))

        # Decoder
        dec4 = self.decoder4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)  # Skip connection
        dec3 = self.decoder3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)  # Skip connection
        dec2 = self.decoder2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)  # Skip connection
        dec1 = self.decoder1(dec2)

        return self.final_conv(dec1)

# Example usage:
model = UNet(in_channels=3, out_channels=1)  # For binary segmentation

Data preparation

In [None]:
from torchvision import transforms
from PIL import Image
import os

class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, masks_dir):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.images = os.listdir(images_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.images[idx])
        mask_path = os.path.join(self.masks_dir, self.images[idx].replace('.jpg', '_mask.png'))

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # Assuming mask is grayscale

        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

        image = transform(image)
        mask = transform(mask)

        return image, mask

# Example usage:
dataset = SegmentationDataset(images_dir='path/to/images', masks_dir='path/to/masks')

Training loop

In [None]:
import torch.optim as optim

def train_model(model, dataset):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
    criterion = nn.BCEWithLogitsLoss()  # Change based on your output classes
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(10):  # Number of epochs
        for images, masks in dataloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.squeeze(), masks)  # Adjust based on output shape
            loss.backward()
            optimizer.step()
            print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

# Example usage:
train_model(model=model, dataset=dataset)

Inference

In [None]:
def predict(model, input_image):
    model.eval()
    with torch.no_grad():
        output = model(input_image.unsqueeze(0))  # Add batch dimension
    return output.squeeze().numpy()

# Example usage:
# Assuming input_image is a preprocessed image tensor of shape (C,H,W)
predicted_mask = predict(model=model, input_image=input_image)