In [2]:
import os
import glob
from PIL import Image
import random
import numpy as np
import cv2
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm

def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seeding(42)

In [3]:
MODEL_NAME = "BB_Custom_Attention_MODEL_Dice_BCE"
batch_size = 1

In [4]:
class DiceBCELoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.smooth = smooth
        self.bce_fn = nn.BCEWithLogitsLoss(reduction='none')  # <- changed

    def forward(self, logits, targets, reduction='mean'):
        # BCE loss map (per-pixel)
        bce_loss = self.bce_fn(logits, targets)  # shape: (B, 1, H, W)

        # Dice loss (per-batch)
        probs = torch.sigmoid(logits)
        batch_size = probs.shape[0]
        dice_losses = []
        for i in range(batch_size):
            p = probs[i].view(-1)
            g = targets[i].view(-1)
            inter = (p * g).sum()
            dice = 1 - (2*inter + self.smooth) / (p.sum() + g.sum() + self.smooth)
            dice_losses.append(dice)
        dice_loss = torch.stack(dice_losses).mean()

        if reduction == 'none':
            # Used for pixel-wise loss map
            return self.alpha + (1 - self.alpha) * bce_loss
        else:
            return self.alpha * dice_loss + (1 - self.alpha) * bce_loss.mean()


In [None]:
# Dataset Definition 
class HemorrhageDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.image_paths = sorted(glob.glob(os.path.join(images_dir, '*')))
        self.mask_paths = sorted(glob.glob(os.path.join(masks_dir, '*')))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')

        # Extract only the green channel as a single-channel image
        img_np = np.array(img)[:, :, 1]  # green channel
        img = Image.fromarray(img_np)   # convert back to PIL Image (mode 'L')

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        mask = (mask > 0).float()

        filename = os.path.basename(self.image_paths[idx])
        return img, mask, filename



In [6]:
class AttentionBlock(nn.Module):
    """Additive attention block for U-Net skip connections."""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # W_g: gating signal transform
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # W_x: skip connection transform
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # psi: attention coefficient
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # g: gating signal (from decoder), x: skip features (from encoder)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi  # apply attention


class TripleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 3 times."""
    def __init__(self, in_c, mid1_c, mid2_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid1_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid1_c)
        self.conv2 = nn.Conv2d(mid1_c, mid2_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(mid2_c)
        self.conv3 = nn.Conv2d(mid2_c, out_c, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class DoubleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 2 times."""
    def __init__(self, in_c, mid_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid_c)
        self.conv2 = nn.Conv2d(mid_c, out_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x


class UNet(nn.Module):
    """U-Net with attention gates on skip connections and additional subsampling concat."""
    def __init__(self):
        super().__init__()
        # Encoder
        self.down1 = TripleConv(1, 32, 32, 64)
        self.down2 = TripleConv(64, 64, 64, 128)
        self.down3 = DoubleConv(128, 128, 256)
        self.down4 = DoubleConv(256, 256, 256)
        self.pool  = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = DoubleConv(256, 512, 256)

        # Decoder up and conv
        self.up4  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = DoubleConv(256+256, 256, 256)
        self.up3  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = DoubleConv(256+256, 128, 128)
        self.up2  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = TripleConv(128+128, 64, 64, 64)
        self.up1  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = TripleConv(64+64, 32, 32, 32)

        # Attention blocks for skip connections
        self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.att1 = AttentionBlock(F_g=64,  F_l=64,  F_int=32)

        # Final subsample, concat and output
        self.final_pool       = nn.MaxPool2d(2, 2)
        self.final_upsample   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.out_conv         = nn.Conv2d(33, 1, kernel_size=1)

    def forward(self, x):
        input_image = x
        # Encoder
        x1  = self.down1(x)
        x1p = self.pool(x1)
        x2  = self.down2(x1p)
        x2p = self.pool(x2)
        x3  = self.down3(x2p)
        x3p = self.pool(x3)
        x4  = self.down4(x3p)
        x4p = self.pool(x4)

        # Bottleneck
        xb  = self.bottleneck(x4p)

        # Decoder + Attention
        d4  = self.up4(xb)
        x4a = self.att4(g=d4, x=x4)
        d4  = torch.cat([x4a, d4], dim=1)
        d4  = self.dec4(d4)

        d3  = self.up3(d4)
        x3a = self.att3(g=d3, x=x3)
        d3  = torch.cat([x3a, d3], dim=1)
        d3  = self.dec3(d3)

        d2  = self.up2(d3)
        x2a = self.att2(g=d2, x=x2)
        d2  = torch.cat([x2a, d2], dim=1)
        d2  = self.dec2(d2)

        d1  = self.up1(d2)
        x1a = self.att1(g=d1, x=x1)
        d1  = torch.cat([x1a, d1], dim=1)
        d1  = self.dec1(d1)

        # Additional subsampling & concatenation
        d1s = self.final_pool(d1)
        ins = self.final_pool(input_image)
        cat = torch.cat([d1s, ins], dim=1)
        out = self.final_upsample(cat)
        out = self.out_conv(out)
        return out


In [7]:
# -------- Utilities --------
def get_bounding_boxes(binary_mask):
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return [cv2.boundingRect(cnt) for cnt in contours]

def box_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
    yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    areaA = boxA[2] * boxA[3]
    areaB = boxB[2] * boxB[3]
    return interArea / float(areaA + areaB - interArea + 1e-6)

def compute_detection_metrics(gt_boxes, pred_boxes, iou_thresh=0.5):
    matched_gt = set()
    tp = 0
    for pb in pred_boxes:
        best_iou, best_j = 0, -1
        for j, gb in enumerate(gt_boxes):
            if j in matched_gt: continue
            iou = box_iou(pb, gb)
            if iou > best_iou:
                best_iou, best_j = iou, j
        if best_iou >= iou_thresh:
            tp += 1; matched_gt.add(best_j)
    fp = len(pred_boxes) - tp
    fn = len(gt_boxes) - tp
    prec = tp / (tp + fp + 1e-6)
    rec  = tp / (tp + fn + 1e-6)
    f1   = 2*prec*rec / (prec + rec + 1e-6)
    return prec, rec, f1

def mask_iou(gt, pred):
    gt_bool = gt.astype(bool)
    pred_bool = pred.astype(bool)
    inter = np.logical_and(gt_bool, pred_bool).sum()
    union = np.logical_or(gt_bool, pred_bool).sum()
    return inter / (union + 1e-6)

In [None]:
# Training Function 
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for imgs, masks, _ in tqdm(loader, desc="Training", leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        preds = model(imgs)

        # Apply pixel-wise loss weighting based on bounding boxes
        weights = torch.ones_like(masks)

        for i in range(masks.size(0)):
            # Convert mask to numpy and extract boxes
            mask_np = (masks[i][0].cpu().numpy() * 255).astype(np.uint8)
            boxes = get_bounding_boxes(mask_np)
            
            # Create weight map for that sample
            for x, y, w, h in boxes:
                weights[i, 0, y:y+h, x:x+w] = 3.0  # Weight 3x inside the box

        # Compute loss with weights
        loss = criterion(preds, masks)
        weighted_loss = (loss * weights).mean()

        weighted_loss.backward()
        optimizer.step()
        running_loss += weighted_loss.item()

    return running_loss / len(loader)


# Validation Function 
def validate(model, loader, criterion, iou_thresh, device):
    model.eval()
    total_loss = 0.0
    metrics = {'prec': [], 'rec': [], 'f1': [], 'iou': []}
    with torch.no_grad():
        for img, mask, _ in tqdm(loader, desc="Validating", leave=False):
            img, mask = img.to(device), mask.to(device)
            pred = torch.sigmoid(model(img))
            loss = criterion(pred, mask)
            total_loss += loss.item()

            pred_bin = (pred > 0.5).float()
            gt_np = (mask[0][0].cpu().numpy()*255).astype(np.uint8)
            pr_np = (pred_bin[0][0].cpu().numpy()*255).astype(np.uint8)
            boxes_gt = get_bounding_boxes(gt_np)
            boxes_pr = get_bounding_boxes(pr_np)
            p, r, f1 = compute_detection_metrics(boxes_gt, boxes_pr, iou_thresh)
            j = mask_iou(gt_np, pr_np)
            metrics['prec'].append(p)
            metrics['rec'].append(r)
            metrics['f1'].append(f1)
            metrics['iou'].append(j)

    avg_loss = total_loss / len(loader)
    return avg_loss, {k: np.mean(v) for k, v in metrics.items()}




In [None]:

# Constants
dataset_dir = '../final_dataset'
lr = 1e-4
iou_thresh = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Transforms
tf = transforms.Compose([
    transforms.ToTensor(),
])

# Datasets & Loaders
train_img = os.path.join(dataset_dir, 'train', 'images')
train_mask = os.path.join(dataset_dir, 'train', 'masks')
test_img = os.path.join(dataset_dir, 'test', 'images')
test_mask = os.path.join(dataset_dir, 'test', 'masks')



train_ds = HemorrhageDataset(train_img, train_mask, transform=tf)
full_test = HemorrhageDataset(test_img, test_mask, transform=tf)
val_size = len(full_test) // 2
val_ds, final_ds = random_split(full_test, [val_size, len(full_test) - val_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
final_loader = DataLoader(final_ds, batch_size=1, shuffle=False)

# Model, Loss, Optimizer
model = UNet().to(device)
criterion = DiceBCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training Loop
num_epochs = 40
best_iou = -1.0
os.makedirs(f"{MODEL_NAME}", exist_ok=True)
save_path = os.path.join(f"{MODEL_NAME}", f"{MODEL_NAME}.pth")

In [10]:
best_iou = -1.0
patience = 15
patience_counter = 0

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_metrics = validate(model, val_loader, criterion, iou_thresh, device)

    print(f"Epoch {epoch}/{num_epochs} - Train Loss: {train_loss:.4f} | ",
          f"Val Loss: {val_loss:.4f} \nPrec: {val_metrics['prec']:.4f},",
          f"Rec: {val_metrics['rec']:.4f}, F1: {val_metrics['f1']:.4f}, IoU: {val_metrics['iou']:.4f}")

    if val_metrics['iou'] > best_iou:
        best_iou = val_metrics['iou']
        torch.save(model.state_dict(), save_path)
        print(f"Model saved with IoU: {best_iou:.4f}")
        patience_counter = 0  # reset if improved
    else:
        patience_counter += 1
        print(f"No improvement. Patience counter: {patience_counter}/{patience}")

    if patience_counter >= patience:
        print(f"Early stopping triggered after {patience} epochs without improvement.")
        break

    print()


                                                             

Epoch 1/40 - Train Loss: 0.4641 |  Val Loss: 0.8379 
Prec: 0.2375, Rec: 0.0858, F1: 0.1089, IoU: 0.1291
Model saved with IoU: 0.1291



                                                             

Epoch 2/40 - Train Loss: 0.3472 |  Val Loss: 0.8378 
Prec: 0.2731, Rec: 0.1964, F1: 0.1935, IoU: 0.2111
Model saved with IoU: 0.2111



                                                             

Epoch 3/40 - Train Loss: 0.3270 |  Val Loss: 0.8377 
Prec: 0.3856, Rec: 0.1563, F1: 0.1937, IoU: 0.1975
No improvement. Patience counter: 1/15



                                                             

Epoch 4/40 - Train Loss: 0.3175 |  Val Loss: 0.8376 
Prec: 0.3612, Rec: 0.1488, F1: 0.1848, IoU: 0.1945
No improvement. Patience counter: 2/15



                                                             

Epoch 5/40 - Train Loss: 0.3056 |  Val Loss: 0.8376 
Prec: 0.3594, Rec: 0.1490, F1: 0.1827, IoU: 0.1994
No improvement. Patience counter: 3/15



                                                             

Epoch 6/40 - Train Loss: 0.3009 |  Val Loss: 0.8377 
Prec: 0.3034, Rec: 0.1267, F1: 0.1563, IoU: 0.1888
No improvement. Patience counter: 4/15



                                                             

Epoch 7/40 - Train Loss: 0.2926 |  Val Loss: 0.8375 
Prec: 0.3497, Rec: 0.1714, F1: 0.2011, IoU: 0.2278
Model saved with IoU: 0.2278



                                                             

Epoch 8/40 - Train Loss: 0.2846 |  Val Loss: 0.8377 
Prec: 0.3647, Rec: 0.1497, F1: 0.1829, IoU: 0.2017
No improvement. Patience counter: 1/15



                                                             

Epoch 9/40 - Train Loss: 0.2786 |  Val Loss: 0.8378 
Prec: 0.3481, Rec: 0.1931, F1: 0.2160, IoU: 0.2309
Model saved with IoU: 0.2309



                                                             

Epoch 10/40 - Train Loss: 0.2747 |  Val Loss: 0.8376 
Prec: 0.3576, Rec: 0.1743, F1: 0.2008, IoU: 0.2218
No improvement. Patience counter: 1/15



                                                             

Epoch 11/40 - Train Loss: 0.2685 |  Val Loss: 0.8373 
Prec: 0.3352, Rec: 0.1710, F1: 0.2002, IoU: 0.2279
No improvement. Patience counter: 2/15



                                                             

Epoch 12/40 - Train Loss: 0.2644 |  Val Loss: 0.8373 
Prec: 0.3793, Rec: 0.1939, F1: 0.2253, IoU: 0.2434
Model saved with IoU: 0.2434



                                                             

Epoch 13/40 - Train Loss: 0.2605 |  Val Loss: 0.8375 
Prec: 0.3869, Rec: 0.1384, F1: 0.1816, IoU: 0.2009
No improvement. Patience counter: 1/15



                                                             

Epoch 14/40 - Train Loss: 0.2566 |  Val Loss: 0.8374 
Prec: 0.4069, Rec: 0.1756, F1: 0.2202, IoU: 0.2312
No improvement. Patience counter: 2/15



                                                             

Epoch 15/40 - Train Loss: 0.2523 |  Val Loss: 0.8373 
Prec: 0.3925, Rec: 0.1745, F1: 0.2136, IoU: 0.2321
No improvement. Patience counter: 3/15



                                                             

Epoch 16/40 - Train Loss: 0.2489 |  Val Loss: 0.8376 
Prec: 0.3421, Rec: 0.1817, F1: 0.2027, IoU: 0.2242
No improvement. Patience counter: 4/15



                                                             

Epoch 17/40 - Train Loss: 0.2462 |  Val Loss: 0.8374 
Prec: 0.3546, Rec: 0.1868, F1: 0.2149, IoU: 0.2369
No improvement. Patience counter: 5/15



                                                             

Epoch 18/40 - Train Loss: 0.2427 |  Val Loss: 0.8372 
Prec: 0.3792, Rec: 0.1664, F1: 0.2088, IoU: 0.2519
Model saved with IoU: 0.2519



                                                             

Epoch 19/40 - Train Loss: 0.2387 |  Val Loss: 0.8374 
Prec: 0.3678, Rec: 0.1570, F1: 0.1959, IoU: 0.2275
No improvement. Patience counter: 1/15



                                                             

Epoch 20/40 - Train Loss: 0.2380 |  Val Loss: 0.8377 
Prec: 0.3689, Rec: 0.1434, F1: 0.1816, IoU: 0.2021
No improvement. Patience counter: 2/15



                                                             

Epoch 21/40 - Train Loss: 0.2337 |  Val Loss: 0.8373 
Prec: 0.3564, Rec: 0.1628, F1: 0.1950, IoU: 0.2403
No improvement. Patience counter: 3/15



                                                             

Epoch 22/40 - Train Loss: 0.2324 |  Val Loss: 0.8375 
Prec: 0.3748, Rec: 0.1551, F1: 0.1963, IoU: 0.2202
No improvement. Patience counter: 4/15



                                                             

Epoch 23/40 - Train Loss: 0.2291 |  Val Loss: 0.8374 
Prec: 0.3717, Rec: 0.1879, F1: 0.2198, IoU: 0.2387
No improvement. Patience counter: 5/15



                                                             

Epoch 24/40 - Train Loss: 0.2269 |  Val Loss: 0.8373 
Prec: 0.4104, Rec: 0.1860, F1: 0.2246, IoU: 0.2479
No improvement. Patience counter: 6/15



                                                             

Epoch 25/40 - Train Loss: 0.2253 |  Val Loss: 0.8377 
Prec: 0.3600, Rec: 0.1463, F1: 0.1818, IoU: 0.2005
No improvement. Patience counter: 7/15



                                                             

Epoch 26/40 - Train Loss: 0.2226 |  Val Loss: 0.8373 
Prec: 0.3667, Rec: 0.1812, F1: 0.2126, IoU: 0.2462
No improvement. Patience counter: 8/15



                                                             

Epoch 27/40 - Train Loss: 0.2202 |  Val Loss: 0.8372 
Prec: 0.3774, Rec: 0.1887, F1: 0.2268, IoU: 0.2596
Model saved with IoU: 0.2596



                                                             

Epoch 28/40 - Train Loss: 0.2214 |  Val Loss: 0.8374 
Prec: 0.3582, Rec: 0.1603, F1: 0.1947, IoU: 0.2155
No improvement. Patience counter: 1/15



                                                             

Epoch 29/40 - Train Loss: 0.2179 |  Val Loss: 0.8373 
Prec: 0.3825, Rec: 0.1714, F1: 0.2107, IoU: 0.2356
No improvement. Patience counter: 2/15



                                                             

Epoch 30/40 - Train Loss: 0.2157 |  Val Loss: 0.8376 
Prec: 0.3716, Rec: 0.1924, F1: 0.2186, IoU: 0.2589
No improvement. Patience counter: 3/15



                                                             

Epoch 31/40 - Train Loss: 0.2142 |  Val Loss: 0.8373 
Prec: 0.3765, Rec: 0.1830, F1: 0.2184, IoU: 0.2506
No improvement. Patience counter: 4/15



                                                             

Epoch 32/40 - Train Loss: 0.2139 |  Val Loss: 0.8373 
Prec: 0.4124, Rec: 0.1975, F1: 0.2364, IoU: 0.2621
Model saved with IoU: 0.2621



                                                             

Epoch 33/40 - Train Loss: 0.2107 |  Val Loss: 0.8374 
Prec: 0.3645, Rec: 0.1932, F1: 0.2201, IoU: 0.2486
No improvement. Patience counter: 1/15



                                                             

Epoch 34/40 - Train Loss: 0.2102 |  Val Loss: 0.8373 
Prec: 0.4149, Rec: 0.1856, F1: 0.2293, IoU: 0.2528
No improvement. Patience counter: 2/15



                                                             

Epoch 35/40 - Train Loss: 0.2087 |  Val Loss: 0.8374 
Prec: 0.3972, Rec: 0.1541, F1: 0.1973, IoU: 0.2254
No improvement. Patience counter: 3/15



                                                             

Epoch 36/40 - Train Loss: 0.2066 |  Val Loss: 0.8373 
Prec: 0.3575, Rec: 0.2085, F1: 0.2338, IoU: 0.2555
No improvement. Patience counter: 4/15



                                                             

Epoch 37/40 - Train Loss: 0.2064 |  Val Loss: 0.8373 
Prec: 0.3963, Rec: 0.2068, F1: 0.2360, IoU: 0.2644
Model saved with IoU: 0.2644



                                                             

Epoch 38/40 - Train Loss: 0.2040 |  Val Loss: 0.8373 
Prec: 0.4042, Rec: 0.1799, F1: 0.2179, IoU: 0.2522
No improvement. Patience counter: 1/15



                                                             

Epoch 39/40 - Train Loss: 0.2040 |  Val Loss: 0.8372 
Prec: 0.4123, Rec: 0.1868, F1: 0.2264, IoU: 0.2548
No improvement. Patience counter: 2/15



                                                             

Epoch 40/40 - Train Loss: 0.2019 |  Val Loss: 0.8372 
Prec: 0.3871, Rec: 0.1752, F1: 0.2124, IoU: 0.2489
No improvement. Patience counter: 3/15





In [None]:
import os
import torch
import numpy as np
import cv2
from tqdm import tqdm
from PIL import Image

def test_model(model, loader, iou_thresh, device, out_dir):
    os.makedirs(out_dir, exist_ok=True)

    model.eval()
    results = {'prec': [], 'rec': [], 'f1': [], 'ious': []}
    total_gt_area = 0
    covered_gt_area = 0

    with torch.no_grad():
        for img, mask, name in tqdm(loader, desc="Testing", leave=False):
            fname = name[0] if isinstance(name, (list, tuple)) else name
            img, mask = img.to(device), mask.to(device)
            pred = torch.sigmoid(model(img))
            pred_bin = (pred > 0.5).float()

            # Convert tensors to numpy arrays
            img_np = (img[0][0].cpu().numpy() * 255).astype(np.uint8)  # shape: (H, W)
            gt_np  = (mask[0][0].cpu().numpy() * 255).astype(np.uint8)
            pr_np  = (pred_bin[0][0].cpu().numpy() * 255).astype(np.uint8)

            # Bounding boxes and metrics
            boxes_gt = get_bounding_boxes(gt_np)
            boxes_pr = get_bounding_boxes(pr_np)
            p, r, f1 = compute_detection_metrics(boxes_gt, boxes_pr, iou_thresh)
            iou = mask_iou(gt_np, pr_np)

            # Track metrics
            results['prec'].append(p)
            results['rec'].append(r)
            results['f1'].append(f1)
            results['ious'].append(iou)

            # Area-based Box Coverage Calculation
            for gb in boxes_gt:
                x, y, w, h = gb
                gt_area = w * h
                total_gt_area += gt_area

                max_iou = 0
                for pb in boxes_pr:
                    iou_val = box_iou(pb, gb)
                    if iou_val > max_iou:
                        max_iou = iou_val

                if max_iou >= iou_thresh:
                    covered_gt_area += gt_area

            # Create composite image (4-panel BGR)
            h, w = img_np.shape
            spacing = 10
            comp_w = w * 4 + spacing * 3
            comp_h = h
            composite = np.ones((comp_h, comp_w, 3), dtype=np.uint8) * 255

            # Convert grayscale to BGR
            img_color = cv2.cvtColor(img_np, cv2.COLOR_GRAY2BGR)
            gt_overlay = img_color.copy()
            pr_overlay = img_color.copy()
            both_overlay = img_color.copy()

            # Draw boxes
            for x, y, ww, hh in boxes_gt:
                cv2.rectangle(gt_overlay, (x, y), (x + ww, y + hh), (0, 255, 0), 2)
                cv2.rectangle(both_overlay, (x, y), (x + ww, y + hh), (0, 255, 0), 2)
            for x, y, ww, hh in boxes_pr:
                cv2.rectangle(pr_overlay, (x, y), (x + ww, y + hh), (255, 0, 0), 2)
                cv2.rectangle(both_overlay, (x, y), (x + ww, y + hh), (255, 0, 0), 2)

            # Assemble composite
            composite[0:h, 0:w] = img_color
            composite[0:h, w + spacing:w * 2 + spacing] = gt_overlay
            composite[0:h, w * 2 + spacing * 2:w * 3 + spacing * 2] = pr_overlay
            composite[0:h, w * 3 + spacing * 3:w * 4 + spacing * 3] = both_overlay

            # Save composite image
            Image.fromarray(composite).save(os.path.join(out_dir, fname))

    # Compute area-based Box Coverage
    box_coverage = covered_gt_area / (total_gt_area + 1e-6)

    # Summary metrics
    summary = {
        'Precision': np.mean(results['prec']),
        'Recall': np.mean(results['rec']),
        'F1': np.mean(results['f1']),
        'IoU': np.mean(results['ious']) if results['ious'] else 0.0,
        'Box Coverage': box_coverage
    }

    return summary

# Test
model.load_state_dict(torch.load(save_path))
summary = test_model(model, final_loader, iou_thresh, device, out_dir=os.path.join('.', f"{MODEL_NAME}_Result"))
print("Test Summary:", summary)


  model.load_state_dict(torch.load(save_path))
                                                                                

Test Summary: {'Precision': 0.3541874737398547, 'Recall': 0.2522272299324383, 'F1': 0.23466042208540522, 'IoU': 0.23988951882244078, 'Box Coverage': 0.1878297493758683}


