In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class AttentionUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionUNet, self).__init__()

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1 = self.conv_block(in_channels, 64)
        self.conv2 = self.conv_block(64, 128)
        self.conv3 = self.conv_block(128, 256)
        self.conv4 = self.conv_block(256, 512)
        self.conv5 = self.conv_block(512, 1024)

        self.up_conv5 = self.up_conv(1024, 512)
        self.att5 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.conv_up5 = self.conv_block(1024, 512)

        self.up_conv4 = self.up_conv(512, 256)
        self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.conv_up4 = self.conv_block(512, 256)

        self.up_conv3 = self.up_conv(256, 128)
        self.att3 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.conv_up3 = self.conv_block(256, 128)

        self.up_conv2 = self.up_conv(128, 64)
        self.att2 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.conv_up2 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        return block

    def up_conv(self, in_channels, out_channels):
        up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        return up

    def forward(self, x):
        # Encoder path
        c1 = self.conv1(x)
        p1 = self.max_pool(c1)

        c2 = self.conv2(p1)
        p2 = self.max_pool(c2)

        c3 = self.conv3(p2)
        p3 = self.max_pool(c3)

        c4 = self.conv4(p3)
        p4 = self.max_pool(c4)

        c5 = self.conv5(p4)

        # Decoder path
        u5 = self.up_conv5(c5)
        a5 = self.att5(g=u5, x=c4)
        u5 = torch.cat((a5, u5), dim=1)
        u5 = self.conv_up5(u5)

        u4 = self.up_conv4(u5)
        a4 = self.att4(g=u4, x=c3)
        u4 = torch.cat((a4, u4), dim=1)
        u4 = self.conv_up4(u4)

        u3 = self.up_conv3(u4)
        a3 = self.att3(g=u3, x=c2)
        u3 = torch.cat((a3, u3), dim=1)
        u3 = self.conv_up3(u3)

        u2 = self.up_conv2(u3)
        a2 = self.att2(g=u2, x=c1)
        u2 = torch.cat((a2, u2), dim=1)
        u2 = self.conv_up2(u2)

        out = self.final_conv(u2)
        return out


In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
from PIL import Image
import matplotlib.pyplot as plt

In [48]:
# Define dataset
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])
        
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = mask.unsqueeze(0)
        return image, mask

In [49]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Define paths
train_image_dir = '/home/abid/Code/U/isic2017/train/images'
train_mask_dir = '/home/abid/Code/U/isic2017/train/masks'
val_image_dir = '/home/abid/Code/U/isic2017/val/images'
val_mask_dir = '/home/abid/Code/U/isic2017/val/masks'


In [50]:
# Create dataset and dataloader
train_dataset = SegmentationDataset(image_dir=train_image_dir, mask_dir=train_mask_dir, transform=transform)
val_dataset = SegmentationDataset(image_dir=val_image_dir, mask_dir=val_mask_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [51]:
# Initialize model, criterion, optimizer, and scheduler
model = AttentionUNet(in_channels=3, out_channels=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

In [52]:
# Dice and mIoU functions
def dice_coefficient(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

def miou(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return (intersection + smooth) / (union + smooth)

In [53]:
# Training and validation loop
num_epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    running_miou = 0.0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device).float()
        outputs = model(images)
        loss = criterion(outputs, masks.squeeze(1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_dice += dice_coefficient(outputs, masks.squeeze(1)).item()
        running_miou += miou(outputs, masks.squeeze(1)).item()

    epoch_loss = running_loss / len(train_loader)
    epoch_dice = running_dice / len(train_loader)
    epoch_miou = running_miou / len(train_loader)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Dice: {epoch_dice:.4f}, mIoU: {epoch_miou:.4f}")

    # Validation loop (optional)
    model.eval()
    val_loss = 0.0
    val_dice = 0.0
    val_miou = 0.0

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device).float()
            outputs = model(images)
            loss = criterion(outputs, masks.squeeze(1))

            val_loss += loss.item()
            val_dice += dice_coefficient(outputs, masks.squeeze(1)).item()
            val_miou += miou(outputs, masks.squeeze(1)).item()

    val_loss /= len(val_loader)
    val_dice /= len(val_loader)
    val_miou /= len(val_loader)

    print(f"Validation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}, Validation mIoU: {val_miou:.4f}")

    # Scheduler step
    scheduler.step(val_loss)

Epoch [1/50], Loss: 0.3288, Dice: 0.6663, mIoU: 0.5166
Validation Loss: 0.2753, Validation Dice: 0.6405, Validation mIoU: 0.4906
Epoch [2/50], Loss: 0.2369, Dice: 0.7587, mIoU: 0.6209
Validation Loss: 0.3279, Validation Dice: 0.3608, Validation mIoU: 0.2444
Epoch [3/50], Loss: 0.2157, Dice: 0.7715, mIoU: 0.6365
Validation Loss: 0.2277, Validation Dice: 0.6404, Validation mIoU: 0.4988
Epoch [4/50], Loss: 0.1992, Dice: 0.7953, mIoU: 0.6679
Validation Loss: 0.2117, Validation Dice: 0.6749, Validation mIoU: 0.5338
Epoch [5/50], Loss: 0.1947, Dice: 0.7946, mIoU: 0.6678
Validation Loss: 0.2022, Validation Dice: 0.6426, Validation mIoU: 0.5031
Epoch [6/50], Loss: 0.1773, Dice: 0.8178, mIoU: 0.6990
Validation Loss: 0.2750, Validation Dice: 0.6079, Validation mIoU: 0.4556
Epoch [7/50], Loss: 0.1723, Dice: 0.8242, mIoU: 0.7080
Validation Loss: 0.2199, Validation Dice: 0.7049, Validation mIoU: 0.5684
Epoch [8/50], Loss: 0.1678, Dice: 0.8318, mIoU: 0.7182
Validation Loss: 0.2029, Validation Dice: 