In [1]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
#  Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
#  Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples


In [4]:
class Focal_Tversky(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=1.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta  = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, logits, targets):
    
        probs = torch.sigmoid(logits)
        N = targets.size(0)

        probs  = probs.view(N, -1)
        targets = targets.view(N, -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        smooth = torch.tensor(self.smooth, device=probs.device, dtype=probs.dtype)
        tversky = (TP + smooth) / (TP + self.alpha * FP + self.beta * FN + smooth)

        loss = (1 - tversky) ** self.gamma

        return loss.mean()


In [None]:

#  DoubleConv and TripleConv

class TripleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 3 times."""
    def __init__(self, in_c, mid1_c, mid2_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid1_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid1_c)
        self.conv2 = nn.Conv2d(mid1_c, mid2_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(mid2_c)
        self.conv3 = nn.Conv2d(mid2_c, out_c, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x

class DoubleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 2 times."""
    def __init__(self, in_c, mid_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid_c)
        self.conv2 = nn.Conv2d(mid_c, out_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x


#  UNetRetina (with additional subsampling and concatenation)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        #  ENCODER 
        # Block 1: 3 conv -> (32, 32, 64)
        self.down1 = TripleConv(
            in_c=1,       # green channel only
            mid1_c=32,
            mid2_c=32,
            out_c=64
        )
        # Block 2: 3 conv -> (64, 64, 128)
        self.down2 = TripleConv(
            in_c=64,
            mid1_c=64,
            mid2_c=64,
            out_c=128
        )
        # Block 3: 2 conv -> (128, 128, 256)
        self.down3 = DoubleConv(
            in_c=128,
            mid_c=128,
            out_c=256
        )
        # Block 4: 2 conv -> (256, 256, 256)
        self.down4 = DoubleConv(
            in_c=256,
            mid_c=256,
            out_c=256
        )
        
        self.pool = nn.MaxPool2d(2, 2)
        
        # BOTTLENECK 
        # Bottleneck: 2 conv -> (256 -> 512 -> 256)
        self.bottleneck = DoubleConv(
            in_c=256,
            mid_c=512,
            out_c=256
        )
        
        #  DECODER 
        # Each decoder block: upsample, concat skip connection, then decode.
        self.up4  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = DoubleConv(in_c=256+256, mid_c=256, out_c=256)
        
        self.up3  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = DoubleConv(in_c=256+256, mid_c=128, out_c=128)
        
        self.up2  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = TripleConv(in_c=128+128, mid1_c=64, mid2_c=64, out_c=64)
        
        self.up1  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = TripleConv(in_c=64+64, mid1_c=32, mid2_c=32, out_c=32)
        
        #  ADDITIONAL SUBSAMPLING & CONCATENATION 
        self.final_pool = nn.MaxPool2d(2, 2)  # Additional subsampling step
        self.final_upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # After concatenation: 32 channels (from decoder) + 1 channel (from input) = 33 channels
        self.out_conv = nn.Conv2d(33, 1, kernel_size=1)
        

    def forward(self, x):
        # Encoder 
        input_image = x
        
        # Block 1
        x1 = self.down1(x)    # -> 64 channels
        x1p = self.pool(x1)   # subsampled
        
        # Block 2
        x2 = self.down2(x1p)  # -> 128 channels
        x2p = self.pool(x2)   # subsampled
        
        # Block 3
        x3 = self.down3(x2p)  # -> 256 channels
        x3p = self.pool(x3)   # subsampled
        
        # Block 4
        x4 = self.down4(x3p)  # -> 256 channels
        x4p = self.pool(x4)   # subsampled
        
        # Bottleneck 
        xb = self.bottleneck(x4p)  # 256 -> 512 -> 256
        
        #  Decoder 
        xd4 = self.up4(xb)               
        xd4 = torch.cat([x4, xd4], dim=1)  
        xd4 = self.dec4(xd4)             
        
        xd3 = self.up3(xd4)              
        xd3 = torch.cat([x3, xd3], dim=1) 
        xd3 = self.dec3(xd3)             
        
        xd2 = self.up2(xd3)              
        xd2 = torch.cat([x2, xd2], dim=1) 
        xd2 = self.dec2(xd2)             
        
        xd1 = self.up1(xd2)              
        xd1 = torch.cat([x1, xd1], dim=1) 
        xd1 = self.dec1(xd1)             
        
        #  Additional Subsampling & Concatenation 
        xd1_sub = self.final_pool(xd1)          # [B,32,H/2,W/2]
        input_sub = self.final_pool(input_image) # [B,1,H/2,W/2]
        final_cat = torch.cat([xd1_sub, input_sub], dim=1)  # [B,33,H/2,W/2]
        final_up = self.final_upsample(final_cat)  #  [B,33,H,W]
        
        out = self.out_conv(final_up)  # [B,1,H,W]
        
        return out


In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [7]:
seeding(42)
MODEL_NAME = "Green_Custom_Focal_Tversky_Full"
MODEL_DIRECTORY = "Green_Model_Custom_Focal_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_Custom_Focal_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# load full training set 
train_images = sorted(glob("../final_dataset/train/images/*"))
train_masks  = sorted(glob("../final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test
test_images = sorted(glob("../final_dataset/test/images/*"))
test_masks  = sorted(glob("../final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = Focal_Tversky()

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [9]:
for epoch in range(50):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                    

Epoch 01 | Time: 12m 34s
	Train Loss: 0.7954 | Valid Loss: 0.5783
	Accuracy: 0.9897 | F1: 0.3506 | Dice: 0.4164 | Recall: 0.4953 | Precision: 0.3592 | Jaccard: 0.2193
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 02 | Time: 11m 45s
	Train Loss: 0.5675 | Valid Loss: 0.6188
	Accuracy: 0.9921 | F1: 0.3685 | Dice: 0.4205 | Recall: 0.2974 | Precision: 0.7176 | Jaccard: 0.2391


                                                                    

Epoch 03 | Time: 11m 44s
	Train Loss: 0.5286 | Valid Loss: 0.5209
	Accuracy: 0.9921 | F1: 0.4266 | Dice: 0.4957 | Recall: 0.4277 | Precision: 0.5894 | Jaccard: 0.2806
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 04 | Time: 11m 44s
	Train Loss: 0.5095 | Valid Loss: 0.5464
	Accuracy: 0.9923 | F1: 0.4238 | Dice: 0.4849 | Recall: 0.3921 | Precision: 0.6353 | Jaccard: 0.2808


                                                                    

Epoch 05 | Time: 11m 43s
	Train Loss: 0.5032 | Valid Loss: 0.5723
	Accuracy: 0.9920 | F1: 0.3878 | Dice: 0.4561 | Recall: 0.3496 | Precision: 0.6559 | Jaccard: 0.2503


                                                                    

Epoch 06 | Time: 11m 45s
	Train Loss: 0.4879 | Valid Loss: 0.5667
	Accuracy: 0.9922 | F1: 0.4042 | Dice: 0.4617 | Recall: 0.3706 | Precision: 0.6119 | Jaccard: 0.2649


                                                                    

Epoch 07 | Time: 11m 44s
	Train Loss: 0.4690 | Valid Loss: 0.4962
	Accuracy: 0.9925 | F1: 0.4745 | Dice: 0.5280 | Recall: 0.4504 | Precision: 0.6377 | Jaccard: 0.3219
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 08 | Time: 11m 43s
	Train Loss: 0.4701 | Valid Loss: 0.5368
	Accuracy: 0.9922 | F1: 0.4471 | Dice: 0.5063 | Recall: 0.4210 | Precision: 0.6350 | Jaccard: 0.2993


                                                                    

Epoch 09 | Time: 11m 44s
	Train Loss: 0.4539 | Valid Loss: 0.5431
	Accuracy: 0.9915 | F1: 0.3995 | Dice: 0.4626 | Recall: 0.4457 | Precision: 0.4809 | Jaccard: 0.2594


                                                                    

Epoch 10 | Time: 11m 43s
	Train Loss: 0.4461 | Valid Loss: 0.5047
	Accuracy: 0.9924 | F1: 0.4601 | Dice: 0.5230 | Recall: 0.4589 | Precision: 0.6078 | Jaccard: 0.3122


                                                                    

Epoch 11 | Time: 11m 50s
	Train Loss: 0.4403 | Valid Loss: 0.5232
	Accuracy: 0.9924 | F1: 0.4486 | Dice: 0.5163 | Recall: 0.4252 | Precision: 0.6573 | Jaccard: 0.3019


                                                                    

Epoch 12 | Time: 11m 42s
	Train Loss: 0.4329 | Valid Loss: 0.5307
	Accuracy: 0.9923 | F1: 0.4426 | Dice: 0.5070 | Recall: 0.4116 | Precision: 0.6602 | Jaccard: 0.2968


                                                                    

Epoch 13 | Time: 11m 13s
	Train Loss: 0.4302 | Valid Loss: 0.5088
	Accuracy: 0.9924 | F1: 0.4537 | Dice: 0.5234 | Recall: 0.4417 | Precision: 0.6420 | Jaccard: 0.3076


                                                                    

Epoch 14 | Time: 5m 44s
	Train Loss: 0.4035 | Valid Loss: 0.5013
	Accuracy: 0.9926 | F1: 0.4725 | Dice: 0.5338 | Recall: 0.4468 | Precision: 0.6629 | Jaccard: 0.3215


                                                                    

Epoch 15 | Time: 5m 43s
	Train Loss: 0.3960 | Valid Loss: 0.4945
	Accuracy: 0.9925 | F1: 0.4703 | Dice: 0.5353 | Recall: 0.4577 | Precision: 0.6445 | Jaccard: 0.3196
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 16 | Time: 5m 45s
	Train Loss: 0.3920 | Valid Loss: 0.4916
	Accuracy: 0.9925 | F1: 0.4752 | Dice: 0.5364 | Recall: 0.4625 | Precision: 0.6385 | Jaccard: 0.3238
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 17 | Time: 5m 43s
	Train Loss: 0.3888 | Valid Loss: 0.4936
	Accuracy: 0.9924 | F1: 0.4792 | Dice: 0.5390 | Recall: 0.4701 | Precision: 0.6315 | Jaccard: 0.3269


                                                                    

Epoch 18 | Time: 5m 43s
	Train Loss: 0.3856 | Valid Loss: 0.4912
	Accuracy: 0.9923 | F1: 0.4770 | Dice: 0.5403 | Recall: 0.4735 | Precision: 0.6290 | Jaccard: 0.3258
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 19 | Time: 5m 46s
	Train Loss: 0.3831 | Valid Loss: 0.4996
	Accuracy: 0.9923 | F1: 0.4725 | Dice: 0.5350 | Recall: 0.4650 | Precision: 0.6298 | Jaccard: 0.3224


                                                                    

Epoch 20 | Time: 5m 44s
	Train Loss: 0.3809 | Valid Loss: 0.4949
	Accuracy: 0.9923 | F1: 0.4721 | Dice: 0.5338 | Recall: 0.4592 | Precision: 0.6372 | Jaccard: 0.3222


                                                                    

Epoch 21 | Time: 5m 43s
	Train Loss: 0.3788 | Valid Loss: 0.4918
	Accuracy: 0.9923 | F1: 0.4780 | Dice: 0.5403 | Recall: 0.4673 | Precision: 0.6402 | Jaccard: 0.3269


                                                                    

Epoch 22 | Time: 5m 43s
	Train Loss: 0.3772 | Valid Loss: 0.4919
	Accuracy: 0.9925 | F1: 0.4788 | Dice: 0.5412 | Recall: 0.4742 | Precision: 0.6303 | Jaccard: 0.3271


                                                                    

Epoch 23 | Time: 5m 43s
	Train Loss: 0.3747 | Valid Loss: 0.4864
	Accuracy: 0.9925 | F1: 0.4851 | Dice: 0.5413 | Recall: 0.4687 | Precision: 0.6406 | Jaccard: 0.3320
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 24 | Time: 5m 43s
	Train Loss: 0.3732 | Valid Loss: 0.4792
	Accuracy: 0.9922 | F1: 0.4815 | Dice: 0.5371 | Recall: 0.4870 | Precision: 0.5987 | Jaccard: 0.3308
Best Green_Custom_Focal_Tversky_Full Saved


                                                                    

Epoch 25 | Time: 5m 43s
	Train Loss: 0.3715 | Valid Loss: 0.4843
	Accuracy: 0.9925 | F1: 0.4835 | Dice: 0.5412 | Recall: 0.4707 | Precision: 0.6366 | Jaccard: 0.3306


                                                                    

Epoch 26 | Time: 5m 45s
	Train Loss: 0.3692 | Valid Loss: 0.4922
	Accuracy: 0.9927 | F1: 0.4811 | Dice: 0.5366 | Recall: 0.4560 | Precision: 0.6518 | Jaccard: 0.3287


                                                                    

Epoch 27 | Time: 5m 44s
	Train Loss: 0.3685 | Valid Loss: 0.4931
	Accuracy: 0.9926 | F1: 0.4796 | Dice: 0.5382 | Recall: 0.4532 | Precision: 0.6625 | Jaccard: 0.3279


                                                                    

Epoch 28 | Time: 5m 43s
	Train Loss: 0.3671 | Valid Loss: 0.4856
	Accuracy: 0.9921 | F1: 0.4804 | Dice: 0.5378 | Recall: 0.4733 | Precision: 0.6224 | Jaccard: 0.3306


                                                                    

Epoch 29 | Time: 5m 45s
	Train Loss: 0.3653 | Valid Loss: 0.4821
	Accuracy: 0.9922 | F1: 0.4833 | Dice: 0.5364 | Recall: 0.4766 | Precision: 0.6135 | Jaccard: 0.3320


                                                                    

Epoch 30 | Time: 5m 44s
	Train Loss: 0.3646 | Valid Loss: 0.4834
	Accuracy: 0.9925 | F1: 0.4862 | Dice: 0.5385 | Recall: 0.4712 | Precision: 0.6281 | Jaccard: 0.3331


                                                                    

Epoch 31 | Time: 5m 43s
	Train Loss: 0.3619 | Valid Loss: 0.4843
	Accuracy: 0.9926 | F1: 0.4835 | Dice: 0.5365 | Recall: 0.4662 | Precision: 0.6319 | Jaccard: 0.3307


                                                                    

Epoch 32 | Time: 5m 45s
	Train Loss: 0.3607 | Valid Loss: 0.4874
	Accuracy: 0.9925 | F1: 0.4832 | Dice: 0.5372 | Recall: 0.4647 | Precision: 0.6364 | Jaccard: 0.3312


                                                                    

Epoch 33 | Time: 5m 43s
	Train Loss: 0.3601 | Valid Loss: 0.4837
	Accuracy: 0.9925 | F1: 0.4847 | Dice: 0.5389 | Recall: 0.4700 | Precision: 0.6314 | Jaccard: 0.3321


                                                                    

Epoch 34 | Time: 5m 43s
	Train Loss: 0.3599 | Valid Loss: 0.4820
	Accuracy: 0.9926 | F1: 0.4882 | Dice: 0.5417 | Recall: 0.4725 | Precision: 0.6346 | Jaccard: 0.3362


                                                                    

Epoch 35 | Time: 5m 42s
	Train Loss: 0.3593 | Valid Loss: 0.4827
	Accuracy: 0.9926 | F1: 0.4871 | Dice: 0.5400 | Recall: 0.4731 | Precision: 0.6289 | Jaccard: 0.3343


                                                                    

Epoch 36 | Time: 5m 43s
	Train Loss: 0.3595 | Valid Loss: 0.4849
	Accuracy: 0.9925 | F1: 0.4813 | Dice: 0.5369 | Recall: 0.4707 | Precision: 0.6247 | Jaccard: 0.3288


                                                                    

Epoch 37 | Time: 5m 44s
	Train Loss: 0.3590 | Valid Loss: 0.4837
	Accuracy: 0.9925 | F1: 0.4865 | Dice: 0.5403 | Recall: 0.4713 | Precision: 0.6329 | Jaccard: 0.3339


                                                                    

Epoch 38 | Time: 11m 0s
	Train Loss: 0.3588 | Valid Loss: 0.4888
	Accuracy: 0.9926 | F1: 0.4814 | Dice: 0.5359 | Recall: 0.4591 | Precision: 0.6434 | Jaccard: 0.3292


                                                                    

Epoch 39 | Time: 11m 33s
	Train Loss: 0.3588 | Valid Loss: 0.4852
	Accuracy: 0.9924 | F1: 0.4814 | Dice: 0.5354 | Recall: 0.4701 | Precision: 0.6217 | Jaccard: 0.3295


                                                                    

Epoch 40 | Time: 11m 38s
	Train Loss: 0.3589 | Valid Loss: 0.4871
	Accuracy: 0.9926 | F1: 0.4858 | Dice: 0.5392 | Recall: 0.4633 | Precision: 0.6448 | Jaccard: 0.3341


                                                                    

Epoch 41 | Time: 11m 29s
	Train Loss: 0.3591 | Valid Loss: 0.4843
	Accuracy: 0.9924 | F1: 0.4835 | Dice: 0.5383 | Recall: 0.4713 | Precision: 0.6274 | Jaccard: 0.3312


                                                                    

Epoch 42 | Time: 8m 53s
	Train Loss: 0.3585 | Valid Loss: 0.4886
	Accuracy: 0.9925 | F1: 0.4828 | Dice: 0.5375 | Recall: 0.4618 | Precision: 0.6429 | Jaccard: 0.3307


                                                                    

Epoch 43 | Time: 5m 47s
	Train Loss: 0.3588 | Valid Loss: 0.4892
	Accuracy: 0.9924 | F1: 0.4803 | Dice: 0.5358 | Recall: 0.4649 | Precision: 0.6322 | Jaccard: 0.3287


                                                                    

Epoch 44 | Time: 5m 45s
	Train Loss: 0.3590 | Valid Loss: 0.4866
	Accuracy: 0.9922 | F1: 0.4816 | Dice: 0.5349 | Recall: 0.4728 | Precision: 0.6159 | Jaccard: 0.3307


                                                                    

Epoch 45 | Time: 5m 46s
	Train Loss: 0.3586 | Valid Loss: 0.4851
	Accuracy: 0.9926 | F1: 0.4846 | Dice: 0.5378 | Recall: 0.4663 | Precision: 0.6352 | Jaccard: 0.3316


                                                                    

Epoch 46 | Time: 5m 45s
	Train Loss: 0.3587 | Valid Loss: 0.4827
	Accuracy: 0.9924 | F1: 0.4845 | Dice: 0.5376 | Recall: 0.4731 | Precision: 0.6225 | Jaccard: 0.3324


                                                                    

Epoch 47 | Time: 5m 43s
	Train Loss: 0.3589 | Valid Loss: 0.4834
	Accuracy: 0.9925 | F1: 0.4867 | Dice: 0.5391 | Recall: 0.4723 | Precision: 0.6279 | Jaccard: 0.3346


                                                                    

Epoch 48 | Time: 5m 45s
	Train Loss: 0.3583 | Valid Loss: 0.4863
	Accuracy: 0.9925 | F1: 0.4843 | Dice: 0.5379 | Recall: 0.4679 | Precision: 0.6325 | Jaccard: 0.3325


                                                                    

Epoch 49 | Time: 5m 43s
	Train Loss: 0.3585 | Valid Loss: 0.4841
	Accuracy: 0.9926 | F1: 0.4875 | Dice: 0.5401 | Recall: 0.4723 | Precision: 0.6308 | Jaccard: 0.3348


                                                                    

Epoch 50 | Time: 5m 43s
	Train Loss: 0.3587 | Valid Loss: 0.4814
	Accuracy: 0.9926 | F1: 0.4885 | Dice: 0.5408 | Recall: 0.4717 | Precision: 0.6336 | Jaccard: 0.3354




In [10]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


  model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
Testing: 100%|██████████| 113/113 [00:19<00:00,  5.84it/s]

Jaccard: 0.2689, F1: 0.3859, Recall: 0.4284, Precision: 0.4726, Accuracy: 0.9945
FPS: 257.81350316503375



