In [1]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples

In [4]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits).view(logits.size(0), -1)
        targets = targets.view(targets.size(0), -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return (1 - tversky).mean()



In [None]:
class AttentionBlock(nn.Module):
    """Additive attention block for U-Net skip connections."""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # W_g: gating signal transform
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # W_x: skip connection transform
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # psi: attention coefficient
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # g: gating signal (from decoder), x: skip features (from encoder)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi  # attention


class TripleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 3 times."""
    def __init__(self, in_c, mid1_c, mid2_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid1_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid1_c)
        self.conv2 = nn.Conv2d(mid1_c, mid2_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(mid2_c)
        self.conv3 = nn.Conv2d(mid2_c, out_c, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class DoubleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 2 times."""
    def __init__(self, in_c, mid_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid_c)
        self.conv2 = nn.Conv2d(mid_c, out_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x


class UNet(nn.Module):
    """U-Net with attention gates on skip connections and additional subsampling concat."""
    def __init__(self):
        super().__init__()
        # Encoder
        self.down1 = TripleConv(1, 32, 32, 64)
        self.down2 = TripleConv(64, 64, 64, 128)
        self.down3 = DoubleConv(128, 128, 256)
        self.down4 = DoubleConv(256, 256, 256)
        self.pool  = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = DoubleConv(256, 512, 256)

        # Decoder up and conv
        self.up4  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = DoubleConv(256+256, 256, 256)
        self.up3  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = DoubleConv(256+256, 128, 128)
        self.up2  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = TripleConv(128+128, 64, 64, 64)
        self.up1  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = TripleConv(64+64, 32, 32, 32)

        # Attention blocks for skip connections
        self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.att1 = AttentionBlock(F_g=64,  F_l=64,  F_int=32)

        # Final subsample, concat and output
        self.final_pool       = nn.MaxPool2d(2, 2)
        self.final_upsample   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.out_conv         = nn.Conv2d(33, 1, kernel_size=1)

    def forward(self, x):
        input_image = x
        # Encoder
        x1  = self.down1(x)
        x1p = self.pool(x1)
        x2  = self.down2(x1p)
        x2p = self.pool(x2)
        x3  = self.down3(x2p)
        x3p = self.pool(x3)
        x4  = self.down4(x3p)
        x4p = self.pool(x4)

        # Bottleneck
        xb  = self.bottleneck(x4p)

        # Decoder + Attention
        d4  = self.up4(xb)
        x4a = self.att4(g=d4, x=x4)
        d4  = torch.cat([x4a, d4], dim=1)
        d4  = self.dec4(d4)

        d3  = self.up3(d4)
        x3a = self.att3(g=d3, x=x3)
        d3  = torch.cat([x3a, d3], dim=1)
        d3  = self.dec3(d3)

        d2  = self.up2(d3)
        x2a = self.att2(g=d2, x=x2)
        d2  = torch.cat([x2a, d2], dim=1)
        d2  = self.dec2(d2)

        d1  = self.up1(d2)
        x1a = self.att1(g=d1, x=x1)
        d1  = torch.cat([x1a, d1], dim=1)
        d1  = self.dec1(d1)

        # Additional subsampling & concatenation
        d1s = self.final_pool(d1)
        ins = self.final_pool(input_image)
        cat = torch.cat([d1s, ins], dim=1)
        out = self.final_upsample(cat)
        out = self.out_conv(out)
        return out


In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [7]:
seeding(42)
MODEL_NAME = "Green_Custom_Tversky_Full"
MODEL_DIRECTORY = "Green_Model_Custom_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_Custom_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
#  load full training set 
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test 
test_images = sorted(glob("./final_dataset/test/images/*"))
test_masks  = sorted(glob("./final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = TverskyLoss()

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [9]:
for epoch in range(50):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                  

Epoch 01 | Time: 0m 54s
	Train Loss: 0.9148 | Valid Loss: 0.8892
	Accuracy: 0.8770 | F1: 0.1406 | Dice: 0.1579 | Recall: 0.8988 | Precision: 0.0866 | Jaccard: 0.0832
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 02 | Time: 0m 53s
	Train Loss: 0.8538 | Valid Loss: 0.8153
	Accuracy: 0.9703 | F1: 0.2956 | Dice: 0.3344 | Recall: 0.7189 | Precision: 0.2179 | Jaccard: 0.1859
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 03 | Time: 0m 53s
	Train Loss: 0.7751 | Valid Loss: 0.7359
	Accuracy: 0.9808 | F1: 0.3835 | Dice: 0.4214 | Recall: 0.7033 | Precision: 0.3008 | Jaccard: 0.2489
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 04 | Time: 0m 53s
	Train Loss: 0.6811 | Valid Loss: 0.7034
	Accuracy: 0.9714 | F1: 0.3147 | Dice: 0.3481 | Recall: 0.7669 | Precision: 0.2252 | Jaccard: 0.2001
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 05 | Time: 0m 53s
	Train Loss: 0.5915 | Valid Loss: 0.6763
	Accuracy: 0.9774 | F1: 0.3380 | Dice: 0.3831 | Recall: 0.6639 | Precision: 0.2693 | Jaccard: 0.2123
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 06 | Time: 0m 53s
	Train Loss: 0.5392 | Valid Loss: 0.6279
	Accuracy: 0.9871 | F1: 0.4104 | Dice: 0.4282 | Recall: 0.4357 | Precision: 0.4210 | Jaccard: 0.2804
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 07 | Time: 0m 52s
	Train Loss: 0.4976 | Valid Loss: 0.6127
	Accuracy: 0.9859 | F1: 0.3877 | Dice: 0.4089 | Recall: 0.4370 | Precision: 0.3841 | Jaccard: 0.2615
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 08 | Time: 0m 53s
	Train Loss: 0.4385 | Valid Loss: 0.5123
	Accuracy: 0.9874 | F1: 0.4960 | Dice: 0.5169 | Recall: 0.5985 | Precision: 0.4549 | Jaccard: 0.3373
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 09 | Time: 0m 53s
	Train Loss: 0.4344 | Valid Loss: 0.6519
	Accuracy: 0.9799 | F1: 0.3211 | Dice: 0.3654 | Recall: 0.5143 | Precision: 0.2834 | Jaccard: 0.2030


                                                                  

Epoch 10 | Time: 0m 53s
	Train Loss: 0.4512 | Valid Loss: 0.5744
	Accuracy: 0.9879 | F1: 0.4327 | Dice: 0.4465 | Recall: 0.4288 | Precision: 0.4658 | Jaccard: 0.2955


                                                                  

Epoch 11 | Time: 0m 53s
	Train Loss: 0.3895 | Valid Loss: 0.5885
	Accuracy: 0.9864 | F1: 0.3920 | Dice: 0.4071 | Recall: 0.4034 | Precision: 0.4109 | Jaccard: 0.2627


                                                                  

Epoch 12 | Time: 0m 53s
	Train Loss: 0.3410 | Valid Loss: 0.5439
	Accuracy: 0.9875 | F1: 0.4471 | Dice: 0.4627 | Recall: 0.4896 | Precision: 0.4385 | Jaccard: 0.3067


                                                                  

Epoch 13 | Time: 0m 53s
	Train Loss: 0.3299 | Valid Loss: 0.5010
	Accuracy: 0.9854 | F1: 0.4662 | Dice: 0.4979 | Recall: 0.6415 | Precision: 0.4068 | Jaccard: 0.3112
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 14 | Time: 0m 53s
	Train Loss: 0.3274 | Valid Loss: 0.5281
	Accuracy: 0.9881 | F1: 0.4687 | Dice: 0.4799 | Recall: 0.4860 | Precision: 0.4739 | Jaccard: 0.3215


                                                                  

Epoch 15 | Time: 0m 53s
	Train Loss: 0.3063 | Valid Loss: 0.5216
	Accuracy: 0.9872 | F1: 0.4664 | Dice: 0.4824 | Recall: 0.5481 | Precision: 0.4308 | Jaccard: 0.3206


                                                                  

Epoch 16 | Time: 0m 51s
	Train Loss: 0.2851 | Valid Loss: 0.5358
	Accuracy: 0.9870 | F1: 0.4473 | Dice: 0.4692 | Recall: 0.5118 | Precision: 0.4331 | Jaccard: 0.3019


                                                                  

Epoch 17 | Time: 0m 50s
	Train Loss: 0.2806 | Valid Loss: 0.5498
	Accuracy: 0.9865 | F1: 0.4322 | Dice: 0.4617 | Recall: 0.5005 | Precision: 0.4284 | Jaccard: 0.2883


                                                                  

Epoch 18 | Time: 0m 50s
	Train Loss: 0.2971 | Valid Loss: 0.5127
	Accuracy: 0.9878 | F1: 0.4905 | Dice: 0.5084 | Recall: 0.5524 | Precision: 0.4709 | Jaccard: 0.3352


                                                                  

Epoch 19 | Time: 0m 50s
	Train Loss: 0.2588 | Valid Loss: 0.5000
	Accuracy: 0.9872 | F1: 0.4794 | Dice: 0.4944 | Recall: 0.5745 | Precision: 0.4340 | Jaccard: 0.3308
Best Green_Custom_Tversky_Full Saved


                                                                  

Epoch 20 | Time: 0m 51s
	Train Loss: 0.2609 | Valid Loss: 0.5236
	Accuracy: 0.9865 | F1: 0.4463 | Dice: 0.4668 | Recall: 0.5546 | Precision: 0.4030 | Jaccard: 0.3030


                                                                  

Epoch 21 | Time: 0m 51s
	Train Loss: 0.2548 | Valid Loss: 0.5484
	Accuracy: 0.9874 | F1: 0.4524 | Dice: 0.4695 | Recall: 0.4903 | Precision: 0.4504 | Jaccard: 0.3063


                                                                  

Epoch 22 | Time: 0m 51s
	Train Loss: 0.2394 | Valid Loss: 0.5160
	Accuracy: 0.9885 | F1: 0.4988 | Dice: 0.5085 | Recall: 0.4935 | Precision: 0.5244 | Jaccard: 0.3398


                                                                  

Epoch 23 | Time: 0m 50s
	Train Loss: 0.2710 | Valid Loss: 0.6613
	Accuracy: 0.9873 | F1: 0.3509 | Dice: 0.3634 | Recall: 0.2858 | Precision: 0.4991 | Jaccard: 0.2291


                                                                  

Epoch 24 | Time: 0m 50s
	Train Loss: 0.2440 | Valid Loss: 0.5190
	Accuracy: 0.9883 | F1: 0.4776 | Dice: 0.4914 | Recall: 0.5225 | Precision: 0.4637 | Jaccard: 0.3317


                                                                  

Epoch 25 | Time: 0m 50s
	Train Loss: 0.2193 | Valid Loss: 0.5177
	Accuracy: 0.9886 | F1: 0.4767 | Dice: 0.4883 | Recall: 0.4889 | Precision: 0.4877 | Jaccard: 0.3298


                                                                  

Epoch 26 | Time: 0m 50s
	Train Loss: 0.2013 | Valid Loss: 0.5019
	Accuracy: 0.9882 | F1: 0.4881 | Dice: 0.5015 | Recall: 0.5533 | Precision: 0.4586 | Jaccard: 0.3371


                                                                  

Epoch 27 | Time: 0m 51s
	Train Loss: 0.1975 | Valid Loss: 0.5058
	Accuracy: 0.9885 | F1: 0.4876 | Dice: 0.5002 | Recall: 0.5312 | Precision: 0.4725 | Jaccard: 0.3377


                                                                  

Epoch 28 | Time: 0m 50s
	Train Loss: 0.1929 | Valid Loss: 0.5241
	Accuracy: 0.9887 | F1: 0.4792 | Dice: 0.4889 | Recall: 0.4870 | Precision: 0.4909 | Jaccard: 0.3312


                                                                  

Epoch 29 | Time: 0m 50s
	Train Loss: 0.1920 | Valid Loss: 0.5159
	Accuracy: 0.9884 | F1: 0.4760 | Dice: 0.4864 | Recall: 0.5020 | Precision: 0.4717 | Jaccard: 0.3293


                                                                  

Epoch 30 | Time: 0m 50s
	Train Loss: 0.1908 | Valid Loss: 0.5072
	Accuracy: 0.9885 | F1: 0.4821 | Dice: 0.4951 | Recall: 0.5211 | Precision: 0.4715 | Jaccard: 0.3340


                                                                  

Epoch 31 | Time: 0m 50s
	Train Loss: 0.1872 | Valid Loss: 0.5218
	Accuracy: 0.9885 | F1: 0.4780 | Dice: 0.4898 | Recall: 0.5061 | Precision: 0.4745 | Jaccard: 0.3294


                                                                  

Epoch 32 | Time: 0m 50s
	Train Loss: 0.1846 | Valid Loss: 0.5314
	Accuracy: 0.9883 | F1: 0.4712 | Dice: 0.4836 | Recall: 0.5026 | Precision: 0.4660 | Jaccard: 0.3230


                                                                  

Epoch 33 | Time: 0m 50s
	Train Loss: 0.1840 | Valid Loss: 0.5322
	Accuracy: 0.9885 | F1: 0.4703 | Dice: 0.4822 | Recall: 0.4895 | Precision: 0.4752 | Jaccard: 0.3232


                                                                  

Epoch 34 | Time: 0m 50s
	Train Loss: 0.1838 | Valid Loss: 0.5276
	Accuracy: 0.9885 | F1: 0.4746 | Dice: 0.4849 | Recall: 0.4920 | Precision: 0.4779 | Jaccard: 0.3265


                                                                  

Epoch 35 | Time: 0m 53s
	Train Loss: 0.1828 | Valid Loss: 0.5314
	Accuracy: 0.9884 | F1: 0.4708 | Dice: 0.4843 | Recall: 0.5082 | Precision: 0.4625 | Jaccard: 0.3233


                                                                  

Epoch 36 | Time: 0m 53s
	Train Loss: 0.1827 | Valid Loss: 0.5287
	Accuracy: 0.9884 | F1: 0.4744 | Dice: 0.4857 | Recall: 0.5003 | Precision: 0.4720 | Jaccard: 0.3259


                                                                  

Epoch 37 | Time: 0m 53s
	Train Loss: 0.1832 | Valid Loss: 0.5329
	Accuracy: 0.9884 | F1: 0.4697 | Dice: 0.4829 | Recall: 0.5041 | Precision: 0.4635 | Jaccard: 0.3222


                                                                  

Epoch 38 | Time: 0m 52s
	Train Loss: 0.1827 | Valid Loss: 0.5339
	Accuracy: 0.9882 | F1: 0.4697 | Dice: 0.4827 | Recall: 0.5093 | Precision: 0.4588 | Jaccard: 0.3210


                                                                  

Epoch 39 | Time: 0m 53s
	Train Loss: 0.1818 | Valid Loss: 0.5340
	Accuracy: 0.9884 | F1: 0.4696 | Dice: 0.4824 | Recall: 0.5011 | Precision: 0.4650 | Jaccard: 0.3222


                                                                  

Epoch 40 | Time: 0m 52s
	Train Loss: 0.1824 | Valid Loss: 0.5325
	Accuracy: 0.9883 | F1: 0.4710 | Dice: 0.4850 | Recall: 0.5148 | Precision: 0.4585 | Jaccard: 0.3225


                                                                  

Epoch 41 | Time: 0m 52s
	Train Loss: 0.1815 | Valid Loss: 0.5299
	Accuracy: 0.9887 | F1: 0.4787 | Dice: 0.4889 | Recall: 0.4854 | Precision: 0.4925 | Jaccard: 0.3286


                                                                  

Epoch 42 | Time: 0m 52s
	Train Loss: 0.1822 | Valid Loss: 0.5357
	Accuracy: 0.9886 | F1: 0.4717 | Dice: 0.4832 | Recall: 0.4857 | Precision: 0.4808 | Jaccard: 0.3232


                                                                  

Epoch 43 | Time: 0m 52s
	Train Loss: 0.1818 | Valid Loss: 0.5310
	Accuracy: 0.9884 | F1: 0.4704 | Dice: 0.4837 | Recall: 0.5032 | Precision: 0.4657 | Jaccard: 0.3233


                                                                  

Epoch 44 | Time: 0m 53s
	Train Loss: 0.1822 | Valid Loss: 0.5479
	Accuracy: 0.9883 | F1: 0.4599 | Dice: 0.4724 | Recall: 0.4877 | Precision: 0.4580 | Jaccard: 0.3130


                                                                  

Epoch 45 | Time: 0m 52s
	Train Loss: 0.1819 | Valid Loss: 0.5295
	Accuracy: 0.9885 | F1: 0.4729 | Dice: 0.4851 | Recall: 0.4968 | Precision: 0.4740 | Jaccard: 0.3247


                                                                  

Epoch 46 | Time: 0m 53s
	Train Loss: 0.1821 | Valid Loss: 0.5346
	Accuracy: 0.9884 | F1: 0.4723 | Dice: 0.4840 | Recall: 0.4975 | Precision: 0.4713 | Jaccard: 0.3237


                                                                  

Epoch 47 | Time: 0m 53s
	Train Loss: 0.1830 | Valid Loss: 0.5249
	Accuracy: 0.9883 | F1: 0.4746 | Dice: 0.4892 | Recall: 0.5180 | Precision: 0.4634 | Jaccard: 0.3264


                                                                  

Epoch 48 | Time: 0m 52s
	Train Loss: 0.1819 | Valid Loss: 0.5261
	Accuracy: 0.9884 | F1: 0.4774 | Dice: 0.4891 | Recall: 0.5041 | Precision: 0.4749 | Jaccard: 0.3277


                                                                  

Epoch 49 | Time: 0m 52s
	Train Loss: 0.1820 | Valid Loss: 0.5337
	Accuracy: 0.9882 | F1: 0.4679 | Dice: 0.4822 | Recall: 0.5172 | Precision: 0.4517 | Jaccard: 0.3201


                                                                  

Epoch 50 | Time: 0m 51s
	Train Loss: 0.1826 | Valid Loss: 0.5276
	Accuracy: 0.9885 | F1: 0.4779 | Dice: 0.4899 | Recall: 0.5072 | Precision: 0.4738 | Jaccard: 0.3281




In [10]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


Testing: 100%|██████████| 14/14 [00:05<00:00,  2.75it/s]

Jaccard: 0.4061, F1: 0.5707, Recall: 0.6969, Precision: 0.4965, Accuracy: 0.9893
FPS: 174.98556494582385



