In [2]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples


In [5]:
class Focal_Tversky(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=1.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta  = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, logits, targets):
        # logits, targets: (N, 1, H, W) or (N, H, W)
        probs = torch.sigmoid(logits)
        N = targets.size(0)

        # flatten per sample
        probs  = probs.view(N, -1)
        targets = targets.view(N, -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        smooth = torch.tensor(self.smooth, device=probs.device, dtype=probs.dtype)
        tversky = (TP + smooth) / (TP + self.alpha * FP + self.beta * FN + smooth)

        # focal modulation
        loss = (1 - tversky) ** self.gamma

        return loss.mean()


In [None]:
# Model Architecture 
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = ConvBlock(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        x = self.conv(x)
        p = self.pool(x)
        return x, p

class DecoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_c * 2, out_c)

    def forward(self, x, skip):
        x = self.up(x)
        x = nn.functional.interpolate(x, size=skip.size()[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], axis=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = EncoderBlock(1, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256, 512)
        self.b = ConvBlock(512, 1024)
        self.d1 = DecoderBlock(1024, 512)
        self.d2 = DecoderBlock(512, 256)
        self.d3 = DecoderBlock(256, 128)
        self.d4 = DecoderBlock(128, 64)
        self.outputs = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        b = self.b(p4)
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        return self.outputs(d4)


In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [8]:
seeding(42)
MODEL_NAME = "Green_UNet_Focal_Tversky_Full"
MODEL_DIRECTORY = "Green_Model_UNet_Focal_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_UNet_Focal_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# load full training set 
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test
test_images = sorted(glob("./final_dataset/test/images/*"))
test_masks  = sorted(glob("./final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = Focal_Tversky()

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [10]:
for epoch in range(45):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                  

Epoch 01 | Time: 1m 8s
	Train Loss: 0.8802 | Valid Loss: 0.8554
	Accuracy: 0.9239 | F1: 0.1820 | Dice: 0.2188 | Recall: 0.7921 | Precision: 0.1269 | Jaccard: 0.1096
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 02 | Time: 1m 10s
	Train Loss: 0.8292 | Valid Loss: 0.8232
	Accuracy: 0.9182 | F1: 0.1721 | Dice: 0.1984 | Recall: 0.8554 | Precision: 0.1122 | Jaccard: 0.1028
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 03 | Time: 1m 7s
	Train Loss: 0.7599 | Valid Loss: 0.7716
	Accuracy: 0.9569 | F1: 0.2418 | Dice: 0.2818 | Recall: 0.7282 | Precision: 0.1747 | Jaccard: 0.1508
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 04 | Time: 1m 7s
	Train Loss: 0.6648 | Valid Loss: 0.7083
	Accuracy: 0.9850 | F1: 0.3746 | Dice: 0.4089 | Recall: 0.4501 | Precision: 0.3746 | Jaccard: 0.2416
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 05 | Time: 1m 7s
	Train Loss: 0.5597 | Valid Loss: 0.6137
	Accuracy: 0.9804 | F1: 0.3368 | Dice: 0.3809 | Recall: 0.5584 | Precision: 0.2890 | Jaccard: 0.2194
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 06 | Time: 1m 7s
	Train Loss: 0.4542 | Valid Loss: 0.5531
	Accuracy: 0.9829 | F1: 0.3716 | Dice: 0.4019 | Recall: 0.5268 | Precision: 0.3249 | Jaccard: 0.2460
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 07 | Time: 1m 6s
	Train Loss: 0.3932 | Valid Loss: 0.5155
	Accuracy: 0.9801 | F1: 0.3669 | Dice: 0.4043 | Recall: 0.6704 | Precision: 0.2895 | Jaccard: 0.2381
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 08 | Time: 1m 6s
	Train Loss: 0.3384 | Valid Loss: 0.5060
	Accuracy: 0.9826 | F1: 0.3740 | Dice: 0.4080 | Recall: 0.5588 | Precision: 0.3213 | Jaccard: 0.2426
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 09 | Time: 1m 6s
	Train Loss: 0.2821 | Valid Loss: 0.4780
	Accuracy: 0.9881 | F1: 0.4345 | Dice: 0.4511 | Recall: 0.4316 | Precision: 0.4723 | Jaccard: 0.3010
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 10 | Time: 1m 5s
	Train Loss: 0.2562 | Valid Loss: 0.4502
	Accuracy: 0.9878 | F1: 0.4334 | Dice: 0.4478 | Recall: 0.4499 | Precision: 0.4457 | Jaccard: 0.3011
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 11 | Time: 1m 5s
	Train Loss: 0.2393 | Valid Loss: 0.5067
	Accuracy: 0.9866 | F1: 0.3913 | Dice: 0.4015 | Recall: 0.3813 | Precision: 0.4239 | Jaccard: 0.2645


                                                                  

Epoch 12 | Time: 1m 5s
	Train Loss: 0.2015 | Valid Loss: 0.4673
	Accuracy: 0.9878 | F1: 0.4217 | Dice: 0.4366 | Recall: 0.4158 | Precision: 0.4596 | Jaccard: 0.2879


                                                                  

Epoch 13 | Time: 1m 5s
	Train Loss: 0.1682 | Valid Loss: 0.5033
	Accuracy: 0.9871 | F1: 0.3905 | Dice: 0.4030 | Recall: 0.3746 | Precision: 0.4360 | Jaccard: 0.2641


                                                                  

Epoch 14 | Time: 1m 5s
	Train Loss: 0.1617 | Valid Loss: 0.4342
	Accuracy: 0.9874 | F1: 0.4498 | Dice: 0.4678 | Recall: 0.4977 | Precision: 0.4412 | Jaccard: 0.3104
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 15 | Time: 1m 5s
	Train Loss: 0.1539 | Valid Loss: 0.4432
	Accuracy: 0.9885 | F1: 0.4449 | Dice: 0.4549 | Recall: 0.4550 | Precision: 0.4549 | Jaccard: 0.3150


                                                                  

Epoch 16 | Time: 1m 5s
	Train Loss: 0.1337 | Valid Loss: 0.3840
	Accuracy: 0.9875 | F1: 0.4721 | Dice: 0.4946 | Recall: 0.5591 | Precision: 0.4434 | Jaccard: 0.3293
Best Green_UNet_Focal_Tversky_Full Saved


                                                                  

Epoch 17 | Time: 1m 5s
	Train Loss: 0.1363 | Valid Loss: 0.4647
	Accuracy: 0.9880 | F1: 0.4249 | Dice: 0.4384 | Recall: 0.4066 | Precision: 0.4756 | Jaccard: 0.2892


                                                                  

Epoch 18 | Time: 1m 5s
	Train Loss: 0.1465 | Valid Loss: 0.4530
	Accuracy: 0.9880 | F1: 0.4248 | Dice: 0.4369 | Recall: 0.4414 | Precision: 0.4325 | Jaccard: 0.2961


                                                                  

Epoch 19 | Time: 1m 5s
	Train Loss: 0.1402 | Valid Loss: 0.4225
	Accuracy: 0.9877 | F1: 0.4505 | Dice: 0.4717 | Recall: 0.4998 | Precision: 0.4465 | Jaccard: 0.3082


                                                                  

Epoch 20 | Time: 1m 5s
	Train Loss: 0.1283 | Valid Loss: 0.4278
	Accuracy: 0.9875 | F1: 0.4447 | Dice: 0.4706 | Recall: 0.5139 | Precision: 0.4340 | Jaccard: 0.3050


                                                                  

Epoch 21 | Time: 1m 5s
	Train Loss: 0.1534 | Valid Loss: 0.4345
	Accuracy: 0.9869 | F1: 0.4382 | Dice: 0.4494 | Recall: 0.5032 | Precision: 0.4060 | Jaccard: 0.3031


                                                                  

Epoch 22 | Time: 1m 5s
	Train Loss: 0.1117 | Valid Loss: 0.4043
	Accuracy: 0.9882 | F1: 0.4770 | Dice: 0.4974 | Recall: 0.5299 | Precision: 0.4687 | Jaccard: 0.3303


                                                                  

Epoch 23 | Time: 1m 5s
	Train Loss: 0.0997 | Valid Loss: 0.3951
	Accuracy: 0.9879 | F1: 0.4794 | Dice: 0.4979 | Recall: 0.5396 | Precision: 0.4623 | Jaccard: 0.3315


                                                                  

Epoch 24 | Time: 1m 5s
	Train Loss: 0.0922 | Valid Loss: 0.3946
	Accuracy: 0.9882 | F1: 0.4805 | Dice: 0.4961 | Recall: 0.5123 | Precision: 0.4809 | Jaccard: 0.3312


                                                                  

Epoch 25 | Time: 1m 5s
	Train Loss: 0.0887 | Valid Loss: 0.3970
	Accuracy: 0.9883 | F1: 0.4804 | Dice: 0.4972 | Recall: 0.4989 | Precision: 0.4954 | Jaccard: 0.3293


                                                                  

Epoch 26 | Time: 1m 5s
	Train Loss: 0.0868 | Valid Loss: 0.3932
	Accuracy: 0.9881 | F1: 0.4816 | Dice: 0.4999 | Recall: 0.5066 | Precision: 0.4934 | Jaccard: 0.3302


                                                                  

Epoch 27 | Time: 1m 5s
	Train Loss: 0.0851 | Valid Loss: 0.4039
	Accuracy: 0.9879 | F1: 0.4689 | Dice: 0.4874 | Recall: 0.4889 | Precision: 0.4858 | Jaccard: 0.3185


                                                                  

Epoch 28 | Time: 1m 5s
	Train Loss: 0.0829 | Valid Loss: 0.4021
	Accuracy: 0.9882 | F1: 0.4774 | Dice: 0.4962 | Recall: 0.4997 | Precision: 0.4927 | Jaccard: 0.3276


                                                                  

Epoch 29 | Time: 1m 5s
	Train Loss: 0.0810 | Valid Loss: 0.4057
	Accuracy: 0.9880 | F1: 0.4719 | Dice: 0.4909 | Recall: 0.4885 | Precision: 0.4932 | Jaccard: 0.3212


                                                                  

Epoch 30 | Time: 1m 5s
	Train Loss: 0.0811 | Valid Loss: 0.4044
	Accuracy: 0.9881 | F1: 0.4724 | Dice: 0.4902 | Recall: 0.4798 | Precision: 0.5010 | Jaccard: 0.3216


                                                                  

Epoch 31 | Time: 1m 5s
	Train Loss: 0.0810 | Valid Loss: 0.4018
	Accuracy: 0.9882 | F1: 0.4776 | Dice: 0.4962 | Recall: 0.4974 | Precision: 0.4950 | Jaccard: 0.3273


                                                                  

Epoch 32 | Time: 1m 5s
	Train Loss: 0.0806 | Valid Loss: 0.4044
	Accuracy: 0.9881 | F1: 0.4741 | Dice: 0.4929 | Recall: 0.4839 | Precision: 0.5021 | Jaccard: 0.3229


                                                                  

Epoch 33 | Time: 1m 5s
	Train Loss: 0.0799 | Valid Loss: 0.3962
	Accuracy: 0.9880 | F1: 0.4824 | Dice: 0.5016 | Recall: 0.5158 | Precision: 0.4882 | Jaccard: 0.3314


                                                                  

Epoch 34 | Time: 1m 5s
	Train Loss: 0.0799 | Valid Loss: 0.4047
	Accuracy: 0.9881 | F1: 0.4732 | Dice: 0.4921 | Recall: 0.4849 | Precision: 0.4996 | Jaccard: 0.3227


                                                                  

Epoch 35 | Time: 1m 5s
	Train Loss: 0.0798 | Valid Loss: 0.4033
	Accuracy: 0.9880 | F1: 0.4757 | Dice: 0.4955 | Recall: 0.4989 | Precision: 0.4921 | Jaccard: 0.3251


                                                                  

Epoch 36 | Time: 1m 5s
	Train Loss: 0.0793 | Valid Loss: 0.4043
	Accuracy: 0.9879 | F1: 0.4729 | Dice: 0.4921 | Recall: 0.4933 | Precision: 0.4909 | Jaccard: 0.3222


                                                                  

Epoch 37 | Time: 1m 5s
	Train Loss: 0.0804 | Valid Loss: 0.4027
	Accuracy: 0.9883 | F1: 0.4777 | Dice: 0.4966 | Recall: 0.4939 | Precision: 0.4992 | Jaccard: 0.3274


                                                                  

Epoch 38 | Time: 1m 5s
	Train Loss: 0.0797 | Valid Loss: 0.4053
	Accuracy: 0.9881 | F1: 0.4719 | Dice: 0.4899 | Recall: 0.4834 | Precision: 0.4965 | Jaccard: 0.3214


                                                                  

Epoch 39 | Time: 1m 5s
	Train Loss: 0.0790 | Valid Loss: 0.4019
	Accuracy: 0.9881 | F1: 0.4766 | Dice: 0.4958 | Recall: 0.5003 | Precision: 0.4914 | Jaccard: 0.3263


                                                                  

Epoch 40 | Time: 1m 5s
	Train Loss: 0.0798 | Valid Loss: 0.4038
	Accuracy: 0.9882 | F1: 0.4749 | Dice: 0.4943 | Recall: 0.4901 | Precision: 0.4986 | Jaccard: 0.3243


                                                                  

Epoch 41 | Time: 1m 5s
	Train Loss: 0.0794 | Valid Loss: 0.4030
	Accuracy: 0.9880 | F1: 0.4769 | Dice: 0.4960 | Recall: 0.5051 | Precision: 0.4872 | Jaccard: 0.3268


                                                                  

Epoch 42 | Time: 1m 5s
	Train Loss: 0.0788 | Valid Loss: 0.4003
	Accuracy: 0.9881 | F1: 0.4797 | Dice: 0.4995 | Recall: 0.5047 | Precision: 0.4943 | Jaccard: 0.3290


                                                                  

Epoch 43 | Time: 1m 5s
	Train Loss: 0.0794 | Valid Loss: 0.4032
	Accuracy: 0.9881 | F1: 0.4749 | Dice: 0.4939 | Recall: 0.4976 | Precision: 0.4902 | Jaccard: 0.3249


                                                                  

Epoch 44 | Time: 1m 5s
	Train Loss: 0.0789 | Valid Loss: 0.4054
	Accuracy: 0.9881 | F1: 0.4732 | Dice: 0.4922 | Recall: 0.4830 | Precision: 0.5017 | Jaccard: 0.3224


                                                                  

Epoch 45 | Time: 1m 5s
	Train Loss: 0.0798 | Valid Loss: 0.4069
	Accuracy: 0.9881 | F1: 0.4722 | Dice: 0.4914 | Recall: 0.4801 | Precision: 0.5031 | Jaccard: 0.3210




In [11]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


Testing: 100%|██████████| 14/14 [00:05<00:00,  2.66it/s]

Jaccard: 0.4182, F1: 0.5812, Recall: 0.7133, Precision: 0.5065, Accuracy: 0.9904
FPS: 241.41665570319694



