In [2]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Patch / Batch Variables 
PATCH_SIZE = 128    # 32, 64, 128
OVERLAP    = 0.5   # 0.0, 0.25, 0.5, 0.8
BATCH_SIZE = 32
EPOCH = 50
PATIENCE = 13

MODEL_NAME       = f"G_MultiRes_UNet_Tversky_P{PATCH_SIZE}_O{int(OVERLAP*100)}"
MODEL_DIRECTORY  = f"G_Model_MultiRes_UNet_Tversky_P{PATCH_SIZE}_O{int(OVERLAP*100)}"
RESULT_DIRECTORY = f"G_Results_MultiRes_UNet_Tversky_P{PATCH_SIZE}_O{int(OVERLAP*100)}"

In [4]:
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    mins = int(elapsed_time / 60)
    secs = int(elapsed_time - mins*60)
    return mins, secs

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)
    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)


In [None]:
# Patching Dataset  
class PatchRetinalDataset(Dataset):
    def __init__(self, images_path, masks_path, patch_size, overlap):
        self.images_path = images_path
        self.masks_path  = masks_path
        self.patch_size  = patch_size
        self.stride      = int(patch_size * (1 - overlap))
        self.n_images    = len(images_path)

    def __len__(self):
        return self.n_images

    def __getitem__(self, idx):
        img = cv2.imread(self.images_path[idx], cv2.IMREAD_COLOR)[:,:,1] / 255.0
        msk = cv2.imread(self.masks_path[idx],  cv2.IMREAD_GRAYSCALE) / 255.0
        H, W = img.shape

        patches, mask_patches = [], []
        for y in range(0, H - self.patch_size + 1, self.stride):
            for x in range(0, W - self.patch_size + 1, self.stride):
                p = img[y:y+self.patch_size, x:x+self.patch_size]
                m = msk[y:y+self.patch_size, x:x+self.patch_size]
                patches.append(torch.from_numpy(p[None].astype(np.float32)))
                mask_patches.append(torch.from_numpy(m[None].astype(np.float32)))

        return torch.stack(patches, dim=0), torch.stack(mask_patches, dim=0)


In [6]:
class DiceBCELoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.smooth = smooth
        self.bce_fn = nn.BCEWithLogitsLoss()

    def forward(self, logits, targets):
        # BCE component (stable, with logits)
        bce_loss = self.bce_fn(logits, targets)

        # Dice component (per-sample)
        probs = torch.sigmoid(logits)
        batch_size = probs.shape[0]
        dice_losses = []
        for i in range(batch_size):
            p = probs[i].view(-1)
            g = targets[i].view(-1)
            inter = (p * g).sum()
            dice = 1 - (2*inter + self.smooth) / (p.sum() + g.sum() + self.smooth)
            dice_losses.append(dice)
        dice_loss = torch.stack(dice_losses).mean()

        return self.alpha * dice_loss + (1 - self.alpha) * bce_loss


In [7]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, act=True):
        super().__init__()

        layers = [
            nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_c)
        ]
        if act == True:
            layers.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

class multires_block(nn.Module):
    def __init__(self, in_c, out_c, alpha=1.67):
        super().__init__()

        W = out_c * alpha
        self.c1 = conv_block(in_c, int(W*0.167))
        self.c2 = conv_block(int(W*0.167), int(W*0.333))
        self.c3 = conv_block(int(W*0.333), int(W*0.5))

        nf = int(W*0.167) + int(W*0.333) + int(W*0.5)
        self.b1 = nn.BatchNorm2d(nf)
        self.c4 = conv_block(in_c, nf)
        self.relu = nn.ReLU(inplace=True)
        self.b2 = nn.BatchNorm2d(nf)

    def forward(self, x):
        x0 = x
        x1 = self.c1(x0)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        xc = torch.cat([x1, x2, x3], dim=1)
        xc = self.b1(xc)

        sc = self.c4(x0)
        x = self.relu(xc + sc)
        x = self.b2(x)
        return x

