In [3]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples

In [6]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits).view(logits.size(0), -1)
        targets = targets.view(targets.size(0), -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return (1 - tversky).mean()



In [7]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, act=True):
        super().__init__()

        layers = [
            nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_c)
        ]
        if act == True:
            layers.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

class multires_block(nn.Module):
    def __init__(self, in_c, out_c, alpha=1.67):
        super().__init__()

        W = out_c * alpha
        self.c1 = conv_block(in_c, int(W*0.167))
        self.c2 = conv_block(int(W*0.167), int(W*0.333))
        self.c3 = conv_block(int(W*0.333), int(W*0.5))

        nf = int(W*0.167) + int(W*0.333) + int(W*0.5)
        self.b1 = nn.BatchNorm2d(nf)
        self.c4 = conv_block(in_c, nf)
        self.relu = nn.ReLU(inplace=True)
        self.b2 = nn.BatchNorm2d(nf)

    def forward(self, x):
        x0 = x
        x1 = self.c1(x0)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        xc = torch.cat([x1, x2, x3], dim=1)
        xc = self.b1(xc)

        sc = self.c4(x0)
        x = self.relu(xc + sc)
        x = self.b2(x)
        return x

class res_path_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.c1 = conv_block(in_c, out_c, act=False)
        self.s1 = conv_block(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self, x):
        x1 = self.c1(x)
        s1 = self.s1(x)
        x = self.relu(x1 + s1)
        x = self.bn(x)
        return x

class res_path(nn.Module):
    def __init__(self, in_c, out_c, length):
        super().__init__()

        layers = []
        for i in range(length):
            layers.append(res_path_block(in_c, out_c))
            in_c = out_c

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

