In [1]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
#Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples

In [4]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits).view(logits.size(0), -1)
        targets = targets.view(targets.size(0), -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return (1 - tversky).mean()



In [5]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p

class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out * s

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(in_c[0]+out_c, out_c)

    def forward(self, x, s):
        x = self.up(x)
        s = self.ag(x, s)
        x = torch.cat([x, s], axis=1)
        x = self.c1(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.e1 = encoder_block(1, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        self.b1 = conv_block(512, 1024)

        self.d1 = decoder_block([1024, 512], 512)
        self.d2 = decoder_block([512, 256], 256)
        self.d3 = decoder_block([256, 128], 128)
        self.d4 = decoder_block([128, 64], 64)
        
        self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b1 = self.b1(p4)

        d1 = self.d1(b1, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.output(d4)
        return output

In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [7]:
seeding(42)
MODEL_NAME = "Green_Attention_UNet_Tversky_Full"
MODEL_DIRECTORY = "Green_Attention_Model_UNet_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Attention_Results_UNet_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# load full training set 
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test
test_images = sorted(glob("./final_dataset/test/images/*"))
test_masks  = sorted(glob("./final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = TverskyLoss()

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [9]:
for epoch in range(50):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                  

Epoch 01 | Time: 1m 38s
	Train Loss: 0.9170 | Valid Loss: 0.8940
	Accuracy: 0.9508 | F1: 0.2120 | Dice: 0.2475 | Recall: 0.7041 | Precision: 0.1501 | Jaccard: 0.1283
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 02 | Time: 1m 33s
	Train Loss: 0.8677 | Valid Loss: 0.8537
	Accuracy: 0.9503 | F1: 0.2199 | Dice: 0.2636 | Recall: 0.7378 | Precision: 0.1605 | Jaccard: 0.1341
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 03 | Time: 1m 34s
	Train Loss: 0.8025 | Valid Loss: 0.7938
	Accuracy: 0.9615 | F1: 0.2553 | Dice: 0.2943 | Recall: 0.7247 | Precision: 0.1847 | Jaccard: 0.1604
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 04 | Time: 1m 31s
	Train Loss: 0.7108 | Valid Loss: 0.7452
	Accuracy: 0.9681 | F1: 0.2804 | Dice: 0.3260 | Recall: 0.7189 | Precision: 0.2108 | Jaccard: 0.1767
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 05 | Time: 1m 31s
	Train Loss: 0.6157 | Valid Loss: 0.6514
	Accuracy: 0.9846 | F1: 0.4057 | Dice: 0.4358 | Recall: 0.5285 | Precision: 0.3708 | Jaccard: 0.2662
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 06 | Time: 1m 30s
	Train Loss: 0.5568 | Valid Loss: 0.6171
	Accuracy: 0.9867 | F1: 0.4352 | Dice: 0.4609 | Recall: 0.4920 | Precision: 0.4335 | Jaccard: 0.2909
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 07 | Time: 1m 29s
	Train Loss: 0.4947 | Valid Loss: 0.6573
	Accuracy: 0.9858 | F1: 0.3760 | Dice: 0.3968 | Recall: 0.3462 | Precision: 0.4647 | Jaccard: 0.2405


                                                                  

Epoch 08 | Time: 1m 30s
	Train Loss: 0.4471 | Valid Loss: 0.5154
	Accuracy: 0.9863 | F1: 0.4782 | Dice: 0.5060 | Recall: 0.6385 | Precision: 0.4190 | Jaccard: 0.3243
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 09 | Time: 1m 26s
	Train Loss: 0.4195 | Valid Loss: 0.5629
	Accuracy: 0.9874 | F1: 0.4507 | Dice: 0.4727 | Recall: 0.4898 | Precision: 0.4568 | Jaccard: 0.3033


                                                                  

Epoch 10 | Time: 1m 26s
	Train Loss: 0.3889 | Valid Loss: 0.5442
	Accuracy: 0.9869 | F1: 0.4439 | Dice: 0.4592 | Recall: 0.5383 | Precision: 0.4004 | Jaccard: 0.3021


                                                                  

Epoch 11 | Time: 1m 26s
	Train Loss: 0.3924 | Valid Loss: 0.5087
	Accuracy: 0.9863 | F1: 0.4694 | Dice: 0.4922 | Recall: 0.6292 | Precision: 0.4043 | Jaccard: 0.3181
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 12 | Time: 1m 25s
	Train Loss: 0.3615 | Valid Loss: 0.5195
	Accuracy: 0.9876 | F1: 0.4951 | Dice: 0.5068 | Recall: 0.5905 | Precision: 0.4439 | Jaccard: 0.3435


                                                                  

Epoch 13 | Time: 1m 25s
	Train Loss: 0.3360 | Valid Loss: 0.5584
	Accuracy: 0.9886 | F1: 0.4645 | Dice: 0.4728 | Recall: 0.4416 | Precision: 0.5088 | Jaccard: 0.3218


                                                                  

Epoch 14 | Time: 1m 25s
	Train Loss: 0.3258 | Valid Loss: 0.5079
	Accuracy: 0.9879 | F1: 0.5268 | Dice: 0.5381 | Recall: 0.5693 | Precision: 0.5102 | Jaccard: 0.3668
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 15 | Time: 1m 25s
	Train Loss: 0.3029 | Valid Loss: 0.5321
	Accuracy: 0.9877 | F1: 0.4924 | Dice: 0.5009 | Recall: 0.5267 | Precision: 0.4775 | Jaccard: 0.3449


                                                                  

Epoch 16 | Time: 1m 30s
	Train Loss: 0.3127 | Valid Loss: 0.4809
	Accuracy: 0.9872 | F1: 0.4983 | Dice: 0.5200 | Recall: 0.6120 | Precision: 0.4520 | Jaccard: 0.3431
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 17 | Time: 1m 31s
	Train Loss: 0.2972 | Valid Loss: 0.5814
	Accuracy: 0.9883 | F1: 0.4458 | Dice: 0.4539 | Recall: 0.4086 | Precision: 0.5105 | Jaccard: 0.3041


                                                                  

Epoch 18 | Time: 1m 30s
	Train Loss: 0.2775 | Valid Loss: 0.4889
	Accuracy: 0.9884 | F1: 0.5287 | Dice: 0.5451 | Recall: 0.6109 | Precision: 0.4921 | Jaccard: 0.3694


                                                                  

Epoch 19 | Time: 1m 30s
	Train Loss: 0.2653 | Valid Loss: 0.4824
	Accuracy: 0.9884 | F1: 0.5346 | Dice: 0.5462 | Recall: 0.6090 | Precision: 0.4951 | Jaccard: 0.3752


                                                                  

Epoch 20 | Time: 1m 30s
	Train Loss: 0.2507 | Valid Loss: 0.4734
	Accuracy: 0.9889 | F1: 0.5424 | Dice: 0.5561 | Recall: 0.5687 | Precision: 0.5440 | Jaccard: 0.3800
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 21 | Time: 1m 30s
	Train Loss: 0.2618 | Valid Loss: 0.5042
	Accuracy: 0.9873 | F1: 0.5039 | Dice: 0.5275 | Recall: 0.5661 | Precision: 0.4937 | Jaccard: 0.3498


                                                                  

Epoch 22 | Time: 1m 29s
	Train Loss: 0.2623 | Valid Loss: 0.5252
	Accuracy: 0.9883 | F1: 0.4897 | Dice: 0.4995 | Recall: 0.4794 | Precision: 0.5212 | Jaccard: 0.3359


                                                                  

Epoch 23 | Time: 1m 32s
	Train Loss: 0.2487 | Valid Loss: 0.5486
	Accuracy: 0.9891 | F1: 0.4846 | Dice: 0.4979 | Recall: 0.4376 | Precision: 0.5775 | Jaccard: 0.3280


                                                                  

Epoch 24 | Time: 1m 30s
	Train Loss: 0.2475 | Valid Loss: 0.4609
	Accuracy: 0.9878 | F1: 0.5225 | Dice: 0.5528 | Recall: 0.6545 | Precision: 0.4785 | Jaccard: 0.3639
Best Green_Attention_UNet_Tversky_Full Saved


                                                                  

Epoch 25 | Time: 1m 28s
	Train Loss: 0.2535 | Valid Loss: 0.5606
	Accuracy: 0.9861 | F1: 0.4440 | Dice: 0.4694 | Recall: 0.5772 | Precision: 0.3955 | Jaccard: 0.2995


                                                                  

Epoch 26 | Time: 1m 28s
	Train Loss: 0.2394 | Valid Loss: 0.5096
	Accuracy: 0.9879 | F1: 0.4953 | Dice: 0.5148 | Recall: 0.5594 | Precision: 0.4767 | Jaccard: 0.3432


                                                                  

Epoch 27 | Time: 1m 29s
	Train Loss: 0.2246 | Valid Loss: 0.4643
	Accuracy: 0.9885 | F1: 0.5593 | Dice: 0.5750 | Recall: 0.6303 | Precision: 0.5286 | Jaccard: 0.3913


                                                                  

Epoch 28 | Time: 1m 30s
	Train Loss: 0.2304 | Valid Loss: 0.4759
	Accuracy: 0.9882 | F1: 0.5247 | Dice: 0.5407 | Recall: 0.5882 | Precision: 0.5003 | Jaccard: 0.3666


                                                                  

Epoch 29 | Time: 1m 29s
	Train Loss: 0.2216 | Valid Loss: 0.4994
	Accuracy: 0.9881 | F1: 0.5151 | Dice: 0.5358 | Recall: 0.5859 | Precision: 0.4936 | Jaccard: 0.3556


                                                                  

Epoch 30 | Time: 1m 30s
	Train Loss: 0.2267 | Valid Loss: 0.5026
	Accuracy: 0.9878 | F1: 0.5205 | Dice: 0.5383 | Recall: 0.6112 | Precision: 0.4809 | Jaccard: 0.3667


                                                                  

Epoch 31 | Time: 1m 30s
	Train Loss: 0.1984 | Valid Loss: 0.4951
	Accuracy: 0.9886 | F1: 0.5351 | Dice: 0.5551 | Recall: 0.5839 | Precision: 0.5290 | Jaccard: 0.3767


                                                                  

Epoch 32 | Time: 1m 29s
	Train Loss: 0.1945 | Valid Loss: 0.4936
	Accuracy: 0.9886 | F1: 0.5350 | Dice: 0.5537 | Recall: 0.5783 | Precision: 0.5311 | Jaccard: 0.3745


                                                                  

Epoch 33 | Time: 1m 27s
	Train Loss: 0.1911 | Valid Loss: 0.4954
	Accuracy: 0.9886 | F1: 0.5313 | Dice: 0.5498 | Recall: 0.5738 | Precision: 0.5278 | Jaccard: 0.3714


                                                                  

Epoch 34 | Time: 1m 28s
	Train Loss: 0.1873 | Valid Loss: 0.4823
	Accuracy: 0.9886 | F1: 0.5410 | Dice: 0.5631 | Recall: 0.5972 | Precision: 0.5327 | Jaccard: 0.3790


                                                                  

Epoch 35 | Time: 1m 28s
	Train Loss: 0.1865 | Valid Loss: 0.4807
	Accuracy: 0.9887 | F1: 0.5436 | Dice: 0.5640 | Recall: 0.5918 | Precision: 0.5386 | Jaccard: 0.3802


                                                                  

Epoch 36 | Time: 1m 28s
	Train Loss: 0.1838 | Valid Loss: 0.4829
	Accuracy: 0.9889 | F1: 0.5483 | Dice: 0.5657 | Recall: 0.5788 | Precision: 0.5532 | Jaccard: 0.3846


                                                                  

Epoch 37 | Time: 1m 28s
	Train Loss: 0.1813 | Valid Loss: 0.4962
	Accuracy: 0.9888 | F1: 0.5342 | Dice: 0.5510 | Recall: 0.5537 | Precision: 0.5483 | Jaccard: 0.3724


                                                                  

Epoch 38 | Time: 1m 28s
	Train Loss: 0.1809 | Valid Loss: 0.4753
	Accuracy: 0.9887 | F1: 0.5453 | Dice: 0.5659 | Recall: 0.5980 | Precision: 0.5371 | Jaccard: 0.3819


                                                                  

Epoch 39 | Time: 1m 28s
	Train Loss: 0.1809 | Valid Loss: 0.4916
	Accuracy: 0.9887 | F1: 0.5356 | Dice: 0.5537 | Recall: 0.5639 | Precision: 0.5439 | Jaccard: 0.3737


                                                                  

Epoch 40 | Time: 1m 28s
	Train Loss: 0.1806 | Valid Loss: 0.4889
	Accuracy: 0.9889 | F1: 0.5409 | Dice: 0.5587 | Recall: 0.5625 | Precision: 0.5551 | Jaccard: 0.3785


                                                                  

Epoch 41 | Time: 1m 28s
	Train Loss: 0.1797 | Valid Loss: 0.4839
	Accuracy: 0.9886 | F1: 0.5400 | Dice: 0.5632 | Recall: 0.5950 | Precision: 0.5347 | Jaccard: 0.3774


                                                                  

Epoch 42 | Time: 1m 31s
	Train Loss: 0.1803 | Valid Loss: 0.4852
	Accuracy: 0.9888 | F1: 0.5413 | Dice: 0.5608 | Recall: 0.5766 | Precision: 0.5458 | Jaccard: 0.3786


                                                                  

Epoch 43 | Time: 1m 31s
	Train Loss: 0.1793 | Valid Loss: 0.4829
	Accuracy: 0.9887 | F1: 0.5427 | Dice: 0.5628 | Recall: 0.5901 | Precision: 0.5380 | Jaccard: 0.3793


                                                                  

Epoch 44 | Time: 1m 31s
	Train Loss: 0.1799 | Valid Loss: 0.4943
	Accuracy: 0.9886 | F1: 0.5324 | Dice: 0.5510 | Recall: 0.5630 | Precision: 0.5396 | Jaccard: 0.3707


                                                                  

Epoch 45 | Time: 1m 31s
	Train Loss: 0.1793 | Valid Loss: 0.4889
	Accuracy: 0.9887 | F1: 0.5346 | Dice: 0.5531 | Recall: 0.5602 | Precision: 0.5462 | Jaccard: 0.3727


                                                                  

Epoch 46 | Time: 1m 31s
	Train Loss: 0.1803 | Valid Loss: 0.5014
	Accuracy: 0.9889 | F1: 0.5274 | Dice: 0.5434 | Recall: 0.5303 | Precision: 0.5572 | Jaccard: 0.3670


                                                                  

Epoch 47 | Time: 1m 29s
	Train Loss: 0.1792 | Valid Loss: 0.4842
	Accuracy: 0.9886 | F1: 0.5347 | Dice: 0.5552 | Recall: 0.5799 | Precision: 0.5326 | Jaccard: 0.3727


                                                                  

Epoch 48 | Time: 1m 26s
	Train Loss: 0.1794 | Valid Loss: 0.4865
	Accuracy: 0.9887 | F1: 0.5376 | Dice: 0.5571 | Recall: 0.5748 | Precision: 0.5404 | Jaccard: 0.3755


                                                                  

Epoch 49 | Time: 1m 26s
	Train Loss: 0.1798 | Valid Loss: 0.4814
	Accuracy: 0.9887 | F1: 0.5444 | Dice: 0.5660 | Recall: 0.5966 | Precision: 0.5385 | Jaccard: 0.3808


                                                                  

Epoch 50 | Time: 1m 26s
	Train Loss: 0.1788 | Valid Loss: 0.4872
	Accuracy: 0.9887 | F1: 0.5354 | Dice: 0.5563 | Recall: 0.5778 | Precision: 0.5364 | Jaccard: 0.3738




In [10]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


Testing: 100%|██████████| 14/14 [00:05<00:00,  2.55it/s]

Jaccard: 0.4198, F1: 0.5813, Recall: 0.7127, Precision: 0.5137, Accuracy: 0.9902
FPS: 191.7796894697994



