In [None]:
# Had to re-run this notebook before submitting. so decided to do reduce epoch to 30. first one was run with standard epoch set up as other experiment.

In [2]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples


In [5]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits).view(logits.size(0), -1)
        targets = targets.view(targets.size(0), -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return (1 - tversky).mean()



In [6]:
class AttentionBlock(nn.Module):
    """Additive attention block for U-Net skip connections."""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # W_g: gating signal transform
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # W_x: skip connection transform
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # psi: attention coefficient
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # g: gating signal (from decoder), x: skip features (from encoder)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi  # apply attention


class TripleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 3 times."""
    def __init__(self, in_c, mid1_c, mid2_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid1_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid1_c)
        self.conv2 = nn.Conv2d(mid1_c, mid2_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(mid2_c)
        self.conv3 = nn.Conv2d(mid2_c, out_c, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class DoubleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 2 times."""
    def __init__(self, in_c, mid_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid_c)
        self.conv2 = nn.Conv2d(mid_c, out_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x


class UNet(nn.Module):
    """U-Net with attention gates on skip connections and additional subsampling concat."""
    def __init__(self):
        super().__init__()
        # Encoder
        self.down1 = TripleConv(1, 32, 32, 64)
        self.down2 = TripleConv(64, 64, 64, 128)
        self.down3 = DoubleConv(128, 128, 256)
        self.down4 = DoubleConv(256, 256, 256)
        self.pool  = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = DoubleConv(256, 512, 256)

        # Decoder up and conv
        self.up4  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = DoubleConv(256+256, 256, 256)
        self.up3  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = DoubleConv(256+256, 128, 128)
        self.up2  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = TripleConv(128+128, 64, 64, 64)
        self.up1  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = TripleConv(64+64, 32, 32, 32)

        # Attention blocks for skip connections
        self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.att1 = AttentionBlock(F_g=64,  F_l=64,  F_int=32)

        # Final subsample, concat and output
        self.final_pool       = nn.MaxPool2d(2, 2)
        self.final_upsample   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.out_conv         = nn.Conv2d(33, 1, kernel_size=1)

    def forward(self, x):
        input_image = x
        # Encoder
        x1  = self.down1(x)
        x1p = self.pool(x1)
        x2  = self.down2(x1p)
        x2p = self.pool(x2)
        x3  = self.down3(x2p)
        x3p = self.pool(x3)
        x4  = self.down4(x3p)
        x4p = self.pool(x4)

        # Bottleneck
        xb  = self.bottleneck(x4p)

        # Decoder + Attention
        d4  = self.up4(xb)
        x4a = self.att4(g=d4, x=x4)
        d4  = torch.cat([x4a, d4], dim=1)
        d4  = self.dec4(d4)

        d3  = self.up3(d4)
        x3a = self.att3(g=d3, x=x3)
        d3  = torch.cat([x3a, d3], dim=1)
        d3  = self.dec3(d3)

        d2  = self.up2(d3)
        x2a = self.att2(g=d2, x=x2)
        d2  = torch.cat([x2a, d2], dim=1)
        d2  = self.dec2(d2)

        d1  = self.up1(d2)
        x1a = self.att1(g=d1, x=x1)
        d1  = torch.cat([x1a, d1], dim=1)
        d1  = self.dec1(d1)

        # Additional subsampling & concatenation
        d1s = self.final_pool(d1)
        ins = self.final_pool(input_image)
        cat = torch.cat([d1s, ins], dim=1)
        out = self.final_upsample(cat)
        out = self.out_conv(out)
        return out


In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [8]:
seeding(42)
MODEL_NAME = "Green_Attention_Custom_Tversky_Full"
MODEL_DIRECTORY = "Green_Model_Attention_Custom_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_Attention_Custom_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# load full training set 
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test
test_images = sorted(glob("./final_dataset/test/images/*"))
test_masks  = sorted(glob("./final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = TverskyLoss()

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [10]:
best_valid_loss = float('inf')
epochs_no_improve = 0
early_stop_patience = 10

for epoch in range(30):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epoch(s)")

    if epochs_no_improve >= early_stop_patience:
        print(f"Early stopping triggered after {epoch+1} epochs.")
        break


                                                                    

Epoch 01 | Time: 13m 51s
	Train Loss: 0.7126 | Valid Loss: 0.6923
	Accuracy: 0.9824 | F1: 0.3762 | Dice: 0.4311 | Recall: 0.4174 | Precision: 0.4457 | Jaccard: 0.2443
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 02 | Time: 11m 53s
	Train Loss: 0.6228 | Valid Loss: 0.6379
	Accuracy: 0.9824 | F1: 0.4251 | Dice: 0.4745 | Recall: 0.5355 | Precision: 0.4259 | Jaccard: 0.2858
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 03 | Time: 11m 53s
	Train Loss: 0.6006 | Valid Loss: 0.6614
	Accuracy: 0.9841 | F1: 0.4227 | Dice: 0.4714 | Recall: 0.4623 | Precision: 0.4808 | Jaccard: 0.2847
No improvement for 1 epoch(s)


                                                                    

Epoch 04 | Time: 11m 52s
	Train Loss: 0.5907 | Valid Loss: 0.6517
	Accuracy: 0.9790 | F1: 0.4001 | Dice: 0.4527 | Recall: 0.5808 | Precision: 0.3709 | Jaccard: 0.2649
No improvement for 2 epoch(s)


                                                                    

Epoch 05 | Time: 11m 51s
	Train Loss: 0.5828 | Valid Loss: 0.6418
	Accuracy: 0.9822 | F1: 0.4268 | Dice: 0.4713 | Recall: 0.5214 | Precision: 0.4301 | Jaccard: 0.2895
No improvement for 3 epoch(s)


                                                                    

Epoch 06 | Time: 11m 56s
	Train Loss: 0.5775 | Valid Loss: 0.6370
	Accuracy: 0.9812 | F1: 0.4206 | Dice: 0.4706 | Recall: 0.5659 | Precision: 0.4028 | Jaccard: 0.2838
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 07 | Time: 11m 58s
	Train Loss: 0.5715 | Valid Loss: 0.6499
	Accuracy: 0.9835 | F1: 0.4317 | Dice: 0.4764 | Recall: 0.4840 | Precision: 0.4691 | Jaccard: 0.2969
No improvement for 1 epoch(s)


                                                                    

Epoch 08 | Time: 11m 52s
	Train Loss: 0.5617 | Valid Loss: 0.6343
	Accuracy: 0.9810 | F1: 0.4170 | Dice: 0.4651 | Recall: 0.5651 | Precision: 0.3952 | Jaccard: 0.2809
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 09 | Time: 12m 17s
	Train Loss: 0.5585 | Valid Loss: 0.6289
	Accuracy: 0.9796 | F1: 0.4151 | Dice: 0.4679 | Recall: 0.6032 | Precision: 0.3822 | Jaccard: 0.2786
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 10 | Time: 11m 54s
	Train Loss: 0.5576 | Valid Loss: 0.6192
	Accuracy: 0.9830 | F1: 0.4473 | Dice: 0.4928 | Recall: 0.5450 | Precision: 0.4497 | Jaccard: 0.3068
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 11 | Time: 11m 55s
	Train Loss: 0.5518 | Valid Loss: 0.6304
	Accuracy: 0.9823 | F1: 0.4360 | Dice: 0.4899 | Recall: 0.5539 | Precision: 0.4392 | Jaccard: 0.2976
No improvement for 1 epoch(s)


                                                                    

Epoch 12 | Time: 11m 54s
	Train Loss: 0.5476 | Valid Loss: 0.6177
	Accuracy: 0.9825 | F1: 0.4409 | Dice: 0.4909 | Recall: 0.5563 | Precision: 0.4393 | Jaccard: 0.3011
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 13 | Time: 11m 58s
	Train Loss: 0.5396 | Valid Loss: 0.6193
	Accuracy: 0.9810 | F1: 0.4315 | Dice: 0.4840 | Recall: 0.5799 | Precision: 0.4153 | Jaccard: 0.2935
No improvement for 1 epoch(s)


                                                                    

Epoch 14 | Time: 11m 58s
	Train Loss: 0.5330 | Valid Loss: 0.6221
	Accuracy: 0.9822 | F1: 0.4339 | Dice: 0.4819 | Recall: 0.5497 | Precision: 0.4291 | Jaccard: 0.2972
No improvement for 2 epoch(s)


                                                                    

Epoch 15 | Time: 11m 55s
	Train Loss: 0.5332 | Valid Loss: 0.6195
	Accuracy: 0.9815 | F1: 0.4361 | Dice: 0.4895 | Recall: 0.5806 | Precision: 0.4232 | Jaccard: 0.2971
No improvement for 3 epoch(s)


                                                                    

Epoch 16 | Time: 11m 56s
	Train Loss: 0.5238 | Valid Loss: 0.6484
	Accuracy: 0.9843 | F1: 0.4350 | Dice: 0.4782 | Recall: 0.4583 | Precision: 0.4998 | Jaccard: 0.2994
No improvement for 4 epoch(s)


                                                                    

Epoch 17 | Time: 11m 55s
	Train Loss: 0.5221 | Valid Loss: 0.6139
	Accuracy: 0.9824 | F1: 0.4430 | Dice: 0.4895 | Recall: 0.5627 | Precision: 0.4331 | Jaccard: 0.3021
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 18 | Time: 11m 53s
	Train Loss: 0.5155 | Valid Loss: 0.6113
	Accuracy: 0.9799 | F1: 0.4278 | Dice: 0.4758 | Recall: 0.6141 | Precision: 0.3884 | Jaccard: 0.2891
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 19 | Time: 11m 58s
	Train Loss: 0.5112 | Valid Loss: 0.6223
	Accuracy: 0.9832 | F1: 0.4490 | Dice: 0.4978 | Recall: 0.5338 | Precision: 0.4663 | Jaccard: 0.3079
No improvement for 1 epoch(s)


                                                                    

Epoch 20 | Time: 11m 56s
	Train Loss: 0.5075 | Valid Loss: 0.6073
	Accuracy: 0.9829 | F1: 0.4501 | Dice: 0.4951 | Recall: 0.5558 | Precision: 0.4463 | Jaccard: 0.3091
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 21 | Time: 11m 54s
	Train Loss: 0.5050 | Valid Loss: 0.6148
	Accuracy: 0.9823 | F1: 0.4498 | Dice: 0.4979 | Recall: 0.5730 | Precision: 0.4402 | Jaccard: 0.3065
No improvement for 1 epoch(s)


                                                                    

Epoch 22 | Time: 11m 55s
	Train Loss: 0.5049 | Valid Loss: 0.6113
	Accuracy: 0.9825 | F1: 0.4530 | Dice: 0.4988 | Recall: 0.5664 | Precision: 0.4457 | Jaccard: 0.3094
No improvement for 2 epoch(s)


                                                                    

Epoch 23 | Time: 11m 54s
	Train Loss: 0.4954 | Valid Loss: 0.6125
	Accuracy: 0.9819 | F1: 0.4498 | Dice: 0.4963 | Recall: 0.5750 | Precision: 0.4365 | Jaccard: 0.3071
No improvement for 3 epoch(s)


                                                                    

Epoch 24 | Time: 11m 57s
	Train Loss: 0.4915 | Valid Loss: 0.5672
	Accuracy: 0.9802 | F1: 0.4371 | Dice: 0.4833 | Recall: 0.5901 | Precision: 0.4092 | Jaccard: 0.2974
Best Green_Attention_Custom_Tversky_Full Saved


                                                                    

Epoch 25 | Time: 11m 57s
	Train Loss: 0.6025 | Valid Loss: 0.6126
	Accuracy: 0.9818 | F1: 0.4089 | Dice: 0.4590 | Recall: 0.4948 | Precision: 0.4280 | Jaccard: 0.2731
No improvement for 1 epoch(s)


                                                                    

Epoch 26 | Time: 11m 53s
	Train Loss: 0.5707 | Valid Loss: 0.5950
	Accuracy: 0.9797 | F1: 0.4158 | Dice: 0.4662 | Recall: 0.5817 | Precision: 0.3890 | Jaccard: 0.2783
No improvement for 2 epoch(s)


                                                                    

Epoch 27 | Time: 11m 53s
	Train Loss: 0.6182 | Valid Loss: 0.5906
	Accuracy: 0.9814 | F1: 0.4191 | Dice: 0.4634 | Recall: 0.5271 | Precision: 0.4134 | Jaccard: 0.2837
No improvement for 3 epoch(s)


                                                                    

Epoch 28 | Time: 11m 52s
	Train Loss: 0.5700 | Valid Loss: 0.6327
	Accuracy: 0.9761 | F1: 0.3587 | Dice: 0.4010 | Recall: 0.5405 | Precision: 0.3188 | Jaccard: 0.2352
No improvement for 4 epoch(s)


                                                                    

Epoch 29 | Time: 11m 54s
	Train Loss: 0.5935 | Valid Loss: 0.6240
	Accuracy: 0.9802 | F1: 0.3688 | Dice: 0.4181 | Recall: 0.4668 | Precision: 0.3786 | Jaccard: 0.2444
No improvement for 5 epoch(s)


                                                                    

Epoch 30 | Time: 11m 55s
	Train Loss: 0.6525 | Valid Loss: 0.8061
	Accuracy: 0.9818 | F1: 0.0000 | Dice: 0.0000 | Recall: 0.0000 | Precision: 0.0000 | Jaccard: 0.0000
No improvement for 6 epoch(s)




In [11]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


  model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
Testing: 100%|██████████| 197/197 [00:34<00:00,  5.79it/s]

Jaccard: 0.2406, F1: 0.3486, Recall: 0.4847, Precision: 0.3293, Accuracy: 0.9831
FPS: 243.24519141462505



