In [2]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions 
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples


In [5]:
class Focal_Tversky(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=1.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta  = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, logits, targets):
    
        probs = torch.sigmoid(logits)
        N = targets.size(0)

        probs  = probs.view(N, -1)
        targets = targets.view(N, -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        smooth = torch.tensor(self.smooth, device=probs.device, dtype=probs.dtype)
        tversky = (TP + smooth) / (TP + self.alpha * FP + self.beta * FN + smooth)

        loss = (1 - tversky) ** self.gamma

        return loss.mean()


In [None]:

#  DoubleConv and TripleConv

class TripleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 3 times."""
    def __init__(self, in_c, mid1_c, mid2_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid1_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid1_c)
        self.conv2 = nn.Conv2d(mid1_c, mid2_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(mid2_c)
        self.conv3 = nn.Conv2d(mid2_c, out_c, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x

class DoubleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 2 times."""
    def __init__(self, in_c, mid_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid_c)
        self.conv2 = nn.Conv2d(mid_c, out_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x


#  UNetRetina (with additional subsampling and concatenation)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # ========== ENCODER ==========
        # Block 1: 3 conv -> (32, 32, 64)
        self.down1 = TripleConv(
            in_c=1,       # green channel only
            mid1_c=32,
            mid2_c=32,
            out_c=64
        )
        # Block 2: 3 conv -> (64, 64, 128)
        self.down2 = TripleConv(
            in_c=64,
            mid1_c=64,
            mid2_c=64,
            out_c=128
        )
        # Block 3: 2 conv -> (128, 128, 256)
        self.down3 = DoubleConv(
            in_c=128,
            mid_c=128,
            out_c=256
        )
        # Block 4: 2 conv -> (256, 256, 256)
        self.down4 = DoubleConv(
            in_c=256,
            mid_c=256,
            out_c=256
        )
        
        self.pool = nn.MaxPool2d(2, 2)
        
        #  BOTTLENECK 
        # Bottleneck: 2 conv -> (256 -> 512 -> 256)
        self.bottleneck = DoubleConv(
            in_c=256,
            mid_c=512,
            out_c=256
        )
        
        # DECODER 
        # Each decoder block: upsample, concat skip connection, then decode.
        self.up4  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = DoubleConv(in_c=256+256, mid_c=256, out_c=256)
        
        self.up3  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = DoubleConv(in_c=256+256, mid_c=128, out_c=128)
        
        self.up2  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = TripleConv(in_c=128+128, mid1_c=64, mid2_c=64, out_c=64)
        
        self.up1  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = TripleConv(in_c=64+64, mid1_c=32, mid2_c=32, out_c=32)
        
        # ADDITIONAL SUBSAMPLING & CONCATENATION 
        self.final_pool = nn.MaxPool2d(2, 2)  # Additional subsampling step
        self.final_upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # After concatenation: 32 channels (from decoder) + 1 channel (from input) = 33 channels
        self.out_conv = nn.Conv2d(33, 1, kernel_size=1)
        

    def forward(self, x):
        # Encoder
        input_image = x
        
        # Block 1
        x1 = self.down1(x)    # 64 channels
        x1p = self.pool(x1)   # subsampled
        
        # Block 2
        x2 = self.down2(x1p)  #  128 channels
        x2p = self.pool(x2)   # subsampled
        
        # Block 3
        x3 = self.down3(x2p)  # 256 channels
        x3p = self.pool(x3)   # subsampled
        
        # Block 4
        x4 = self.down4(x3p)  #  256 channels
        x4p = self.pool(x4)   # subsampled
        
        # Bottleneck 
        xb = self.bottleneck(x4p)  # 256 -> 512 -> 256
        
        #  Decoder 
        xd4 = self.up4(xb)               
        xd4 = torch.cat([x4, xd4], dim=1)  
        xd4 = self.dec4(xd4)             
        
        xd3 = self.up3(xd4)              
        xd3 = torch.cat([x3, xd3], dim=1) 
        xd3 = self.dec3(xd3)             
        
        xd2 = self.up2(xd3)              
        xd2 = torch.cat([x2, xd2], dim=1) 
        xd2 = self.dec2(xd2)             
        
        xd1 = self.up1(xd2)              
        xd1 = torch.cat([x1, xd1], dim=1) 
        xd1 = self.dec1(xd1)             
        
        # Additional Subsampling & Concatenation 
        xd1_sub = self.final_pool(xd1)          #  [B,32,H/2,W/2]
        input_sub = self.final_pool(input_image) # [B,1,H/2,W/2]
        final_cat = torch.cat([xd1_sub, input_sub], dim=1)  #  [B,33,H/2,W/2]
        final_up = self.final_upsample(final_cat)  # [B,33,H,W]
        
        out = self.out_conv(final_up)  # [B,1,H,W]
        
        return out


In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [8]:
seeding(42)
MODEL_NAME = "Green_Custom_Focal_Tversky_Full"
MODEL_DIRECTORY = "Green_Model_Custom_Focal_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_Custom_Focal_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
#  load full training set 
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test
test_images = sorted(glob("./final_dataset/test/images/*"))
test_masks  = sorted(glob("./final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = Focal_Tversky()

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [10]:
for epoch in range(50):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                  

Epoch 01 | Time: 0m 44s
	Train Loss: 0.8483 | Valid Loss: 0.7878
	Accuracy: 0.9661 | F1: 0.2700 | Dice: 0.3129 | Recall: 0.7194 | Precision: 0.2000 | Jaccard: 0.1670
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 02 | Time: 0m 43s
	Train Loss: 0.7080 | Valid Loss: 0.6926
	Accuracy: 0.9788 | F1: 0.3365 | Dice: 0.3791 | Recall: 0.5819 | Precision: 0.2811 | Jaccard: 0.2121
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 03 | Time: 0m 43s
	Train Loss: 0.5545 | Valid Loss: 0.5863
	Accuracy: 0.9840 | F1: 0.3769 | Dice: 0.4137 | Recall: 0.5058 | Precision: 0.3499 | Jaccard: 0.2442
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 04 | Time: 0m 43s
	Train Loss: 0.4424 | Valid Loss: 0.5341
	Accuracy: 0.9871 | F1: 0.4117 | Dice: 0.4366 | Recall: 0.4192 | Precision: 0.4556 | Jaccard: 0.2723
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 05 | Time: 0m 43s
	Train Loss: 0.3448 | Valid Loss: 0.4924
	Accuracy: 0.9873 | F1: 0.4250 | Dice: 0.4472 | Recall: 0.4233 | Precision: 0.4739 | Jaccard: 0.2804
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 06 | Time: 0m 43s
	Train Loss: 0.2987 | Valid Loss: 0.4874
	Accuracy: 0.9859 | F1: 0.4033 | Dice: 0.4254 | Recall: 0.4667 | Precision: 0.3908 | Jaccard: 0.2746
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 07 | Time: 0m 43s
	Train Loss: 0.2557 | Valid Loss: 0.5232
	Accuracy: 0.9881 | F1: 0.3929 | Dice: 0.4007 | Recall: 0.3544 | Precision: 0.4608 | Jaccard: 0.2679


                                                                  

Epoch 08 | Time: 0m 43s
	Train Loss: 0.2238 | Valid Loss: 0.4783
	Accuracy: 0.9869 | F1: 0.4187 | Dice: 0.4398 | Recall: 0.4823 | Precision: 0.4042 | Jaccard: 0.2858
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 09 | Time: 0m 43s
	Train Loss: 0.1971 | Valid Loss: 0.4416
	Accuracy: 0.9880 | F1: 0.4377 | Dice: 0.4511 | Recall: 0.4404 | Precision: 0.4623 | Jaccard: 0.3009
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 10 | Time: 0m 43s
	Train Loss: 0.1810 | Valid Loss: 0.4294
	Accuracy: 0.9879 | F1: 0.4474 | Dice: 0.4671 | Recall: 0.4518 | Precision: 0.4836 | Jaccard: 0.3044
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 11 | Time: 0m 43s
	Train Loss: 0.1599 | Valid Loss: 0.5225
	Accuracy: 0.9882 | F1: 0.4100 | Dice: 0.4212 | Recall: 0.3447 | Precision: 0.5411 | Jaccard: 0.2731


                                                                  

Epoch 12 | Time: 0m 43s
	Train Loss: 0.1592 | Valid Loss: 0.3886
	Accuracy: 0.9880 | F1: 0.4601 | Dice: 0.4859 | Recall: 0.5301 | Precision: 0.4485 | Jaccard: 0.3152
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 13 | Time: 0m 43s
	Train Loss: 0.1515 | Valid Loss: 0.4377
	Accuracy: 0.9872 | F1: 0.4393 | Dice: 0.4592 | Recall: 0.4789 | Precision: 0.4411 | Jaccard: 0.2986


                                                                  

Epoch 14 | Time: 0m 43s
	Train Loss: 0.1400 | Valid Loss: 0.4277
	Accuracy: 0.9874 | F1: 0.4507 | Dice: 0.4720 | Recall: 0.5184 | Precision: 0.4333 | Jaccard: 0.3105


                                                                  

Epoch 15 | Time: 0m 43s
	Train Loss: 0.1456 | Valid Loss: 0.4283
	Accuracy: 0.9884 | F1: 0.4506 | Dice: 0.4613 | Recall: 0.4607 | Precision: 0.4620 | Jaccard: 0.3073


                                                                  

Epoch 16 | Time: 0m 43s
	Train Loss: 0.1189 | Valid Loss: 0.4552
	Accuracy: 0.9867 | F1: 0.4279 | Dice: 0.4469 | Recall: 0.4759 | Precision: 0.4212 | Jaccard: 0.2935


                                                                  

Epoch 17 | Time: 0m 43s
	Train Loss: 0.1375 | Valid Loss: 0.4504
	Accuracy: 0.9883 | F1: 0.4525 | Dice: 0.4726 | Recall: 0.4628 | Precision: 0.4829 | Jaccard: 0.3132


                                                                  

Epoch 18 | Time: 0m 44s
	Train Loss: 0.1135 | Valid Loss: 0.3700
	Accuracy: 0.9883 | F1: 0.4800 | Dice: 0.4941 | Recall: 0.4548 | Precision: 0.5409 | Jaccard: 0.3272
Best Green_Custom_Focal_Tversky_Full Saved


                                                                  

Epoch 19 | Time: 0m 43s
	Train Loss: 0.1221 | Valid Loss: 0.3977
	Accuracy: 0.9880 | F1: 0.4607 | Dice: 0.4783 | Recall: 0.5030 | Precision: 0.4560 | Jaccard: 0.3179


                                                                  

Epoch 20 | Time: 0m 45s
	Train Loss: 0.1076 | Valid Loss: 0.3847
	Accuracy: 0.9870 | F1: 0.4474 | Dice: 0.4760 | Recall: 0.5202 | Precision: 0.4386 | Jaccard: 0.3043


                                                                  

Epoch 21 | Time: 0m 45s
	Train Loss: 0.0967 | Valid Loss: 0.4480
	Accuracy: 0.9883 | F1: 0.4428 | Dice: 0.4675 | Recall: 0.4496 | Precision: 0.4869 | Jaccard: 0.3036


                                                                  

Epoch 22 | Time: 0m 45s
	Train Loss: 0.0966 | Valid Loss: 0.4253
	Accuracy: 0.9869 | F1: 0.4514 | Dice: 0.4757 | Recall: 0.5626 | Precision: 0.4120 | Jaccard: 0.3160


                                                                  

Epoch 23 | Time: 0m 45s
	Train Loss: 0.1049 | Valid Loss: 0.4654
	Accuracy: 0.9876 | F1: 0.4323 | Dice: 0.4515 | Recall: 0.4416 | Precision: 0.4619 | Jaccard: 0.2974


                                                                  

Epoch 24 | Time: 0m 44s
	Train Loss: 0.0930 | Valid Loss: 0.4261
	Accuracy: 0.9879 | F1: 0.4544 | Dice: 0.4761 | Recall: 0.5094 | Precision: 0.4469 | Jaccard: 0.3157


                                                                  

Epoch 25 | Time: 0m 44s
	Train Loss: 0.0774 | Valid Loss: 0.4399
	Accuracy: 0.9885 | F1: 0.4477 | Dice: 0.4705 | Recall: 0.4565 | Precision: 0.4853 | Jaccard: 0.3087


                                                                  

Epoch 26 | Time: 0m 44s
	Train Loss: 0.0744 | Valid Loss: 0.4311
	Accuracy: 0.9886 | F1: 0.4568 | Dice: 0.4799 | Recall: 0.4805 | Precision: 0.4794 | Jaccard: 0.3172


                                                                  

Epoch 27 | Time: 0m 44s
	Train Loss: 0.0719 | Valid Loss: 0.4350
	Accuracy: 0.9885 | F1: 0.4528 | Dice: 0.4763 | Recall: 0.4691 | Precision: 0.4837 | Jaccard: 0.3130


                                                                  

Epoch 28 | Time: 0m 43s
	Train Loss: 0.0708 | Valid Loss: 0.4388
	Accuracy: 0.9883 | F1: 0.4484 | Dice: 0.4714 | Recall: 0.4715 | Precision: 0.4712 | Jaccard: 0.3086


                                                                  

Epoch 29 | Time: 0m 43s
	Train Loss: 0.0696 | Valid Loss: 0.4236
	Accuracy: 0.9884 | F1: 0.4596 | Dice: 0.4862 | Recall: 0.4936 | Precision: 0.4791 | Jaccard: 0.3178


                                                                  

Epoch 30 | Time: 0m 43s
	Train Loss: 0.0682 | Valid Loss: 0.4213
	Accuracy: 0.9884 | F1: 0.4621 | Dice: 0.4877 | Recall: 0.4997 | Precision: 0.4763 | Jaccard: 0.3195


                                                                  

Epoch 31 | Time: 0m 42s
	Train Loss: 0.0669 | Valid Loss: 0.4320
	Accuracy: 0.9885 | F1: 0.4549 | Dice: 0.4791 | Recall: 0.4794 | Precision: 0.4787 | Jaccard: 0.3137


                                                                  

Epoch 32 | Time: 0m 42s
	Train Loss: 0.0666 | Valid Loss: 0.4254
	Accuracy: 0.9885 | F1: 0.4610 | Dice: 0.4865 | Recall: 0.4910 | Precision: 0.4821 | Jaccard: 0.3192


                                                                  

Epoch 33 | Time: 0m 42s
	Train Loss: 0.0667 | Valid Loss: 0.4269
	Accuracy: 0.9885 | F1: 0.4587 | Dice: 0.4835 | Recall: 0.4847 | Precision: 0.4823 | Jaccard: 0.3170


                                                                  

Epoch 34 | Time: 0m 42s
	Train Loss: 0.0667 | Valid Loss: 0.4360
	Accuracy: 0.9883 | F1: 0.4495 | Dice: 0.4735 | Recall: 0.4782 | Precision: 0.4690 | Jaccard: 0.3093


                                                                  

Epoch 35 | Time: 0m 42s
	Train Loss: 0.0663 | Valid Loss: 0.4372
	Accuracy: 0.9884 | F1: 0.4487 | Dice: 0.4740 | Recall: 0.4622 | Precision: 0.4863 | Jaccard: 0.3082


                                                                  

Epoch 36 | Time: 0m 42s
	Train Loss: 0.0662 | Valid Loss: 0.4400
	Accuracy: 0.9885 | F1: 0.4487 | Dice: 0.4720 | Recall: 0.4591 | Precision: 0.4856 | Jaccard: 0.3093


                                                                  

Epoch 37 | Time: 0m 42s
	Train Loss: 0.0658 | Valid Loss: 0.4344
	Accuracy: 0.9885 | F1: 0.4538 | Dice: 0.4765 | Recall: 0.4725 | Precision: 0.4806 | Jaccard: 0.3134


                                                                  

Epoch 38 | Time: 0m 43s
	Train Loss: 0.0665 | Valid Loss: 0.4323
	Accuracy: 0.9885 | F1: 0.4553 | Dice: 0.4792 | Recall: 0.4802 | Precision: 0.4782 | Jaccard: 0.3146


                                                                  

Epoch 39 | Time: 0m 42s
	Train Loss: 0.0659 | Valid Loss: 0.4307
	Accuracy: 0.9883 | F1: 0.4549 | Dice: 0.4786 | Recall: 0.4927 | Precision: 0.4652 | Jaccard: 0.3131


                                                                  

Epoch 40 | Time: 0m 42s
	Train Loss: 0.0665 | Valid Loss: 0.4391
	Accuracy: 0.9884 | F1: 0.4483 | Dice: 0.4708 | Recall: 0.4698 | Precision: 0.4717 | Jaccard: 0.3080


                                                                  

Epoch 41 | Time: 0m 42s
	Train Loss: 0.0658 | Valid Loss: 0.4381
	Accuracy: 0.9883 | F1: 0.4494 | Dice: 0.4713 | Recall: 0.4778 | Precision: 0.4650 | Jaccard: 0.3090


                                                                  

Epoch 42 | Time: 0m 42s
	Train Loss: 0.0657 | Valid Loss: 0.4334
	Accuracy: 0.9883 | F1: 0.4524 | Dice: 0.4756 | Recall: 0.4826 | Precision: 0.4687 | Jaccard: 0.3112


                                                                  

Epoch 43 | Time: 0m 42s
	Train Loss: 0.0657 | Valid Loss: 0.4357
	Accuracy: 0.9883 | F1: 0.4514 | Dice: 0.4740 | Recall: 0.4813 | Precision: 0.4669 | Jaccard: 0.3107


                                                                  

Epoch 44 | Time: 0m 42s
	Train Loss: 0.0662 | Valid Loss: 0.4392
	Accuracy: 0.9884 | F1: 0.4500 | Dice: 0.4720 | Recall: 0.4713 | Precision: 0.4726 | Jaccard: 0.3097


                                                                  

Epoch 45 | Time: 0m 42s
	Train Loss: 0.0660 | Valid Loss: 0.4448
	Accuracy: 0.9884 | F1: 0.4435 | Dice: 0.4652 | Recall: 0.4583 | Precision: 0.4722 | Jaccard: 0.3043


                                                                  

Epoch 46 | Time: 0m 42s
	Train Loss: 0.0657 | Valid Loss: 0.4377
	Accuracy: 0.9884 | F1: 0.4481 | Dice: 0.4726 | Recall: 0.4690 | Precision: 0.4763 | Jaccard: 0.3071


                                                                  

Epoch 47 | Time: 0m 42s
	Train Loss: 0.0658 | Valid Loss: 0.4396
	Accuracy: 0.9883 | F1: 0.4484 | Dice: 0.4703 | Recall: 0.4763 | Precision: 0.4645 | Jaccard: 0.3080


                                                                  

Epoch 48 | Time: 0m 42s
	Train Loss: 0.0658 | Valid Loss: 0.4376
	Accuracy: 0.9884 | F1: 0.4494 | Dice: 0.4726 | Recall: 0.4713 | Precision: 0.4739 | Jaccard: 0.3091


                                                                  

Epoch 49 | Time: 0m 42s
	Train Loss: 0.0656 | Valid Loss: 0.4330
	Accuracy: 0.9884 | F1: 0.4531 | Dice: 0.4785 | Recall: 0.4784 | Precision: 0.4787 | Jaccard: 0.3120


                                                                  

Epoch 50 | Time: 0m 42s
	Train Loss: 0.0656 | Valid Loss: 0.4337
	Accuracy: 0.9884 | F1: 0.4526 | Dice: 0.4765 | Recall: 0.4766 | Precision: 0.4764 | Jaccard: 0.3118




In [11]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


Testing: 100%|██████████| 14/14 [00:05<00:00,  2.66it/s]

Jaccard: 0.3965, F1: 0.5622, Recall: 0.5737, Precision: 0.5859, Accuracy: 0.9912
FPS: 243.50802635780428