class res_path_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.c1 = conv_block(in_c, out_c, act=False)
        self.s1 = conv_block(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self, x):
        x1 = self.c1(x)
        s1 = self.s1(x)
        x = self.relu(x1 + s1)
        x = self.bn(x)
        return x

class res_path(nn.Module):
    def __init__(self, in_c, out_c, length):
        super().__init__()

        layers = []
        for i in range(length):
            layers.append(res_path_block(in_c, out_c))
            in_c = out_c

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

def cal_nf(ch, alpha=1.67):
    W = ch * alpha
    return int(W*0.167) + int(W*0.333) + int(W*0.5)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c, length):
        super().__init__()

        self.c1 = multires_block(in_c, out_c)
        nf = cal_nf(out_c)
        self.s1 = res_path(nf, out_c, length)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        x = self.c1(x)
        s = self.s1(x)
        p = self.pool(x)
        return s, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.c1 = nn.ConvTranspose2d(in_c[0], out_c, kernel_size=2, stride=2, padding=0)
        self.c2 = multires_block(out_c+in_c[1], out_c)

    def forward(self, x, s):
        x = self.c1(x)
        x = torch.cat([x, s], dim=1)
        x = self.c2(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = encoder_block(1, 32, 4)
        self.e2 = encoder_block(cal_nf(32), 64, 3)
        self.e3 = encoder_block(cal_nf(64), 128, 2)
        self.e4 = encoder_block(cal_nf(128), 256, 1)

        """ Bridge """
        self.b1 = multires_block(cal_nf(256), 512)

        """ Decoder """
        self.d1 = decoder_block([cal_nf(512), 256], 256)
        self.d2 = decoder_block([cal_nf(256), 128], 128)
        self.d3 = decoder_block([cal_nf(128), 64], 64)
        self.d4 = decoder_block([cal_nf(64), 32], 32)

        """ Output """
        self.output = nn.Conv2d(cal_nf(32), 1, kernel_size=1, padding=0)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b1 = self.b1(p4)

        d1 = self.d1(b1, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.output(d4)
        return output



In [8]:
seeding(42)
create_directory(MODEL_DIRECTORY)
create_directory(RESULT_DIRECTORY)

In [9]:
# load file lists
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
test_images  = sorted(glob("./final_dataset/test/images/*"))
test_masks   = sorted(glob("./final_dataset/test/masks/*"))

# build datasets
train_ds_full = PatchRetinalDataset(train_images, train_masks, PATCH_SIZE, OVERLAP)
test_ds_full  = PatchRetinalDataset(test_images, test_masks, PATCH_SIZE, OVERLAP)

# split test into val/test
n_val = len(test_ds_full)//2
n_test= len(test_ds_full)-n_val
valid_ds, test_ds = random_split(test_ds_full, [n_val, n_test])

# flatten into patches
def flatten(ds):
    xs, ys = [], []
    for xp, yp in ds:
        xs.append(xp); ys.append(yp)
    return TensorDataset(torch.cat(xs,0), torch.cat(ys,0))

train_ds = flatten(train_ds_full)
valid_ds = flatten(valid_ds)

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = DiceBCELoss()

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# standard train/eval fns
def train_epoch(model, loader):
    model.train(); total=0
    for x,y in tqdm(loader, desc="Train", leave=False):
        x,y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = loss_fn(model(x), y)
        loss.backward(); optimizer.step()
        total += loss.item()
    return total/len(loader)

def eval_epoch(model, loader):
    model.eval(); total=0
    with torch.no_grad():
        for x,y in tqdm(loader, desc="Valid", leave=False):
            x,y = x.to(device), y.to(device)
            total += loss_fn(model(x), y).item()
    return total/len(loader)



In [None]:
# Training Loop 
early_stop_counter = 0
best_loss = float("inf")

best_loss = float("inf")
for epoch in range(EPOCH):
    start = time.time()
    tr_loss = train_epoch(model, train_loader)
    va_loss = eval_epoch(model, valid_loader)
    scheduler.step(va_loss)

    # compute metrics on validation patches
    model.eval()
    mets = [0.0]*5
    with torch.no_grad():
        for x,y in tqdm(valid_loader, desc="Val Metrics", leave=False):
            x,y = x.to(device), y.to(device)
            mets = list(map(add, mets, calculate_metrics(y, model(x))))
    mets = [m/len(valid_loader) for m in mets]
    j,f1,r,p,a = mets
    dice = (2*p*r)/(p+r+1e-7)
    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02}/{EPOCH} | {mins}m {secs}s")
    print(f"Train: {tr_loss:.4f} | Valid: {va_loss:.4f}")
    print(f"Acc:{a:.4f} F1:{f1:.4f} Dice:{dice:.4f} Rec:{r:.4f} Prec:{p:.4f} Jac:{j:.4f}")

    if va_loss < best_loss:
        best_loss = va_loss
        torch.save(model.state_dict(), os.path.join(MODEL_DIRECTORY, MODEL_NAME+".pth"))
        print("Saved best model")
        early_stop_counter = 0
   
    else:
        early_stop_counter += 1
        print(f"No improvement for {early_stop_counter}/{PATIENCE} epochs")

    if early_stop_counter >= PATIENCE:
        print(f"Stopping early after {PATIENCE} epochs without improvement.")
        break
    print()


                                                            

Epoch 01/50 | 3m 42s
Train: 0.7692 | Valid: 0.6838
Acc:0.9792 F1:0.3014 Dice:0.3564 Rec:0.4585 Prec:0.2915 Jac:0.1957
Saved best model



                                                            

Epoch 02/50 | 3m 41s
Train: 0.6006 | Valid: 0.5453
Acc:0.9837 F1:0.3493 Dice:0.3870 Rec:0.4172 Prec:0.3609 Jac:0.2299
Saved best model



                                                            

Epoch 03/50 | 3m 41s
Train: 0.4948 | Valid: 0.4889
Acc:0.9844 F1:0.3655 Dice:0.3997 Rec:0.4290 Prec:0.3741 Jac:0.2444
Saved best model



                                                            

Epoch 04/50 | 3m 41s
Train: 0.4279 | Valid: 0.4553
Acc:0.9833 F1:0.3651 Dice:0.4083 Rec:0.4836 Prec:0.3533 Jac:0.2423
Saved best model



                                                            

Epoch 05/50 | 3m 41s
Train: 0.3705 | Valid: 0.4640
Acc:0.9858 F1:0.3528 Dice:0.3866 Rec:0.3254 Prec:0.4760 Jac:0.2314
No improvement for 1/13 epochs



                                                            

Epoch 06/50 | 3m 41s
Train: 0.3344 | Valid: 0.4441
Acc:0.9851 F1:0.3875 Dice:0.4140 Rec:0.4542 Prec:0.3803 Jac:0.2637
Saved best model



                                                            

Epoch 07/50 | 3m 41s
Train: 0.3145 | Valid: 0.4437
Acc:0.9866 F1:0.3963 Dice:0.4191 Rec:0.4009 Prec:0.4390 Jac:0.2727
Saved best model



                                                            

Epoch 08/50 | 3m 41s
Train: 0.3001 | Valid: 0.4565
Acc:0.9863 F1:0.3924 Dice:0.4168 Rec:0.3520 Prec:0.5108 Jac:0.2656
No improvement for 1/13 epochs



                                                            

Epoch 09/50 | 3m 41s
Train: 0.2918 | Valid: 0.4488
Acc:0.9862 F1:0.3829 Dice:0.4031 Rec:0.3515 Prec:0.4725 Jac:0.2592
No improvement for 2/13 epochs



                                                            

Epoch 10/50 | 3m 41s
Train: 0.2840 | Valid: 0.4441
Acc:0.9859 F1:0.3844 Dice:0.4104 Rec:0.4078 Prec:0.4130 Jac:0.2592
No improvement for 3/13 epochs



                                                            

Epoch 11/50 | 3m 41s
Train: 0.2766 | Valid: 0.4429
Acc:0.9866 F1:0.3953 Dice:0.4213 Rec:0.4055 Prec:0.4383 Jac:0.2688
Saved best model



                                                            

Epoch 12/50 | 3m 41s
Train: 0.2732 | Valid: 0.4429
Acc:0.9864 F1:0.3963 Dice:0.4203 Rec:0.4102 Prec:0.4309 Jac:0.2724
Saved best model



                                                            

Epoch 13/50 | 3m 41s
Train: 0.2664 | Valid: 0.4495
Acc:0.9865 F1:0.3871 Dice:0.4119 Rec:0.3705 Prec:0.4638 Jac:0.2629
No improvement for 1/13 epochs



                                                            

Epoch 14/50 | 3m 41s
Train: 0.2628 | Valid: 0.4427
Acc:0.9867 F1:0.3954 Dice:0.4150 Rec:0.3892 Prec:0.4445 Jac:0.2730
Saved best model



                                                            

Epoch 15/50 | 3m 41s
Train: 0.2587 | Valid: 0.4499
Acc:0.9865 F1:0.3968 Dice:0.4242 Rec:0.3815 Prec:0.4777 Jac:0.2662
No improvement for 1/13 epochs



                                                            

Epoch 16/50 | 3m 41s
Train: 0.2566 | Valid: 0.4444
Acc:0.9862 F1:0.3771 Dice:0.3952 Rec:0.3841 Prec:0.4069 Jac:0.2568
No improvement for 2/13 epochs



                                                            

Epoch 17/50 | 3m 41s
Train: 0.2543 | Valid: 0.4482
Acc:0.9867 F1:0.3777 Dice:0.3985 Rec:0.3441 Prec:0.4733 Jac:0.2543
No improvement for 3/13 epochs



                                                            

Epoch 18/50 | 3m 41s
Train: 0.2499 | Valid: 0.4576
Acc:0.9865 F1:0.3709 Dice:0.3955 Rec:0.3404 Prec:0.4719 Jac:0.2476
No improvement for 4/13 epochs



                                                            

Epoch 19/50 | 3m 41s
Train: 0.2465 | Valid: 0.4582
Acc:0.9866 F1:0.3647 Dice:0.3880 Rec:0.3203 Prec:0.4920 Jac:0.2431
No improvement for 5/13 epochs



                                                            

Epoch 20/50 | 3m 41s
Train: 0.2447 | Valid: 0.4576
Acc:0.9866 F1:0.3692 Dice:0.3908 Rec:0.3369 Prec:0.4651 Jac:0.2482
No improvement for 6/13 epochs



                                                            

Epoch 21/50 | 3m 41s
Train: 0.2402 | Valid: 0.4567
Acc:0.9867 F1:0.3781 Dice:0.4045 Rec:0.3467 Prec:0.4853 Jac:0.2549
No improvement for 7/13 epochs



                                                            

Epoch 22/50 | 3m 41s
Train: 0.2370 | Valid: 0.4587
Acc:0.9867 F1:0.3745 Dice:0.3991 Rec:0.3354 Prec:0.4927 Jac:0.2520
No improvement for 8/13 epochs



                                                            

Epoch 23/50 | 3m 41s
Train: 0.2354 | Valid: 0.4576
Acc:0.9867 F1:0.3791 Dice:0.4034 Rec:0.3396 Prec:0.4967 Jac:0.2553
No improvement for 9/13 epochs



                                                            

Epoch 24/50 | 3m 41s
Train: 0.2343 | Valid: 0.4621
Acc:0.9866 F1:0.3635 Dice:0.3882 Rec:0.3206 Prec:0.4919 Jac:0.2439
No improvement for 10/13 epochs



                                                            

Epoch 25/50 | 3m 41s
Train: 0.2336 | Valid: 0.4580
Acc:0.9867 F1:0.3768 Dice:0.4012 Rec:0.3396 Prec:0.4901 Jac:0.2529
No improvement for 11/13 epochs



                                                            

Epoch 26/50 | 3m 41s
Train: 0.2327 | Valid: 0.4606
Acc:0.9866 F1:0.3694 Dice:0.3933 Rec:0.3322 Prec:0.4819 Jac:0.2476
No improvement for 12/13 epochs



                                                            

Epoch 27/50 | 3m 41s
Train: 0.2319 | Valid: 0.4603
Acc:0.9866 F1:0.3695 Dice:0.3930 Rec:0.3302 Prec:0.4854 Jac:0.2475
No improvement for 13/13 epochs
Stopping early after 13 epochs without improvement.




In [None]:
# Testing & Final Metrics 
model.load_state_dict(torch.load(os.path.join(MODEL_DIRECTORY, MODEL_NAME + ".pth"), map_location=device))
model.eval()

# accumulators for final metrics
metrics_score = [0.0] * 5

# get only the held-out test indices
test_indices = test_ds.indices

for idx in tqdm(test_indices, total=len(test_indices), desc="Testing"):
    img_path = test_images[idx]
    msk_path = test_masks[idx]

    # load full-size
    green = cv2.imread(img_path,    cv2.IMREAD_COLOR)[:,:,1] / 255.0
    mask  = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)  / 255.0
    H, W  = green.shape

    # sliding-window recon
    pred_accum  = np.zeros((H, W), dtype=np.float32)
    count_accum = np.zeros((H, W), dtype=np.float32)
    stride = int(PATCH_SIZE * (1 - OVERLAP))

    for y in range(0, H - PATCH_SIZE + 1, stride):
        for x in range(0, W - PATCH_SIZE + 1, stride):
            patch = torch.from_numpy(
                green[y:y+PATCH_SIZE, x:x+PATCH_SIZE][None,None].astype(np.float32)
            ).to(device)
            with torch.no_grad():
                out = torch.sigmoid(model(patch))[0,0].cpu().numpy()
            pred_accum[y:y+PATCH_SIZE, x:x+PATCH_SIZE]  += out
            count_accum[y:y+PATCH_SIZE, x:x+PATCH_SIZE] += 1.0

    # final binary mask
    pred_avg = pred_accum / np.maximum(count_accum, 1e-6)
    pred_bin = (pred_avg > 0.5).astype(np.uint8)

    # compute this image's metrics
    y_true = mask.reshape(-1) > 0.5
    y_pred = pred_bin.reshape(-1) > 0.5
    mets = [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]
    metrics_score = list(map(add, metrics_score, mets))

    # build & save composite as before
    green_rgb = np.stack([ (green*255).astype(np.uint8) ]*3, axis=-1)
    mask_rgb  = mask_parse((mask*255).astype(np.uint8))
    pred_rgb  = mask_parse((pred_bin*255).astype(np.uint8))
    line = np.ones((H,10,3),dtype=np.uint8)*128
    composite = np.concatenate([green_rgb, line, mask_rgb, line, pred_rgb], axis=1)
    out_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
    plt.imsave(os.path.join(RESULT_DIRECTORY, out_name), composite)

# average and print final metrics over the held-out test subset
num = len(test_indices)
j, f1, r, p, a = [m/num for m in metrics_score]
print(f"\nTest Set Metrics over Test images:")
print(f"  Jaccard:  {j:.4f}")
print(f"  F1 Score: {f1:.4f}")
print(f"  Recall:   {r:.4f}")
print(f"  Precision:{p:.4f}")
print(f"  Accuracy: {a:.4f}")

Testing: 100%|██████████| 14/14 [00:10<00:00,  1.29it/s]


Test Set Metrics over Test images:
  Jaccard:  0.3867
  F1 Score: 0.5475
  Recall:   0.5417
  Precision:0.6043
  Accuracy: 0.9913



