In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from segmentation_models_pytorch import Unet
import numpy as np
import os


In [42]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))  # Ensure consistent ordering
        self.masks = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Convert mask to grayscale

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = mask.unsqueeze(0)  # Add channel dimension for mask
        return image, mask


In [43]:
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])


In [44]:
train_dataset = SegmentationDataset(
    image_dir='/home/abid/Code/U/isic2017/train/images', 
    mask_dir='/home/abid/Code/U/isic2017/train/masks',
    transform=image_transform
)

val_dataset = SegmentationDataset(
    image_dir='/home/abid/Code/U/isic2017/val/images', 
    mask_dir='/home/abid/Code/U/isic2017/val/masks',
    transform=image_transform
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)


In [45]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1)
model = model.to(device)


In [46]:
criterion = nn.BCEWithLogitsLoss()  # Binary Cross Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)


In [47]:
def dice_coefficient(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
    return dice

def miou(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    miou = (intersection + smooth) / (union + smooth)
    return miou


In [50]:
num_epochs = 50

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    running_dice = 0.0
    running_miou = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device).float()  # Ensure masks are float type

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks.squeeze(1))  # Remove extra channel dimension from masks

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute metrics
        pred = torch.sigmoid(outputs)
        pred_binary = (pred > 0.5).float()  # Binarize predictions

        running_loss += loss.item()
        running_dice += dice_coefficient(pred_binary, masks.squeeze(1)).item()
        running_miou += miou(pred_binary, masks.squeeze(1)).item()

    # Average loss and metrics for the epoch
    epoch_loss = running_loss / len(train_loader)
    epoch_dice = running_dice / len(train_loader)
    epoch_miou = running_miou / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Dice: {epoch_dice:.4f}, MIoU: {epoch_miou:.4f}")

    # Validation loop
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    val_dice = 0.0
    val_miou = 0.0
    with torch.no_grad():  # Disable gradient computation for validation
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device).float()  # Ensure masks are float type

            outputs = model(images)
            loss = criterion(outputs, masks.squeeze(1))  # Remove extra channel dimension from masks
            val_loss += loss.item()

            pred = torch.sigmoid(outputs)
            pred_binary = (pred > 0.5).float()  # Binarize predictions

            val_dice += dice_coefficient(pred_binary, masks.squeeze(1)).item()
            val_miou += miou(pred_binary, masks.squeeze(1)).item()

    # Average validation loss and metrics
    val_loss /= len(val_loader)
    val_dice /= len(val_loader)
    val_miou /= len(val_loader)
    print(f"Validation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}, Validation MIoU: {val_miou:.4f}")

    # Update learning rate if validation loss plateaus
    scheduler.step(val_loss)


Epoch [1/50], Loss: 0.1126, Dice: 0.9267, MIoU: 0.8647
Validation Loss: 0.1352, Validation Dice: 0.8237, Validation MIoU: 0.7132
Epoch [2/50], Loss: 0.0902, Dice: 0.9344, MIoU: 0.8775
Validation Loss: 0.1528, Validation Dice: 0.8301, Validation MIoU: 0.7180
Epoch [3/50], Loss: 0.0856, Dice: 0.9283, MIoU: 0.8676
Validation Loss: 0.1515, Validation Dice: 0.7993, Validation MIoU: 0.6821
Epoch [4/50], Loss: 0.0825, Dice: 0.9253, MIoU: 0.8628
Validation Loss: 0.1213, Validation Dice: 0.8454, Validation MIoU: 0.7409
Epoch [5/50], Loss: 0.0634, Dice: 0.9432, MIoU: 0.8930
Validation Loss: 0.1147, Validation Dice: 0.8538, Validation MIoU: 0.7540
Epoch [6/50], Loss: 0.0547, Dice: 0.9504, MIoU: 0.9059
Validation Loss: 0.1377, Validation Dice: 0.8348, Validation MIoU: 0.7280
Epoch [7/50], Loss: 0.0484, Dice: 0.9551, MIoU: 0.9144
Validation Loss: 0.1259, Validation Dice: 0.8495, Validation MIoU: 0.7474
Epoch [8/50], Loss: 0.0441, Dice: 0.9587, MIoU: 0.9208
Validation Loss: 0.1311, Validation Dice: 

In [59]:
test_dataset = SegmentationDataset(
    image_dir='/home/abid/Code/U/isic2017/val/images', 
    mask_dir='/home/abid/Code/U/isic2017/val/masks',
    transform=image_transform
)

test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)


In [60]:
test_model(model, test_loader, device)


Test Loss: 0.1351, Test Dice: 0.8512, Test MIoU: 0.7514
