In [2]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Patch / Batch Variables 
PATCH_SIZE = 128    # 32, 64, 128
OVERLAP    = 0.5   # 0.0, 0.25, 0.5, 0.8
BATCH_SIZE = 32
EPOCH = 40
PATIENCE = 10

MODEL_NAME       = f"G_Custom_Attention_Dice_BCE_P{PATCH_SIZE}_O{int(OVERLAP*100)}"
MODEL_DIRECTORY  = f"G_Model_Custom_Attention_Dice_BCE_P{PATCH_SIZE}_O{int(OVERLAP*100)}"
RESULT_DIRECTORY = f"G_Results_Custom_Attention_Dice_BCE_P{PATCH_SIZE}_O{int(OVERLAP*100)}"

In [4]:
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    mins = int(elapsed_time / 60)
    secs = int(elapsed_time - mins*60)
    return mins, secs

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)
    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)


In [5]:
class PatchRetinalDataset(Dataset):
    def __init__(self, images_path, masks_path, patch_size, overlap):
        self.patch_size = patch_size
        self.stride = int(patch_size * (1 - overlap))

        self.patches = []  # list of (image_index, y, x)

        self.images = []
        self.masks = []

        for img_path, msk_path in zip(images_path, masks_path):
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)[:, :, 1] / 255.0  # green channel
            msk = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) / 255.0

            H, W = img.shape
            self.images.append(img)
            self.masks.append(msk)

            for y in range(0, H - patch_size + 1, self.stride):
                for x in range(0, W - patch_size + 1, self.stride):
                    self.patches.append((len(self.images) - 1, y, x))

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        img_idx, y, x = self.patches[idx]
        img = self.images[img_idx]
        msk = self.masks[img_idx]

        patch = img[y:y+self.patch_size, x:x+self.patch_size]
        mask  = msk[y:y+self.patch_size, x:x+self.patch_size]

        patch_tensor = torch.from_numpy(patch[None].astype(np.float32))
        mask_tensor  = torch.from_numpy(mask[None].astype(np.float32))
        return patch_tensor, mask_tensor


In [6]:
class DiceBCELoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.smooth = smooth
        self.bce_fn = nn.BCEWithLogitsLoss()

    def forward(self, logits, targets):
        # BCE component (stable, with logits)
        bce_loss = self.bce_fn(logits, targets)

        # Dice component (per-sample)
        probs = torch.sigmoid(logits)
        batch_size = probs.shape[0]
        dice_losses = []
        for i in range(batch_size):
            p = probs[i].view(-1)
            
            g = targets[i].view(-1)
            inter = (p * g).sum()
            dice = 1 - (2*inter + self.smooth) / (p.sum() + g.sum() + self.smooth)
            dice_losses.append(dice)
        dice_loss = torch.stack(dice_losses).mean()

        return self.alpha * dice_loss + (1 - self.alpha) * bce_loss


In [7]:
class AttentionBlock(nn.Module):
    """Additive attention block for U-Net skip connections."""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # W_g: gating signal transform
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # W_x: skip connection transform
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # psi: attention coefficient
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # g: gating signal (from decoder), x: skip features (from encoder)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi  # apply attention


class TripleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 3 times."""
    def __init__(self, in_c, mid1_c, mid2_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid1_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid1_c)
        self.conv2 = nn.Conv2d(mid1_c, mid2_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(mid2_c)
        self.conv3 = nn.Conv2d(mid2_c, out_c, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class DoubleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 2 times."""
    def __init__(self, in_c, mid_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid_c)
        self.conv2 = nn.Conv2d(mid_c, out_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x


class UNet(nn.Module):
    """U-Net with attention gates on skip connections and additional subsampling concat."""
    def __init__(self):
        super().__init__()
        # Encoder
        self.down1 = TripleConv(1, 32, 32, 64)
        self.down2 = TripleConv(64, 64, 64, 128)
        self.down3 = DoubleConv(128, 128, 256)
        self.down4 = DoubleConv(256, 256, 256)
        self.pool  = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = DoubleConv(256, 512, 256)

        # Decoder up and conv
        self.up4  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = DoubleConv(256+256, 256, 256)
        self.up3  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = DoubleConv(256+256, 128, 128)
        self.up2  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = TripleConv(128+128, 64, 64, 64)
        self.up1  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = TripleConv(64+64, 32, 32, 32)

        # Attention blocks for skip connections
        self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.att1 = AttentionBlock(F_g=64,  F_l=64,  F_int=32)

        # Final subsample, concat and output
        self.final_pool       = nn.MaxPool2d(2, 2)
        self.final_upsample   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.out_conv         = nn.Conv2d(33, 1, kernel_size=1)

    def forward(self, x):
        input_image = x
        # Encoder
        x1  = self.down1(x)
        x1p = self.pool(x1)
        x2  = self.down2(x1p)
        x2p = self.pool(x2)
        x3  = self.down3(x2p)
        x3p = self.pool(x3)
        x4  = self.down4(x3p)
        x4p = self.pool(x4)

        # Bottleneck
        xb  = self.bottleneck(x4p)

        # Decoder + Attention
        d4  = self.up4(xb)
        x4a = self.att4(g=d4, x=x4)
        d4  = torch.cat([x4a, d4], dim=1)
        d4  = self.dec4(d4)

        d3  = self.up3(d4)
        x3a = self.att3(g=d3, x=x3)
        d3  = torch.cat([x3a, d3], dim=1)
        d3  = self.dec3(d3)

        d2  = self.up2(d3)
        x2a = self.att2(g=d2, x=x2)
        d2  = torch.cat([x2a, d2], dim=1)
        d2  = self.dec2(d2)

        d1  = self.up1(d2)
        x1a = self.att1(g=d1, x=x1)
        d1  = torch.cat([x1a, d1], dim=1)
        d1  = self.dec1(d1)

        # Additional subsampling & concatenation
        d1s = self.final_pool(d1)
        ins = self.final_pool(input_image)
        cat = torch.cat([d1s, ins], dim=1)
        out = self.final_upsample(cat)
        out = self.out_conv(out)
        return out


In [8]:
seeding(42)
create_directory(MODEL_DIRECTORY)
create_directory(RESULT_DIRECTORY)

In [None]:
# load file lists
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
test_images  = sorted(glob("./final_dataset/test/images/*"))
test_masks   = sorted(glob("./final_dataset/test/masks/*"))

# build datasets
train_ds_full = PatchRetinalDataset(train_images, train_masks, PATCH_SIZE, OVERLAP)
test_ds_full  = PatchRetinalDataset(test_images, test_masks, PATCH_SIZE, OVERLAP)

# split test into val/test
n_val = len(test_ds_full) // 2
n_test = len(test_ds_full) - n_val
valid_ds, test_ds = random_split(test_ds_full, [n_val, n_test])

train_loader = DataLoader(train_ds_full, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)
valid_loader = DataLoader(valid_ds,      batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# setup
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = DiceBCELoss()

# training/eval functions
def train_epoch(model, loader):
    model.train()
    total = 0
    for x, y in tqdm(loader, desc="Train", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = loss_fn(model(x), y)
        loss.backward()
        optimizer.step()
        total += loss.item()
    return total / len(loader)

def eval_epoch(model, loader):
    model.eval()
    total = 0
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Valid", leave=False):
            x, y = x.to(device), y.to(device)
            total += loss_fn(model(x), y).item()
    return total / len(loader)




In [None]:
# Training Loop 
early_stop_counter = 0
best_loss = float("inf")

best_loss = float("inf")
for epoch in range(EPOCH):
    start = time.time()
    tr_loss = train_epoch(model, train_loader)
    va_loss = eval_epoch(model, valid_loader)
    scheduler.step(va_loss)

    # compute metrics on validation patches
    model.eval()
    mets = [0.0]*5
    with torch.no_grad():
        for x,y in tqdm(valid_loader, desc="Val Metrics", leave=False):
            x,y = x.to(device), y.to(device)
            mets = list(map(add, mets, calculate_metrics(y, model(x))))
    mets = [m/len(valid_loader) for m in mets]
    j,f1,r,p,a = mets
    dice = (2*p*r)/(p+r+1e-7)
    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02}/{EPOCH} | {mins}m {secs}s")
    print(f"Train: {tr_loss:.4f} | Valid: {va_loss:.4f}")
    print(f"Acc:{a:.4f} F1:{f1:.4f} Dice:{dice:.4f} Rec:{r:.4f} Prec:{p:.4f} Jac:{j:.4f}")

    if va_loss < best_loss:
        best_loss = va_loss
        torch.save(model.state_dict(), os.path.join(MODEL_DIRECTORY, MODEL_NAME+".pth"))
        print("Saved best model")
        early_stop_counter = 0
   
    else:
        early_stop_counter += 1
        print(f"No improvement for {early_stop_counter}/{PATIENCE} epochs")

    if early_stop_counter >= PATIENCE:
        print(f"Stopping early after {PATIENCE} epochs without improvement.")
        break
    print()


                                                              

Epoch 01/40 | 35m 32s
Train: 0.4634 | Valid: 0.4577
Acc:0.9803 F1:0.4718 Dice:0.4829 Rec:0.5051 Prec:0.4625 Jac:0.3135
Saved best model



                                                              

Epoch 02/40 | 35m 32s
Train: 0.4386 | Valid: 0.4572
Acc:0.9817 F1:0.4972 Dice:0.5078 Rec:0.5208 Prec:0.4955 Jac:0.3365
Saved best model



                                                              

Epoch 03/40 | 35m 32s
Train: 0.4272 | Valid: 0.4629
Acc:0.9789 F1:0.4734 Dice:0.4854 Rec:0.5354 Prec:0.4440 Jac:0.3159
No improvement for 1/10 epochs



                                                              

Epoch 04/40 | 35m 32s
Train: 0.4166 | Valid: 0.4697
Acc:0.9840 F1:0.4557 Dice:0.4659 Rec:0.3760 Prec:0.6125 Jac:0.3004
No improvement for 2/10 epochs



                                                              

Epoch 05/40 | 35m 31s
Train: 0.4079 | Valid: 0.4694
Acc:0.9811 F1:0.4890 Dice:0.5012 Rec:0.5131 Prec:0.4899 Jac:0.3294
No improvement for 3/10 epochs



                                                              

Epoch 06/40 | 35m 34s
Train: 0.4007 | Valid: 0.4683
Acc:0.9826 F1:0.4886 Dice:0.4999 Rec:0.4739 Prec:0.5290 Jac:0.3283
No improvement for 4/10 epochs



                                                              

Epoch 07/40 | 35m 34s
Train: 0.3946 | Valid: 0.4724
Acc:0.9811 F1:0.4715 Dice:0.4829 Rec:0.4773 Prec:0.4887 Jac:0.3143
No improvement for 5/10 epochs



                                                              

Epoch 08/40 | 35m 34s
Train: 0.3893 | Valid: 0.4720
Acc:0.9826 F1:0.4681 Dice:0.4782 Rec:0.4347 Prec:0.5314 Jac:0.3105
No improvement for 6/10 epochs



                                                              

Epoch 09/40 | 35m 34s
Train: 0.3781 | Valid: 0.4753
Acc:0.9834 F1:0.4852 Dice:0.4961 Rec:0.4417 Prec:0.5657 Jac:0.3259
No improvement for 7/10 epochs



                                                              

Epoch 10/40 | 35m 35s
Train: 0.3740 | Valid: 0.4782
Acc:0.9836 F1:0.4817 Dice:0.4930 Rec:0.4306 Prec:0.5764 Jac:0.3230
No improvement for 8/10 epochs



                                                              

Epoch 11/40 | 35m 35s
Train: 0.3716 | Valid: 0.4786
Acc:0.9836 F1:0.4800 Dice:0.4910 Rec:0.4276 Prec:0.5764 Jac:0.3214
No improvement for 9/10 epochs



                                                              

Epoch 12/40 | 35m 34s
Train: 0.3696 | Valid: 0.4813
Acc:0.9837 F1:0.4762 Dice:0.4875 Rec:0.4189 Prec:0.5828 Jac:0.3183
No improvement for 10/10 epochs
Stopping early after 10 epochs without improvement.




In [None]:
# Testing & Final Metrics 
model.load_state_dict(torch.load(
    os.path.join(MODEL_DIRECTORY, MODEL_NAME + ".pth"),
    map_location=device
))
model.eval()

metrics_score = [0.0] * 5

# Loop over original full images
for i in tqdm(range(len(test_images)), desc="Testing"):
    img_path = test_images[i]
    msk_path = test_masks[i]

    # Load full-size green channel and mask
    green = cv2.imread(img_path, cv2.IMREAD_COLOR)[:, :, 1] / 255.0
    mask = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE) / 255.0
    H, W = green.shape

    # Sliding window inference
    pred_accum = np.zeros((H, W), dtype=np.float32)
    count_accum = np.zeros((H, W), dtype=np.float32)
    stride = int(PATCH_SIZE * (1 - OVERLAP))

    for y in range(0, H - PATCH_SIZE + 1, stride):
        for x in range(0, W - PATCH_SIZE + 1, stride):
            patch = green[y:y+PATCH_SIZE, x:x+PATCH_SIZE]
            patch_tensor = torch.from_numpy(patch[None, None].astype(np.float32)).to(device)

            with torch.no_grad():
                pred = torch.sigmoid(model(patch_tensor))[0, 0].cpu().numpy()

            pred_accum[y:y+PATCH_SIZE, x:x+PATCH_SIZE] += pred
            count_accum[y:y+PATCH_SIZE, x:x+PATCH_SIZE] += 1.0

    pred_avg = pred_accum / np.maximum(count_accum, 1e-6)
    pred_bin = (pred_avg > 0.5).astype(np.uint8)

    # Metrics
    y_true = mask.reshape(-1) > 0.5
    y_pred = pred_bin.reshape(-1) > 0.5
    mets = [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]
    metrics_score = list(map(add, metrics_score, mets))

    # Save composite
    green_rgb = np.stack([(green * 255).astype(np.uint8)] * 3, axis=-1)
    mask_rgb = mask_parse((mask * 255).astype(np.uint8))
    pred_rgb = mask_parse((pred_bin * 255).astype(np.uint8))
    line = np.ones((H, 10, 3), dtype=np.uint8) * 128
    composite = np.concatenate([green_rgb, line, mask_rgb, line, pred_rgb], axis=1)
    out_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
    plt.imsave(os.path.join(RESULT_DIRECTORY, out_name), composite)

# Print average metrics
num = len(test_images)
j, f1, r, p, a = [m / num for m in metrics_score]
print(f"\nTest Set Metrics:")
print(f"  Jaccard:   {j:.4f}")
print(f"  F1 Score:  {f1:.4f}")
print(f"  Recall:    {r:.4f}")
print(f"  Precision: {p:.4f}")
print(f"  Accuracy:  {a:.4f}")


  model.load_state_dict(torch.load(
Testing: 100%|████████████████████████████████| 393/393 [01:34<00:00,  4.17it/s]


Test Set Metrics:
  Jaccard:   0.2443
  F1 Score:  0.3568
  Recall:    0.4251
  Precision: 0.3899
  Accuracy:  0.9856



