In [1]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from operator import add
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score

In [None]:
# Utility Functions
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Dataset 
class RetinalDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        img_path = self.images_path[index]
        mask_path = self.masks_path[index]

        # Load image in BGR format
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # Extract only the green channel
        green = image[:, :, 1]  # Index 1 = green channel in BGR

        # Normalize and convert to tensor
        green = green / 255.0
        green = np.expand_dims(green, axis=0).astype(np.float32)  # Shape: [1, H, W]
        image = torch.from_numpy(green)

        # Load and process the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.0
        mask = np.expand_dims(mask, axis=0).astype(np.float32)
        mask = torch.from_numpy(mask)

        # Get the filename
        filename = os.path.basename(img_path)

        return image, mask, filename
    def __len__(self):
            return self.n_samples

In [4]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits).view(logits.size(0), -1)
        targets = targets.view(targets.size(0), -1)

        TP = (probs * targets).sum(dim=1)
        FP = (probs * (1 - targets)).sum(dim=1)
        FN = ((1 - probs) * targets).sum(dim=1)

        tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return (1 - tversky).mean()



In [5]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, act=True):
        super().__init__()

        layers = [
            nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_c)
        ]
        if act == True:
            layers.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

class multires_block(nn.Module):
    def __init__(self, in_c, out_c, alpha=1.67):
        super().__init__()

        W = out_c * alpha
        self.c1 = conv_block(in_c, int(W*0.167))
        self.c2 = conv_block(int(W*0.167), int(W*0.333))
        self.c3 = conv_block(int(W*0.333), int(W*0.5))

        nf = int(W*0.167) + int(W*0.333) + int(W*0.5)
        self.b1 = nn.BatchNorm2d(nf)
        self.c4 = conv_block(in_c, nf)
        self.relu = nn.ReLU(inplace=True)
        self.b2 = nn.BatchNorm2d(nf)

    def forward(self, x):
        x0 = x
        x1 = self.c1(x0)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        xc = torch.cat([x1, x2, x3], dim=1)
        xc = self.b1(xc)

        sc = self.c4(x0)
        x = self.relu(xc + sc)
        x = self.b2(x)
        return x

class res_path_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.c1 = conv_block(in_c, out_c, act=False)
        self.s1 = conv_block(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self, x):
        x1 = self.c1(x)
        s1 = self.s1(x)
        x = self.relu(x1 + s1)
        x = self.bn(x)
        return x

class res_path(nn.Module):
    def __init__(self, in_c, out_c, length):
        super().__init__()

        layers = []
        for i in range(length):
            layers.append(res_path_block(in_c, out_c))
            in_c = out_c

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