def cal_nf(ch, alpha=1.67):
    W = ch * alpha
    return int(W*0.167) + int(W*0.333) + int(W*0.5)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c, length):
        super().__init__()

        self.c1 = multires_block(in_c, out_c)
        nf = cal_nf(out_c)
        self.s1 = res_path(nf, out_c, length)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        x = self.c1(x)
        s = self.s1(x)
        p = self.pool(x)
        return s, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.c1 = nn.ConvTranspose2d(in_c[0], out_c, kernel_size=2, stride=2, padding=0)
        self.c2 = multires_block(out_c+in_c[1], out_c)

    def forward(self, x, s):
        x = self.c1(x)
        x = torch.cat([x, s], dim=1)
        x = self.c2(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = encoder_block(1, 32, 4)
        self.e2 = encoder_block(cal_nf(32), 64, 3)
        self.e3 = encoder_block(cal_nf(64), 128, 2)
        self.e4 = encoder_block(cal_nf(128), 256, 1)

        """ Bridge """
        self.b1 = multires_block(cal_nf(256), 512)

        """ Decoder """
        self.d1 = decoder_block([cal_nf(512), 256], 256)
        self.d2 = decoder_block([cal_nf(256), 128], 128)
        self.d3 = decoder_block([cal_nf(128), 64], 64)
        self.d4 = decoder_block([cal_nf(64), 32], 32)

        """ Output """
        self.output = nn.Conv2d(cal_nf(32), 1, kernel_size=1, padding=0)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b1 = self.b1(p4)

        d1 = self.d1(b1, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.output(d4)
        return output



In [None]:
# Training and Evaluation
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [9]:
seeding(42)
MODEL_NAME = "Green_AMultiRes_UNet_Tversky_Full"
MODEL_DIRECTORY = "Green_Model_MultiRes_UNet_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_MultiRes_UNet_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# Load full training data
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
full_train_dataset = RetinalDataset(train_images, train_masks)

# Randomly select 70% of training data 
total_train_size = len(full_train_dataset)
subset_size = int(0.7 * total_train_size)
train_subset, _ = random_split(full_train_dataset, [subset_size, total_train_size - subset_size])

#Load test data and split into validation and test (50/50) 
test_images = sorted(glob("./final_dataset/test/images/*"))
test_masks  = sorted(glob("./final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)

n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

#  Setup device, model, optimizer, scheduler, and loss 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = TverskyLoss()

#  Data loaders
train_loader = DataLoader(train_subset, batch_size=1, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

# Initialize best validation loss 
best_valid_loss = float("inf")



In [11]:
for epoch in range(40):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                      

Epoch 01 | Time: 7m 51s
	Train Loss: 0.9511 | Valid Loss: 0.8949
	Accuracy: 0.9049 | F1: 0.1061 | Dice: 0.1301 | Recall: 0.5543 | Precision: 0.0737 | Jaccard: 0.0642
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                      

Epoch 02 | Time: 7m 46s
	Train Loss: 0.7644 | Valid Loss: 0.7646
	Accuracy: 0.9939 | F1: 0.2400 | Dice: 0.3290 | Recall: 0.2753 | Precision: 0.4088 | Jaccard: 0.1533
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                      

Epoch 03 | Time: 7m 46s
	Train Loss: 0.6474 | Valid Loss: 0.7586
	Accuracy: 0.9925 | F1: 0.2281 | Dice: 0.2988 | Recall: 0.3423 | Precision: 0.2651 | Jaccard: 0.1413
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                      

Epoch 04 | Time: 7m 46s
	Train Loss: 0.6171 | Valid Loss: 0.7058
	Accuracy: 0.9943 | F1: 0.3084 | Dice: 0.3801 | Recall: 0.3063 | Precision: 0.5007 | Jaccard: 0.2040
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                      

Epoch 05 | Time: 7m 47s
	Train Loss: 0.5952 | Valid Loss: 0.6822
	Accuracy: 0.9942 | F1: 0.3208 | Dice: 0.3945 | Recall: 0.3600 | Precision: 0.4362 | Jaccard: 0.2125
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                      

Epoch 06 | Time: 7m 46s
	Train Loss: 0.5811 | Valid Loss: 0.7285
	Accuracy: 0.9946 | F1: 0.2943 | Dice: 0.3568 | Recall: 0.2637 | Precision: 0.5517 | Jaccard: 0.1974


                                                                      

Epoch 07 | Time: 7m 46s
	Train Loss: 0.5647 | Valid Loss: 0.7071
	Accuracy: 0.9945 | F1: 0.3021 | Dice: 0.3750 | Recall: 0.3116 | Precision: 0.4707 | Jaccard: 0.1999


                                                                      

Epoch 08 | Time: 7m 46s
	Train Loss: 0.5543 | Valid Loss: 0.6924
	Accuracy: 0.9943 | F1: 0.3130 | Dice: 0.3824 | Recall: 0.3427 | Precision: 0.4324 | Jaccard: 0.2095


                                                                      

Epoch 09 | Time: 7m 46s
	Train Loss: 0.5380 | Valid Loss: 0.7255
	Accuracy: 0.9946 | F1: 0.2951 | Dice: 0.3510 | Recall: 0.2637 | Precision: 0.5245 | Jaccard: 0.1990


                                                                      

Epoch 10 | Time: 7m 46s
	Train Loss: 0.5306 | Valid Loss: 0.7227
	Accuracy: 0.9930 | F1: 0.2894 | Dice: 0.3719 | Recall: 0.3047 | Precision: 0.4771 | Jaccard: 0.1907


                                                                      

Epoch 11 | Time: 7m 45s
	Train Loss: 0.5264 | Valid Loss: 0.6942
	Accuracy: 0.9945 | F1: 0.3203 | Dice: 0.3805 | Recall: 0.3096 | Precision: 0.4934 | Jaccard: 0.2163


                                                                      

Epoch 12 | Time: 7m 47s
	Train Loss: 0.4958 | Valid Loss: 0.7079
	Accuracy: 0.9943 | F1: 0.3108 | Dice: 0.3716 | Recall: 0.2874 | Precision: 0.5257 | Jaccard: 0.2129


                                                                      

Epoch 13 | Time: 7m 48s
	Train Loss: 0.4872 | Valid Loss: 0.6868
	Accuracy: 0.9947 | F1: 0.3328 | Dice: 0.3875 | Recall: 0.3059 | Precision: 0.5284 | Jaccard: 0.2274


                                                                      

Epoch 14 | Time: 7m 47s
	Train Loss: 0.4822 | Valid Loss: 0.7065
	Accuracy: 0.9945 | F1: 0.3125 | Dice: 0.3719 | Recall: 0.2880 | Precision: 0.5249 | Jaccard: 0.2134


                                                                      

Epoch 15 | Time: 7m 45s
	Train Loss: 0.4776 | Valid Loss: 0.6865
	Accuracy: 0.9922 | F1: 0.3287 | Dice: 0.3856 | Recall: 0.3228 | Precision: 0.4786 | Jaccard: 0.2260


                                                                      

Epoch 16 | Time: 7m 45s
	Train Loss: 0.4738 | Valid Loss: 0.7127
	Accuracy: 0.9941 | F1: 0.3088 | Dice: 0.3684 | Recall: 0.2802 | Precision: 0.5377 | Jaccard: 0.2124


                                                                      

Epoch 17 | Time: 7m 45s
	Train Loss: 0.4710 | Valid Loss: 0.6823
	Accuracy: 0.9917 | F1: 0.3301 | Dice: 0.3917 | Recall: 0.3322 | Precision: 0.4772 | Jaccard: 0.2253


                                                                      

Epoch 18 | Time: 7m 46s
	Train Loss: 0.4663 | Valid Loss: 0.7037
	Accuracy: 0.9944 | F1: 0.3181 | Dice: 0.3768 | Recall: 0.2873 | Precision: 0.5474 | Jaccard: 0.2183


                                                                      

Epoch 19 | Time: 7m 45s
	Train Loss: 0.4653 | Valid Loss: 0.6916
	Accuracy: 0.9909 | F1: 0.3251 | Dice: 0.3864 | Recall: 0.3152 | Precision: 0.4991 | Jaccard: 0.2237


                                                                      

Epoch 20 | Time: 7m 45s
	Train Loss: 0.4646 | Valid Loss: 0.6927
	Accuracy: 0.9938 | F1: 0.3260 | Dice: 0.3836 | Recall: 0.3067 | Precision: 0.5117 | Jaccard: 0.2237


                                                                      

Epoch 21 | Time: 7m 45s
	Train Loss: 0.4640 | Valid Loss: 0.7010
	Accuracy: 0.9934 | F1: 0.3195 | Dice: 0.3786 | Recall: 0.2972 | Precision: 0.5212 | Jaccard: 0.2174


                                                                      

Epoch 22 | Time: 7m 48s
	Train Loss: 0.4635 | Valid Loss: 0.6973
	Accuracy: 0.9907 | F1: 0.3206 | Dice: 0.3801 | Recall: 0.3065 | Precision: 0.5001 | Jaccard: 0.2210


                                                                      

Epoch 23 | Time: 7m 46s
	Train Loss: 0.4629 | Valid Loss: 0.7122
	Accuracy: 0.9944 | F1: 0.3086 | Dice: 0.3675 | Recall: 0.2785 | Precision: 0.5401 | Jaccard: 0.2115


                                                                      

Epoch 24 | Time: 7m 45s
	Train Loss: 0.4623 | Valid Loss: 0.7163
	Accuracy: 0.9943 | F1: 0.3050 | Dice: 0.3627 | Recall: 0.2752 | Precision: 0.5317 | Jaccard: 0.2080


                                                                      

Epoch 25 | Time: 7m 47s
	Train Loss: 0.4622 | Valid Loss: 0.6849
	Accuracy: 0.9901 | F1: 0.3315 | Dice: 0.3930 | Recall: 0.3242 | Precision: 0.4988 | Jaccard: 0.2265


                                                                      

Epoch 26 | Time: 7m 45s
	Train Loss: 0.4621 | Valid Loss: 0.7146
	Accuracy: 0.9919 | F1: 0.3063 | Dice: 0.3712 | Recall: 0.2827 | Precision: 0.5404 | Jaccard: 0.2101


                                                                      

Epoch 27 | Time: 7m 45s
	Train Loss: 0.4621 | Valid Loss: 0.7256
	Accuracy: 0.9920 | F1: 0.2938 | Dice: 0.3588 | Recall: 0.2757 | Precision: 0.5133 | Jaccard: 0.1987


                                                                      

Epoch 28 | Time: 7m 46s
	Train Loss: 0.4620 | Valid Loss: 0.6898
	Accuracy: 0.9897 | F1: 0.3280 | Dice: 0.3883 | Recall: 0.3153 | Precision: 0.5054 | Jaccard: 0.2247


                                                                      

Epoch 29 | Time: 7m 46s
	Train Loss: 0.4619 | Valid Loss: 0.6857
	Accuracy: 0.9900 | F1: 0.3299 | Dice: 0.3904 | Recall: 0.3245 | Precision: 0.4900 | Jaccard: 0.2248


                                                                      

Epoch 30 | Time: 7m 46s
	Train Loss: 0.4619 | Valid Loss: 0.6904
	Accuracy: 0.9897 | F1: 0.3259 | Dice: 0.3842 | Recall: 0.3180 | Precision: 0.4853 | Jaccard: 0.2242


                                                                      

Epoch 31 | Time: 7m 46s
	Train Loss: 0.4619 | Valid Loss: 0.7046
	Accuracy: 0.9946 | F1: 0.3181 | Dice: 0.3715 | Recall: 0.2813 | Precision: 0.5467 | Jaccard: 0.2173


                                                                      

Epoch 32 | Time: 7m 46s
	Train Loss: 0.4619 | Valid Loss: 0.6886
	Accuracy: 0.9924 | F1: 0.3306 | Dice: 0.3902 | Recall: 0.3134 | Precision: 0.5170 | Jaccard: 0.2268


                                                                      

Epoch 33 | Time: 7m 45s
	Train Loss: 0.4619 | Valid Loss: 0.6950
	Accuracy: 0.9922 | F1: 0.3238 | Dice: 0.3854 | Recall: 0.3068 | Precision: 0.5182 | Jaccard: 0.2205


                                                                      

Epoch 34 | Time: 7m 46s
	Train Loss: 0.4618 | Valid Loss: 0.7029
	Accuracy: 0.9909 | F1: 0.3164 | Dice: 0.3822 | Recall: 0.3022 | Precision: 0.5198 | Jaccard: 0.2152


                                                                      

Epoch 35 | Time: 7m 46s
	Train Loss: 0.4618 | Valid Loss: 0.6972
	Accuracy: 0.9924 | F1: 0.3228 | Dice: 0.3833 | Recall: 0.3027 | Precision: 0.5225 | Jaccard: 0.2216


                                                                      

Epoch 36 | Time: 7m 46s
	Train Loss: 0.4618 | Valid Loss: 0.6888
	Accuracy: 0.9903 | F1: 0.3282 | Dice: 0.3895 | Recall: 0.3185 | Precision: 0.5013 | Jaccard: 0.2245


                                                                      

Epoch 37 | Time: 7m 48s
	Train Loss: 0.4618 | Valid Loss: 0.7049
	Accuracy: 0.9922 | F1: 0.3174 | Dice: 0.3773 | Recall: 0.2903 | Precision: 0.5390 | Jaccard: 0.2180


                                                                      

Epoch 38 | Time: 7m 46s
	Train Loss: 0.4618 | Valid Loss: 0.6971
	Accuracy: 0.9926 | F1: 0.3240 | Dice: 0.3845 | Recall: 0.3020 | Precision: 0.5291 | Jaccard: 0.2224


                                                                      

Epoch 39 | Time: 7m 45s
	Train Loss: 0.4618 | Valid Loss: 0.6910
	Accuracy: 0.9922 | F1: 0.3279 | Dice: 0.3882 | Recall: 0.3122 | Precision: 0.5131 | Jaccard: 0.2244


                                                                      

Epoch 40 | Time: 7m 45s
	Train Loss: 0.4618 | Valid Loss: 0.6934
	Accuracy: 0.9916 | F1: 0.3253 | Dice: 0.3851 | Recall: 0.3090 | Precision: 0.5111 | Jaccard: 0.2234




In [12]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


  model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
Testing: 100%|██████████| 113/113 [00:21<00:00,  5.30it/s]

Jaccard: 0.2286, F1: 0.3406, Recall: 0.4037, Precision: 0.4456, Accuracy: 0.9910
FPS: 359.92493431144726



