In [1]:
import torch.nn as nn
import torch
import os

In [2]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(UNet, self).__init__()
        
        # Encoder blocks
        self.block1 = nn.Sequential()
        self.block1.add_module('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1))
        self.block1.add_module('relu1', nn.ReLU())   
        self.block1.add_module('conv2', nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1))
        self.block1.add_module('relu2', nn.ReLU())

        self.block2 = nn.Sequential()
        self.block2.add_module('maxpool1', nn.MaxPool2d(kernel_size=2, stride=2))
        self.block2.add_module('conv3', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1))
        self.block2.add_module('relu3', nn.ReLU())
        self.block2.add_module('conv4', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1))
        self.block2.add_module('relu4', nn.ReLU())

        self.block3 = nn.Sequential()
        self.block3.add_module('maxpool2', nn.MaxPool2d(kernel_size=2, stride=2))
        self.block3.add_module('conv5', nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1))
        self.block3.add_module('relu5', nn.ReLU())
        self.block3.add_module('conv6', nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1))
        self.block3.add_module('relu6', nn.ReLU())

        self.block4 = nn.Sequential()
        self.block4.add_module('maxpool3', nn.MaxPool2d(kernel_size=2, stride=2))
        self.block4.add_module('conv7', nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1))
        self.block4.add_module('relu7', nn.ReLU())
        self.block4.add_module('conv8', nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1))
        self.block4.add_module('relu8', nn.ReLU())
        
        self.block5 = nn.Sequential()
        self.block5.add_module('maxpool4', nn.MaxPool2d(kernel_size=2, stride=2))
        self.block5.add_module('conv9', nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1))
        self.block5.add_module('relu9', nn.ReLU())
        self.block5.add_module('conv10', nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1))
        self.block5.add_module('relu10', nn.ReLU())
        self.block5.add_module('conv11', nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2))
        self.block5.add_module('relu10', nn.ReLU())

        # Decoder blocks
        self.block6 = nn.Sequential()
        self.block6.add_module('conv12', nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1))
        self.block6.add_module('relu12', nn.ReLU())
        self.block6.add_module('conv13', nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1))
        self.block6.add_module('relu13', nn.ReLU())
        self.block6.add_module('conv14', nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2))
        self.block6.add_module('relu14', nn.ReLU())

        self.block7 = nn.Sequential()
        self.block7.add_module('conv15', nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1))
        self.block7.add_module('relu15', nn.ReLU())
        self.block7.add_module('conv16', nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1))
        self.block7.add_module('relu16', nn.ReLU())
        self.block7.add_module('conv17', nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2))
        self.block7.add_module('relu17', nn.ReLU())

        self.block8 = nn.Sequential()
        self.block8.add_module('conv18', nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1))
        self.block8.add_module('relu18', nn.ReLU())
        self.block8.add_module('conv19', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1))
        self.block8.add_module('relu19', nn.ReLU())
        self.block8.add_module('conv20', nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2))
        self.block8.add_module('relu20', nn.ReLU())
  
        self.block9 = nn.Sequential()
        self.block9.add_module('conv21', nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1))
        self.block9.add_module('relu21', nn.ReLU())
        self.block9.add_module('conv22', nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1))
        self.block9.add_module('relu22', nn.ReLU())
        self.block9.add_module('conv23', nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1)) # Conv1-1
        

    def forward(self, x):
        # Encoder
        features1 = self.block1(x) 
        features2 = self.block2(features1)
        features3 = self.block3(features2)
        features4 = self.block4(features3)
        features5 = self.block5(features4)

        # Decoder
        concat5 = torch.cat([features4, features5], dim=1)
        features6 = self.block6(concat5)
        concat6 = torch.cat([features3, features6], dim=1)
        features7 = self.block7(concat6)
        concat7 = torch.cat([features2, features7], dim=1)
        features8 = self.block8(concat7)
        concat8 = torch.cat([features1, features8], dim=1)
        mask = self.block9(concat8)

        # Permutate mask from [bN,ch,w,h] to [bN,w,h,ch]
        mask = torch.permute(mask, (0, 2, 3, 1))

        return mask

In [3]:
from torch.utils.data import Dataset
from PIL import Image

In [4]:
class SegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None, mask_transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.mask_transform = mask_transform
        self.images = sorted(os.listdir(images_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.images_dir, img_name)
        mask_path = os.path.join(self.masks_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        return image, mask

In [11]:
from torch.utils.data import DataLoader
from torchvision import transforms

# Transforms definition
img_transform = transforms.Compose([
    transforms.ToTensor()
])
mask_transform = transforms.Compose([
    transforms.ToTensor()
])

# Datasets creation
train_dataset = SegmentationDataset(
    images_dir="../datasets/train-seg/images",
    masks_dir="../datasets/train-seg/masks",
    transform=img_transform,
    mask_transform=mask_transform
)
val_dataset = SegmentationDataset(
    images_dir="../datasets/val-seg/images",
    masks_dir="../datasets/val-seg/masks",
    transform=img_transform,
    mask_transform=mask_transform
)
test_dataset = SegmentationDataset(
    images_dir="../datasets/test-seg/images",
    masks_dir="../datasets/test-seg/masks",
    transform=img_transform,
    mask_transform=mask_transform
)

# DataLoaders creation
BATCH_SIZE = 8
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
import torch.nn as nn
import torch.optim as optim

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(in_channels=3, num_classes=1).to(device)

criterion = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

EPOCHS = 20

train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    model.train()  
    running_loss = 0.  # Pour accumuler la loss de l’epoch

    for images, masks in train_loader:
        images = images.to(device)  
        masks = masks.to(device)    
        masks = masks.float()       

        optimizer.zero_grad()       # Remet à zéro les gradients

        outputs = model(images)     # Passe les images dans le modèle
        outputs = outputs.permute(0, 3, 1, 2)  # [B, H, W, 1] -> [B, 1, H, W] pour la loss

        loss = criterion(outputs, masks)  # Calcule la perte
        loss.backward()                   # Rétro-propagation
        optimizer.step()                  # Mise à jour des poids

        running_loss += loss.item() * images.size(0)  # Accumule la loss pondérée par la taille du batch

    epoch_loss = running_loss / len(train_loader.dataset)  # Moyenne de la loss sur l’epoch
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss:.4f}")

    # Validation (désactive le calcul du gradient)
    model.eval()
    val_loss = 0.
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            masks = masks.float()
            outputs = model(images)
            outputs = outputs.permute(0, 3, 1, 2)
            loss = criterion(outputs, masks)
            val_loss += loss.item() * images.size(0)
    val_loss /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f}")
    
    train_losses.append(epoch_loss)
    val_losses.append(val_loss)

FileNotFoundError: [Errno 2] No such file or directory: '../datasets/train-seg/masks\\010230.jpg'

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

In [None]:
model.eval()
test_loss = 0.
with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)
        masks = masks.float()
        outputs = model(images)
        outputs = outputs.permute(0, 3, 1, 2)
        loss = criterion(outputs, masks)
        test_loss += loss.item() * images.size(0)
test_loss /= len(test_loader.dataset)
print(f"Test Loss: {test_loss:.4f}")