def cal_nf(ch, alpha=1.67):
    W = ch * alpha
    return int(W*0.167) + int(W*0.333) + int(W*0.5)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c, length):
        super().__init__()

        self.c1 = multires_block(in_c, out_c)
        nf = cal_nf(out_c)
        self.s1 = res_path(nf, out_c, length)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        x = self.c1(x)
        s = self.s1(x)
        p = self.pool(x)
        return s, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.c1 = nn.ConvTranspose2d(in_c[0], out_c, kernel_size=2, stride=2, padding=0)
        self.c2 = multires_block(out_c+in_c[1], out_c)

    def forward(self, x, s):
        x = self.c1(x)
        x = torch.cat([x, s], dim=1)
        x = self.c2(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = encoder_block(1, 32, 4)
        self.e2 = encoder_block(cal_nf(32), 64, 3)
        self.e3 = encoder_block(cal_nf(64), 128, 2)
        self.e4 = encoder_block(cal_nf(128), 256, 1)

        """ Bridge """
        self.b1 = multires_block(cal_nf(256), 512)

        """ Decoder """
        self.d1 = decoder_block([cal_nf(512), 256], 256)
        self.d2 = decoder_block([cal_nf(256), 128], 128)
        self.d3 = decoder_block([cal_nf(128), 64], 64)
        self.d4 = decoder_block([cal_nf(64), 32], 32)

        """ Output """
        self.output = nn.Conv2d(cal_nf(32), 1, kernel_size=1, padding=0)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b1 = self.b1(p4)

        d1 = self.d1(b1, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.output(d4)
        return output



In [None]:
# Training and Evaluation 
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y, _ in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y, _ in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

def calculate_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy() > 0.5
    y_pred = torch.sigmoid(y_pred).cpu().numpy() > 0.5
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    return [
        jaccard_score(y_true, y_pred, zero_division=0),
        f1_score(y_true, y_pred, zero_division=0),
        recall_score(y_true, y_pred, zero_division=0),
        precision_score(y_true, y_pred, zero_division=0),
        accuracy_score(y_true, y_pred)
    ]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    return np.concatenate([mask]*3, axis=-1)

In [7]:
seeding(42)
MODEL_NAME = "Green_AMultiRes_UNet_Tversky_Full"
MODEL_DIRECTORY = "Green_Model_MultiRes_UNet_Tversky_Full"
create_directory(MODEL_DIRECTORY)
RESULT_DIRECTORY = "Green_Results_MultiRes_UNet_Tversky_Full"
create_directory(RESULT_DIRECTORY)

In [None]:
# load full training set 
train_images = sorted(glob("./final_dataset/train/images/*"))
train_masks  = sorted(glob("./final_dataset/train/masks/*"))
train_dataset = RetinalDataset(train_images, train_masks)

# load test set and split it 50/50 into validation and test 
test_images = sorted(glob("./final_dataset/test/images/*"))
test_masks  = sorted(glob("./final_dataset/test/masks/*"))
full_test_dataset = RetinalDataset(test_images, test_masks)
n_val = len(full_test_dataset) // 2
n_test = len(full_test_dataset) - n_val
valid_dataset, test_dataset = random_split(full_test_dataset, [n_val, n_test])

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn   = TverskyLoss()

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

best_valid_loss = float("inf")



In [9]:
for epoch in range(50):
    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, valid_loader, loss_fn, device)
    scheduler.step(valid_loss)

    model.eval()
    valid_metrics = [0.0] * 5
    with torch.no_grad():
        for x_val, y_val, _ in tqdm(valid_loader, desc="Calculating Metrics", leave=False):
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = model(x_val)
            valid_metrics = list(map(add, valid_metrics, calculate_metrics(y_val, y_pred)))

    metrics_avg = [m / len(valid_loader) for m in valid_metrics]
    jaccard, f1, recall, precision, accuracy = metrics_avg
    dice = (2 * precision * recall) / (precision + recall + 1e-7)

    mins, secs = epoch_time(start, time.time())
    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"\tTrain Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    print(f"\tAccuracy: {accuracy:.4f} | F1: {f1:.4f} | Dice: {dice:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | Jaccard: {jaccard:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss  
        torch.save(model.state_dict(), f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth")
        print(f"Best {MODEL_NAME} Saved")

                                                                    

Epoch 01 | Time: 1m 22s
	Train Loss: 0.9161 | Valid Loss: 0.8832
	Accuracy: 0.7986 | F1: 0.1088 | Dice: 0.1341 | Recall: 0.8586 | Precision: 0.0727 | Jaccard: 0.0662
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 02 | Time: 1m 20s
	Train Loss: 0.8620 | Valid Loss: 0.8409
	Accuracy: 0.8969 | F1: 0.1516 | Dice: 0.1783 | Recall: 0.8561 | Precision: 0.0995 | Jaccard: 0.0924
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 03 | Time: 1m 20s
	Train Loss: 0.8004 | Valid Loss: 0.7777
	Accuracy: 0.9425 | F1: 0.2070 | Dice: 0.2383 | Recall: 0.7927 | Precision: 0.1402 | Jaccard: 0.1268
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 04 | Time: 1m 20s
	Train Loss: 0.7192 | Valid Loss: 0.7022
	Accuracy: 0.9648 | F1: 0.2950 | Dice: 0.3315 | Recall: 0.7625 | Precision: 0.2118 | Jaccard: 0.1944
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 05 | Time: 1m 17s
	Train Loss: 0.6310 | Valid Loss: 0.6788
	Accuracy: 0.9750 | F1: 0.3032 | Dice: 0.3382 | Recall: 0.5792 | Precision: 0.2388 | Jaccard: 0.1934
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 06 | Time: 1m 17s
	Train Loss: 0.5535 | Valid Loss: 0.6344
	Accuracy: 0.9826 | F1: 0.3640 | Dice: 0.4106 | Recall: 0.4917 | Precision: 0.3524 | Jaccard: 0.2515
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 07 | Time: 1m 17s
	Train Loss: 0.4967 | Valid Loss: 0.6238
	Accuracy: 0.9831 | F1: 0.3630 | Dice: 0.4078 | Recall: 0.4852 | Precision: 0.3516 | Jaccard: 0.2485
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 08 | Time: 1m 18s
	Train Loss: 0.4425 | Valid Loss: 0.6198
	Accuracy: 0.9840 | F1: 0.3721 | Dice: 0.4187 | Recall: 0.4603 | Precision: 0.3839 | Jaccard: 0.2585
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 09 | Time: 1m 17s
	Train Loss: 0.4141 | Valid Loss: 0.6523
	Accuracy: 0.9854 | F1: 0.3583 | Dice: 0.3985 | Recall: 0.3660 | Precision: 0.4375 | Jaccard: 0.2477


                                                                    

Epoch 10 | Time: 1m 17s
	Train Loss: 0.3720 | Valid Loss: 0.5751
	Accuracy: 0.9857 | F1: 0.4126 | Dice: 0.4425 | Recall: 0.4869 | Precision: 0.4055 | Jaccard: 0.2896
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 11 | Time: 1m 17s
	Train Loss: 0.3520 | Valid Loss: 0.6099
	Accuracy: 0.9849 | F1: 0.3886 | Dice: 0.4341 | Recall: 0.4359 | Precision: 0.4322 | Jaccard: 0.2692


                                                                    

Epoch 12 | Time: 1m 17s
	Train Loss: 0.3360 | Valid Loss: 0.5993
	Accuracy: 0.9844 | F1: 0.3849 | Dice: 0.4240 | Recall: 0.4703 | Precision: 0.3860 | Jaccard: 0.2619


                                                                    

Epoch 13 | Time: 1m 17s
	Train Loss: 0.3069 | Valid Loss: 0.5960
	Accuracy: 0.9864 | F1: 0.4079 | Dice: 0.4372 | Recall: 0.4249 | Precision: 0.4504 | Jaccard: 0.2907


                                                                    

Epoch 14 | Time: 1m 17s
	Train Loss: 0.2889 | Valid Loss: 0.6300
	Accuracy: 0.9864 | F1: 0.3868 | Dice: 0.4182 | Recall: 0.3672 | Precision: 0.4856 | Jaccard: 0.2710


                                                                    

Epoch 15 | Time: 1m 17s
	Train Loss: 0.2768 | Valid Loss: 0.5573
	Accuracy: 0.9806 | F1: 0.3926 | Dice: 0.4289 | Recall: 0.6408 | Precision: 0.3223 | Jaccard: 0.2734
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 16 | Time: 1m 17s
	Train Loss: 0.2733 | Valid Loss: 0.5978
	Accuracy: 0.9858 | F1: 0.4095 | Dice: 0.4373 | Recall: 0.4094 | Precision: 0.4692 | Jaccard: 0.2828


                                                                    

Epoch 17 | Time: 1m 18s
	Train Loss: 0.2564 | Valid Loss: 0.6223
	Accuracy: 0.9852 | F1: 0.3802 | Dice: 0.4127 | Recall: 0.3944 | Precision: 0.4328 | Jaccard: 0.2654


                                                                    

Epoch 18 | Time: 1m 17s
	Train Loss: 0.2436 | Valid Loss: 0.6021
	Accuracy: 0.9861 | F1: 0.4002 | Dice: 0.4305 | Recall: 0.4172 | Precision: 0.4447 | Jaccard: 0.2812


                                                                    

Epoch 19 | Time: 1m 17s
	Train Loss: 0.2410 | Valid Loss: 0.5506
	Accuracy: 0.9865 | F1: 0.4385 | Dice: 0.4607 | Recall: 0.4857 | Precision: 0.4381 | Jaccard: 0.3078
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 20 | Time: 1m 17s
	Train Loss: 0.2236 | Valid Loss: 0.6217
	Accuracy: 0.9864 | F1: 0.3927 | Dice: 0.4198 | Recall: 0.3706 | Precision: 0.4839 | Jaccard: 0.2746


                                                                    

Epoch 21 | Time: 1m 17s
	Train Loss: 0.2158 | Valid Loss: 0.6325
	Accuracy: 0.9861 | F1: 0.3876 | Dice: 0.4152 | Recall: 0.3552 | Precision: 0.4996 | Jaccard: 0.2683


                                                                    

Epoch 22 | Time: 1m 20s
	Train Loss: 0.2165 | Valid Loss: 0.5206
	Accuracy: 0.9845 | F1: 0.4480 | Dice: 0.4760 | Recall: 0.5756 | Precision: 0.4059 | Jaccard: 0.3159
Best Green_AMultiRes_UNet_Tversky_Full Saved


                                                                    

Epoch 23 | Time: 1m 21s
	Train Loss: 0.2170 | Valid Loss: 0.5875
	Accuracy: 0.9865 | F1: 0.4194 | Dice: 0.4583 | Recall: 0.4303 | Precision: 0.4902 | Jaccard: 0.3001


                                                                    

Epoch 24 | Time: 1m 20s
	Train Loss: 0.2130 | Valid Loss: 0.5479
	Accuracy: 0.9850 | F1: 0.4373 | Dice: 0.4700 | Recall: 0.5040 | Precision: 0.4403 | Jaccard: 0.3038


                                                                    

Epoch 25 | Time: 1m 20s
	Train Loss: 0.2036 | Valid Loss: 0.5375
	Accuracy: 0.9859 | F1: 0.4580 | Dice: 0.4891 | Recall: 0.4963 | Precision: 0.4820 | Jaccard: 0.3202


                                                                    

Epoch 26 | Time: 1m 21s
	Train Loss: 0.2037 | Valid Loss: 0.5807
	Accuracy: 0.9857 | F1: 0.4076 | Dice: 0.4384 | Recall: 0.4699 | Precision: 0.4109 | Jaccard: 0.2891


                                                                    

Epoch 27 | Time: 1m 20s
	Train Loss: 0.1918 | Valid Loss: 0.5785
	Accuracy: 0.9868 | F1: 0.4299 | Dice: 0.4547 | Recall: 0.4284 | Precision: 0.4844 | Jaccard: 0.3009


                                                                    

Epoch 28 | Time: 1m 21s
	Train Loss: 0.1895 | Valid Loss: 0.5797
	Accuracy: 0.9867 | F1: 0.4252 | Dice: 0.4517 | Recall: 0.4298 | Precision: 0.4759 | Jaccard: 0.3031


                                                                    

Epoch 29 | Time: 1m 20s
	Train Loss: 0.1699 | Valid Loss: 0.5886
	Accuracy: 0.9866 | F1: 0.4192 | Dice: 0.4430 | Recall: 0.4139 | Precision: 0.4765 | Jaccard: 0.2983


                                                                    

Epoch 30 | Time: 1m 19s
	Train Loss: 0.1634 | Valid Loss: 0.6065
	Accuracy: 0.9866 | F1: 0.4049 | Dice: 0.4286 | Recall: 0.3903 | Precision: 0.4752 | Jaccard: 0.2876


                                                                    

Epoch 31 | Time: 1m 19s
	Train Loss: 0.1595 | Valid Loss: 0.5937
	Accuracy: 0.9866 | F1: 0.4157 | Dice: 0.4406 | Recall: 0.4071 | Precision: 0.4800 | Jaccard: 0.2960


                                                                    

Epoch 32 | Time: 1m 17s
	Train Loss: 0.1569 | Valid Loss: 0.5784
	Accuracy: 0.9867 | F1: 0.4373 | Dice: 0.4798 | Recall: 0.4177 | Precision: 0.5637 | Jaccard: 0.3086


                                                                    

Epoch 33 | Time: 1m 18s
	Train Loss: 0.1547 | Valid Loss: 0.5956
	Accuracy: 0.9867 | F1: 0.4144 | Dice: 0.4409 | Recall: 0.4038 | Precision: 0.4855 | Jaccard: 0.2951


                                                                    

Epoch 34 | Time: 1m 17s
	Train Loss: 0.1521 | Valid Loss: 0.5917
	Accuracy: 0.9865 | F1: 0.4200 | Dice: 0.4533 | Recall: 0.4078 | Precision: 0.5101 | Jaccard: 0.2976


                                                                    

Epoch 35 | Time: 1m 17s
	Train Loss: 0.1503 | Valid Loss: 0.6264
	Accuracy: 0.9867 | F1: 0.3995 | Dice: 0.4245 | Recall: 0.3500 | Precision: 0.5395 | Jaccard: 0.2810


                                                                    

Epoch 36 | Time: 1m 18s
	Train Loss: 0.1491 | Valid Loss: 0.6103
	Accuracy: 0.9868 | F1: 0.4083 | Dice: 0.4327 | Recall: 0.3748 | Precision: 0.5118 | Jaccard: 0.2896


                                                                    

Epoch 37 | Time: 1m 17s
	Train Loss: 0.1484 | Valid Loss: 0.6323
	Accuracy: 0.9867 | F1: 0.3952 | Dice: 0.4191 | Recall: 0.3405 | Precision: 0.5448 | Jaccard: 0.2773


                                                                    

Epoch 38 | Time: 1m 18s
	Train Loss: 0.1478 | Valid Loss: 0.6181
	Accuracy: 0.9866 | F1: 0.4061 | Dice: 0.4515 | Recall: 0.3596 | Precision: 0.6067 | Jaccard: 0.2861


                                                                    

Epoch 39 | Time: 1m 23s
	Train Loss: 0.1473 | Valid Loss: 0.6133
	Accuracy: 0.9867 | F1: 0.4045 | Dice: 0.4283 | Recall: 0.3741 | Precision: 0.5008 | Jaccard: 0.2859


                                                                    

Epoch 40 | Time: 1m 23s
	Train Loss: 0.1468 | Valid Loss: 0.5943
	Accuracy: 0.9867 | F1: 0.4153 | Dice: 0.4400 | Recall: 0.4063 | Precision: 0.4798 | Jaccard: 0.2963


                                                                    

Epoch 41 | Time: 1m 23s
	Train Loss: 0.1464 | Valid Loss: 0.6299
	Accuracy: 0.9866 | F1: 0.3948 | Dice: 0.4177 | Recall: 0.3470 | Precision: 0.5245 | Jaccard: 0.2776


                                                                    

Epoch 42 | Time: 1m 22s
	Train Loss: 0.1463 | Valid Loss: 0.6184
	Accuracy: 0.9867 | F1: 0.3983 | Dice: 0.4231 | Recall: 0.3730 | Precision: 0.4886 | Jaccard: 0.2811


                                                                    

Epoch 43 | Time: 1m 18s
	Train Loss: 0.1462 | Valid Loss: 0.6065
	Accuracy: 0.9867 | F1: 0.4094 | Dice: 0.4333 | Recall: 0.3832 | Precision: 0.4985 | Jaccard: 0.2907


                                                                    

Epoch 44 | Time: 1m 19s
	Train Loss: 0.1462 | Valid Loss: 0.6168
	Accuracy: 0.9867 | F1: 0.4044 | Dice: 0.4285 | Recall: 0.3642 | Precision: 0.5203 | Jaccard: 0.2860


                                                                    

Epoch 45 | Time: 1m 19s
	Train Loss: 0.1461 | Valid Loss: 0.6308
	Accuracy: 0.9866 | F1: 0.3928 | Dice: 0.4161 | Recall: 0.3473 | Precision: 0.5190 | Jaccard: 0.2759


                                                                    

Epoch 46 | Time: 1m 19s
	Train Loss: 0.1461 | Valid Loss: 0.5948
	Accuracy: 0.9867 | F1: 0.4157 | Dice: 0.4410 | Recall: 0.4036 | Precision: 0.4859 | Jaccard: 0.2961


                                                                    

Epoch 47 | Time: 1m 19s
	Train Loss: 0.1460 | Valid Loss: 0.6157
	Accuracy: 0.9868 | F1: 0.4058 | Dice: 0.4299 | Recall: 0.3650 | Precision: 0.5229 | Jaccard: 0.2872


                                                                    

Epoch 48 | Time: 1m 18s
	Train Loss: 0.1460 | Valid Loss: 0.6240
	Accuracy: 0.9867 | F1: 0.3998 | Dice: 0.4221 | Recall: 0.3538 | Precision: 0.5230 | Jaccard: 0.2816


                                                                    

Epoch 49 | Time: 1m 18s
	Train Loss: 0.1460 | Valid Loss: 0.6063
	Accuracy: 0.9866 | F1: 0.4055 | Dice: 0.4305 | Recall: 0.3887 | Precision: 0.4824 | Jaccard: 0.2884


                                                                    

Epoch 50 | Time: 1m 19s
	Train Loss: 0.1460 | Valid Loss: 0.6281
	Accuracy: 0.9867 | F1: 0.3968 | Dice: 0.4193 | Recall: 0.3492 | Precision: 0.5246 | Jaccard: 0.2784




In [10]:
# Load the best model
model.load_state_dict(torch.load(f"{MODEL_DIRECTORY}/{MODEL_NAME}.pth", map_location=device))
model.eval()

# Evaluate on the held-out half of the test set
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
metrics_score = [0.0] * 5
time_taken = []

for x, y, fname in tqdm(test_loader, desc="Testing", total=len(test_loader)):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        start = time.time()
        pred_y = model(x)
        time_taken.append(time.time() - start)
        metrics_score = list(map(add, metrics_score, calculate_metrics(y, pred_y)))

    # Use green-channel image directly (grayscale)
    green_img = (x.cpu().numpy()[0, 0] * 255).astype(np.uint8)  # Shape: [H, W]

    # Process ground truth and prediction
    mask = (y.cpu().numpy()[0, 0] * 255).astype(np.uint8)
    pred = (torch.sigmoid(pred_y).cpu().numpy()[0, 0] > 0.5).astype(np.uint8) * 255

    # Convert masks to RGB overlays
    mask_img = mask_parse(mask)
    pred_img = mask_parse(pred)

    # Resize masks to match green image if needed
    h, w = green_img.shape
    mask_img = cv2.resize(mask_img, (w, h))
    pred_img = cv2.resize(pred_img, (w, h))

    # Create vertical separator
    line = np.ones((h, 10, 3), dtype=np.uint8) * 128

    # Convert green image to 3-channel grayscale for compatibility
    green_rgb = np.stack([green_img]*3, axis=-1)

    # Concatenate images: green | line | mask | line | prediction
    result_uint8 = np.concatenate([green_rgb, line, mask_img, line, pred_img], axis=1)

    # Safe filename
    if isinstance(fname, (list, tuple)):
        fname = fname[0]
    save_name = os.path.splitext(fname)[0] + ".png"

    # Save image
    plt.imsave(f"{RESULT_DIRECTORY}/{save_name}", result_uint8)

# Final metrics
j, f1, r, p, a = [m / len(test_loader) for m in metrics_score]
print(f"Jaccard: {j:.4f}, F1: {f1:.4f}, Recall: {r:.4f}, Precision: {p:.4f}, Accuracy: {a:.4f}")
print("FPS:", 1 / np.mean(time_taken))


Testing: 100%|██████████| 14/14 [00:05<00:00,  2.76it/s]

Jaccard: 0.3281, F1: 0.4840, Recall: 0.6992, Precision: 0.3934, Accuracy: 0.9861
FPS: 95.58469023826073



