In [1]:
import os
import glob
from PIL import Image
import random
import numpy as np
import cv2
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm

def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seeding(42)

In [2]:
MODEL_NAME = "BB_Custom_MODEL_Dice_BCE"
batch_size = 1

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.smooth = smooth
        self.bce_fn = nn.BCEWithLogitsLoss(reduction='none')  #  changed

    def forward(self, logits, targets, reduction='mean'):
        # BCE loss map (per-pixel)
        bce_loss = self.bce_fn(logits, targets)  # shape: (B, 1, H, W)

        # Dice loss (per-batch)
        probs = torch.sigmoid(logits)
        batch_size = probs.shape[0]
        dice_losses = []
        for i in range(batch_size):
            p = probs[i].view(-1)
            g = targets[i].view(-1)
            inter = (p * g).sum()
            dice = 1 - (2*inter + self.smooth) / (p.sum() + g.sum() + self.smooth)
            dice_losses.append(dice)
        dice_loss = torch.stack(dice_losses).mean()

        if reduction == 'none':
            # Used for pixel-wise loss map
            return self.alpha + (1 - self.alpha) * bce_loss
        else:
            return self.alpha * dice_loss + (1 - self.alpha) * bce_loss.mean()


In [4]:
# -------- Dataset Definition --------
class HemorrhageDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.image_paths = sorted(glob.glob(os.path.join(images_dir, '*')))
        self.mask_paths = sorted(glob.glob(os.path.join(masks_dir, '*')))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')

        # Extract only the green channel as a single-channel image
        img_np = np.array(img)[:, :, 1]  # green channel
        img = Image.fromarray(img_np)   # convert back to PIL Image (mode 'L')

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        mask = (mask > 0).float()

        filename = os.path.basename(self.image_paths[idx])
        return img, mask, filename



In [None]:

#  DoubleConv and TripleConv

class TripleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 3 times."""
    def __init__(self, in_c, mid1_c, mid2_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid1_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid1_c)
        self.conv2 = nn.Conv2d(mid1_c, mid2_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(mid2_c)
        self.conv3 = nn.Conv2d(mid2_c, out_c, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x

class DoubleConv(nn.Module):
    """Conv -> BN -> ReLU repeated 2 times."""
    def __init__(self, in_c, mid_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, mid_c, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(mid_c)
        self.conv2 = nn.Conv2d(mid_c, out_c, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x


#  UNetRetina (with additional subsampling and concatenation)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # ENCODER 
        # Block 1: 3 conv -> (32, 32, 64)
        self.down1 = TripleConv(
            in_c=1,       # green channel only
            mid1_c=32,
            mid2_c=32,
            out_c=64
        )
        # Block 2: 3 conv -> (64, 64, 128)
        self.down2 = TripleConv(
            in_c=64,
            mid1_c=64,
            mid2_c=64,
            out_c=128
        )
        # Block 3: 2 conv -> (128, 128, 256)
        self.down3 = DoubleConv(
            in_c=128,
            mid_c=128,
            out_c=256
        )
        # Block 4: 2 conv -> (256, 256, 256)
        self.down4 = DoubleConv(
            in_c=256,
            mid_c=256,
            out_c=256
        )
        
        self.pool = nn.MaxPool2d(2, 2)
        
    
        # Bottleneck: 2 conv -> (256 -> 512 -> 256)

        self.bottleneck = DoubleConv(
            in_c=256,
            mid_c=512,
            out_c=256
        )
        
        # DECODER 
        # Each decoder block: upsample, concat skip connection, then decode.
        self.up4  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = DoubleConv(in_c=256+256, mid_c=256, out_c=256)
        
        self.up3  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = DoubleConv(in_c=256+256, mid_c=128, out_c=128)
        
        self.up2  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = TripleConv(in_c=128+128, mid1_c=64, mid2_c=64, out_c=64)
        
        self.up1  = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = TripleConv(in_c=64+64, mid1_c=32, mid2_c=32, out_c=32)
        
        #  ADDITIONAL SUBSAMPLING & CONCATENATION 
        self.final_pool = nn.MaxPool2d(2, 2)  # Additional subsampling step
        self.final_upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # After concatenation: 32 channels (from decoder) + 1 channel (from input) = 33 channels
        self.out_conv = nn.Conv2d(33, 1, kernel_size=1)
        

    def forward(self, x):
        # -------- Encoder --------
        input_image = x
        
        # Block 1
        x1 = self.down1(x)    #  64 channels
        x1p = self.pool(x1)   # subsampled
        
        # Block 2
        x2 = self.down2(x1p)  # 128 channels
        x2p = self.pool(x2)   # subsampled
        
        # Block 3
        x3 = self.down3(x2p)  # 256 channels
        x3p = self.pool(x3)   # subsampled
        
        # Block 4
        x4 = self.down4(x3p)  # 256 channels
        x4p = self.pool(x4)   # subsampled
        
        #  Bottleneck 
        xb = self.bottleneck(x4p)  # 256 -> 512 -> 256
        
        #  Decoder 
        xd4 = self.up4(xb)               
        xd4 = torch.cat([x4, xd4], dim=1)  
        xd4 = self.dec4(xd4)             
        
        xd3 = self.up3(xd4)              
        xd3 = torch.cat([x3, xd3], dim=1) 
        xd3 = self.dec3(xd3)             
        
        xd2 = self.up2(xd3)              
        xd2 = torch.cat([x2, xd2], dim=1) 
        xd2 = self.dec2(xd2)             
        
        xd1 = self.up1(xd2)              
        xd1 = torch.cat([x1, xd1], dim=1) 
        xd1 = self.dec1(xd1)             
        
        # Additional Subsampling & Concatenation 
        xd1_sub = self.final_pool(xd1)          # [B,32,H/2,W/2]
        input_sub = self.final_pool(input_image) # [B,1,H/2,W/2]
        final_cat = torch.cat([xd1_sub, input_sub], dim=1)  # [B,33,H/2,W/2]
        final_up = self.final_upsample(final_cat)  # [B,33,H,W]
        
        out = self.out_conv(final_up)  # [B,1,H,W]
        
        return out


In [None]:
# Utilities 
def get_bounding_boxes(binary_mask):
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return [cv2.boundingRect(cnt) for cnt in contours]

def box_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[0]+boxA[2], boxB[0]+boxB[2])
    yB = min(boxA[1]+boxA[3], boxB[1]+boxB[3])
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    areaA = boxA[2] * boxA[3]
    areaB = boxB[2] * boxB[3]
    return interArea / float(areaA + areaB - interArea + 1e-6)

def compute_detection_metrics(gt_boxes, pred_boxes, iou_thresh=0.5):
    matched_gt = set()
    tp = 0
    for pb in pred_boxes:
        best_iou, best_j = 0, -1
        for j, gb in enumerate(gt_boxes):
            if j in matched_gt: continue
            iou = box_iou(pb, gb)
            if iou > best_iou:
                best_iou, best_j = iou, j
        if best_iou >= iou_thresh:
            tp += 1; matched_gt.add(best_j)
    fp = len(pred_boxes) - tp
    fn = len(gt_boxes) - tp
    prec = tp / (tp + fp + 1e-6)
    rec  = tp / (tp + fn + 1e-6)
    f1   = 2*prec*rec / (prec + rec + 1e-6)
    return prec, rec, f1

def mask_iou(gt, pred):
    gt_bool = gt.astype(bool)
    pred_bool = pred.astype(bool)
    inter = np.logical_and(gt_bool, pred_bool).sum()
    union = np.logical_or(gt_bool, pred_bool).sum()
    return inter / (union + 1e-6)

In [None]:
# Training Function 
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for imgs, masks, _ in tqdm(loader, desc="Training", leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        preds = model(imgs)

        # Apply pixel-wise loss weighting based on bounding boxes
        weights = torch.ones_like(masks)

        for i in range(masks.size(0)):
            # Convert mask to numpy and extract boxes
            mask_np = (masks[i][0].cpu().numpy() * 255).astype(np.uint8)
            boxes = get_bounding_boxes(mask_np)
            
            # Create weight map for that sample
            for x, y, w, h in boxes:
                weights[i, 0, y:y+h, x:x+w] = 3.0  # Weight 3x inside the box

        # Compute loss with weights
        loss = criterion(preds, masks)
        weighted_loss = (loss * weights).mean()

        weighted_loss.backward()
        optimizer.step()
        running_loss += weighted_loss.item()

    return running_loss / len(loader)


# Validation Function 
def validate(model, loader, criterion, iou_thresh, device):
    model.eval()
    total_loss = 0.0
    metrics = {'prec': [], 'rec': [], 'f1': [], 'iou': []}
    with torch.no_grad():
        for img, mask, _ in tqdm(loader, desc="Validating", leave=False):
            img, mask = img.to(device), mask.to(device)
            pred = torch.sigmoid(model(img))
            loss = criterion(pred, mask)
            total_loss += loss.item()

            pred_bin = (pred > 0.5).float()
            gt_np = (mask[0][0].cpu().numpy()*255).astype(np.uint8)
            pr_np = (pred_bin[0][0].cpu().numpy()*255).astype(np.uint8)
            boxes_gt = get_bounding_boxes(gt_np)
            boxes_pr = get_bounding_boxes(pr_np)
            p, r, f1 = compute_detection_metrics(boxes_gt, boxes_pr, iou_thresh)
            j = mask_iou(gt_np, pr_np)
            metrics['prec'].append(p)
            metrics['rec'].append(r)
            metrics['f1'].append(f1)
            metrics['iou'].append(j)

    avg_loss = total_loss / len(loader)
    return avg_loss, {k: np.mean(v) for k, v in metrics.items()}




In [None]:
#  Setup & Execution (Notebook) 
# Constants
dataset_dir = './final_dataset'
lr = 1e-4
iou_thresh = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Transforms
tf = transforms.Compose([
    transforms.ToTensor(),
])

# Datasets & Loaders
train_img = os.path.join(dataset_dir, 'train', 'images')
train_mask = os.path.join(dataset_dir, 'train', 'masks')
test_img = os.path.join(dataset_dir, 'test', 'images')
test_mask = os.path.join(dataset_dir, 'test', 'masks')



train_ds = HemorrhageDataset(train_img, train_mask, transform=tf)
full_test = HemorrhageDataset(test_img, test_mask, transform=tf)
val_size = len(full_test) // 2
val_ds, final_ds = random_split(full_test, [val_size, len(full_test) - val_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
final_loader = DataLoader(final_ds, batch_size=1, shuffle=False)

# Model, Loss, Optimizer
model = UNet().to(device)
criterion = DiceBCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training Loop
num_epochs = 40
best_iou = -1.0
os.makedirs(f"{MODEL_NAME}", exist_ok=True)
save_path = os.path.join(f"{MODEL_NAME}", f"{MODEL_NAME}.pth")

In [10]:
best_iou = -1.0
patience = 10
patience_counter = 0

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_metrics = validate(model, val_loader, criterion, iou_thresh, device)

    print(f"Epoch {epoch}/{num_epochs} - Train Loss: {train_loss:.4f} | ",
          f"Val Loss: {val_loss:.4f} \nPrec: {val_metrics['prec']:.4f},",
          f"Rec: {val_metrics['rec']:.4f}, F1: {val_metrics['f1']:.4f}, IoU: {val_metrics['iou']:.4f}")

    if val_metrics['iou'] > best_iou:
        best_iou = val_metrics['iou']
        torch.save(model.state_dict(), save_path)
        print(f"Model saved with IoU: {best_iou:.4f}")
        patience_counter = 0  # reset if improved
    else:
        patience_counter += 1
        print(f"No improvement. Patience counter: {patience_counter}/{patience}")

    if patience_counter >= patience:
        print(f"Early stopping triggered after {patience} epochs without improvement.")
        break

    print()


                                                           

Epoch 1/40 - Train Loss: 0.5832 |  Val Loss: 0.8461 
Prec: 0.0817, Rec: 0.3168, F1: 0.1171, IoU: 0.1994
Model saved with IoU: 0.1994



                                                           

Epoch 2/40 - Train Loss: 0.4689 |  Val Loss: 0.8361 
Prec: 0.0984, Rec: 0.3807, F1: 0.1428, IoU: 0.1940
No improvement. Patience counter: 1/10



                                                           

Epoch 3/40 - Train Loss: 0.3815 |  Val Loss: 0.8306 
Prec: 0.1625, Rec: 0.2904, F1: 0.1994, IoU: 0.2881
Model saved with IoU: 0.2881



                                                           

Epoch 4/40 - Train Loss: 0.3311 |  Val Loss: 0.8298 
Prec: 0.2327, Rec: 0.2897, F1: 0.2511, IoU: 0.3077
Model saved with IoU: 0.3077



                                                           

Epoch 5/40 - Train Loss: 0.2926 |  Val Loss: 0.8299 
Prec: 0.1961, Rec: 0.2983, F1: 0.2255, IoU: 0.3078
Model saved with IoU: 0.3078



                                                           

Epoch 6/40 - Train Loss: 0.2711 |  Val Loss: 0.8298 
Prec: 0.2464, Rec: 0.3466, F1: 0.2777, IoU: 0.3069
No improvement. Patience counter: 1/10



                                                           

Epoch 7/40 - Train Loss: 0.2438 |  Val Loss: 0.8298 
Prec: 0.2486, Rec: 0.3078, F1: 0.2649, IoU: 0.3021
No improvement. Patience counter: 2/10



                                                           

Epoch 8/40 - Train Loss: 0.2290 |  Val Loss: 0.8297 
Prec: 0.2519, Rec: 0.2805, F1: 0.2609, IoU: 0.2970
No improvement. Patience counter: 3/10



                                                           

Epoch 9/40 - Train Loss: 0.2056 |  Val Loss: 0.8302 
Prec: 0.2694, Rec: 0.2278, F1: 0.2401, IoU: 0.2664
No improvement. Patience counter: 4/10



                                                           

Epoch 10/40 - Train Loss: 0.2127 |  Val Loss: 0.8301 
Prec: 0.2460, Rec: 0.2672, F1: 0.2521, IoU: 0.2830
No improvement. Patience counter: 5/10



                                                           

Epoch 11/40 - Train Loss: 0.1808 |  Val Loss: 0.8295 
Prec: 0.1828, Rec: 0.3603, F1: 0.2323, IoU: 0.3234
Model saved with IoU: 0.3234



                                                           

Epoch 12/40 - Train Loss: 0.1715 |  Val Loss: 0.8296 
Prec: 0.2468, Rec: 0.2848, F1: 0.2530, IoU: 0.3005
No improvement. Patience counter: 1/10



                                                           

Epoch 13/40 - Train Loss: 0.1800 |  Val Loss: 0.8298 
Prec: 0.2484, Rec: 0.3346, F1: 0.2790, IoU: 0.3112
No improvement. Patience counter: 2/10



                                                           

Epoch 14/40 - Train Loss: 0.1593 |  Val Loss: 0.8294 
Prec: 0.2030, Rec: 0.3112, F1: 0.2360, IoU: 0.3140
No improvement. Patience counter: 3/10



                                                           

Epoch 15/40 - Train Loss: 0.1566 |  Val Loss: 0.8301 
Prec: 0.1963, Rec: 0.2381, F1: 0.1974, IoU: 0.2435
No improvement. Patience counter: 4/10



                                                           

Epoch 16/40 - Train Loss: 0.1447 |  Val Loss: 0.8298 
Prec: 0.2611, Rec: 0.2680, F1: 0.2600, IoU: 0.2708
No improvement. Patience counter: 5/10



                                                           

Epoch 17/40 - Train Loss: 0.1677 |  Val Loss: 0.8303 
Prec: 0.2531, Rec: 0.2645, F1: 0.2389, IoU: 0.2616
No improvement. Patience counter: 6/10



                                                           

Epoch 18/40 - Train Loss: 0.1516 |  Val Loss: 0.8297 
Prec: 0.2274, Rec: 0.2565, F1: 0.2276, IoU: 0.2952
No improvement. Patience counter: 7/10



                                                           

Epoch 19/40 - Train Loss: 0.1381 |  Val Loss: 0.8298 
Prec: 0.2722, Rec: 0.2877, F1: 0.2737, IoU: 0.2829
No improvement. Patience counter: 8/10



                                                           

Epoch 20/40 - Train Loss: 0.1277 |  Val Loss: 0.8296 
Prec: 0.3468, Rec: 0.3636, F1: 0.3246, IoU: 0.3236
Model saved with IoU: 0.3236



                                                           

Epoch 21/40 - Train Loss: 0.1278 |  Val Loss: 0.8299 
Prec: 0.2686, Rec: 0.3222, F1: 0.2876, IoU: 0.3011
No improvement. Patience counter: 1/10



                                                           

Epoch 22/40 - Train Loss: 0.1410 |  Val Loss: 0.8303 
Prec: 0.2047, Rec: 0.2869, F1: 0.2218, IoU: 0.2737
No improvement. Patience counter: 2/10



                                                           

Epoch 23/40 - Train Loss: 0.1262 |  Val Loss: 0.8295 
Prec: 0.2408, Rec: 0.3518, F1: 0.2791, IoU: 0.3110
No improvement. Patience counter: 3/10



                                                           

Epoch 24/40 - Train Loss: 0.1141 |  Val Loss: 0.8297 
Prec: 0.2160, Rec: 0.3042, F1: 0.2462, IoU: 0.2964
No improvement. Patience counter: 4/10



                                                           

Epoch 25/40 - Train Loss: 0.1124 |  Val Loss: 0.8296 
Prec: 0.2189, Rec: 0.3050, F1: 0.2492, IoU: 0.3031
No improvement. Patience counter: 5/10



                                                           

Epoch 26/40 - Train Loss: 0.1332 |  Val Loss: 0.8295 
Prec: 0.3070, Rec: 0.2778, F1: 0.2852, IoU: 0.3058
No improvement. Patience counter: 6/10



                                                           

Epoch 27/40 - Train Loss: 0.1084 |  Val Loss: 0.8296 
Prec: 0.1808, Rec: 0.2858, F1: 0.2079, IoU: 0.2953
No improvement. Patience counter: 7/10



                                                           

Epoch 28/40 - Train Loss: 0.1069 |  Val Loss: 0.8293 
Prec: 0.2116, Rec: 0.3257, F1: 0.2523, IoU: 0.3208
No improvement. Patience counter: 8/10



                                                           

Epoch 29/40 - Train Loss: 0.1048 |  Val Loss: 0.8300 
Prec: 0.2263, Rec: 0.3137, F1: 0.2574, IoU: 0.2839
No improvement. Patience counter: 9/10



                                                           

Epoch 30/40 - Train Loss: 0.1097 |  Val Loss: 0.8335 
Prec: 0.0978, Rec: 0.3785, F1: 0.1403, IoU: 0.1663
No improvement. Patience counter: 10/10
Early stopping triggered after 10 epochs without improvement.




In [None]:
import os
import torch
import numpy as np
import cv2
from tqdm import tqdm
from PIL import Image

def test_model(model, loader, iou_thresh, device, out_dir):
    os.makedirs(out_dir, exist_ok=True)

    model.eval()
    results = {'prec': [], 'rec': [], 'f1': [], 'ious': []}
    total_gt_area = 0
    covered_gt_area = 0

    with torch.no_grad():
        for img, mask, name in tqdm(loader, desc="Testing", leave=False):
            fname = name[0] if isinstance(name, (list, tuple)) else name
            img, mask = img.to(device), mask.to(device)
            pred = torch.sigmoid(model(img))
            pred_bin = (pred > 0.5).float()

            # Convert tensors to numpy arrays
            img_np = (img[0][0].cpu().numpy() * 255).astype(np.uint8)  # shape: (H, W)
            gt_np  = (mask[0][0].cpu().numpy() * 255).astype(np.uint8)
            pr_np  = (pred_bin[0][0].cpu().numpy() * 255).astype(np.uint8)

            # Bounding boxes and metrics
            boxes_gt = get_bounding_boxes(gt_np)
            boxes_pr = get_bounding_boxes(pr_np)
            p, r, f1 = compute_detection_metrics(boxes_gt, boxes_pr, iou_thresh)
            iou = mask_iou(gt_np, pr_np)

            # Track metrics
            results['prec'].append(p)
            results['rec'].append(r)
            results['f1'].append(f1)
            results['ious'].append(iou)

            # Area-based Box Coverage Calculation
            for gb in boxes_gt:
                x, y, w, h = gb
                gt_area = w * h
                total_gt_area += gt_area

                max_iou = 0
                for pb in boxes_pr:
                    iou_val = box_iou(pb, gb)
                    if iou_val > max_iou:
                        max_iou = iou_val

                if max_iou >= iou_thresh:
                    covered_gt_area += gt_area

            # Create composite image (4-panel BGR)
            h, w = img_np.shape
            spacing = 10
            comp_w = w * 4 + spacing * 3
            comp_h = h
            composite = np.ones((comp_h, comp_w, 3), dtype=np.uint8) * 255

            # Convert grayscale to BGR
            img_color = cv2.cvtColor(img_np, cv2.COLOR_GRAY2BGR)
            gt_overlay = img_color.copy()
            pr_overlay = img_color.copy()
            both_overlay = img_color.copy()

            # Draw boxes
            for x, y, ww, hh in boxes_gt:
                cv2.rectangle(gt_overlay, (x, y), (x + ww, y + hh), (0, 255, 0), 2)
                cv2.rectangle(both_overlay, (x, y), (x + ww, y + hh), (0, 255, 0), 2)
            for x, y, ww, hh in boxes_pr:
                cv2.rectangle(pr_overlay, (x, y), (x + ww, y + hh), (255, 0, 0), 2)
                cv2.rectangle(both_overlay, (x, y), (x + ww, y + hh), (255, 0, 0), 2)

            # Assemble composite
            composite[0:h, 0:w] = img_color
            composite[0:h, w + spacing:w * 2 + spacing] = gt_overlay
            composite[0:h, w * 2 + spacing * 2:w * 3 + spacing * 2] = pr_overlay
            composite[0:h, w * 3 + spacing * 3:w * 4 + spacing * 3] = both_overlay

            # Save composite image
            Image.fromarray(composite).save(os.path.join(out_dir, fname))

    # Compute area-based Box Coverage
    box_coverage = covered_gt_area / (total_gt_area + 1e-6)

    # Summary metrics
    summary = {
        'Precision': np.mean(results['prec']),
        'Recall': np.mean(results['rec']),
        'F1': np.mean(results['f1']),
        'IoU': np.mean(results['ious']) if results['ious'] else 0.0,
        'Box Coverage': box_coverage
    }

    return summary

# Test
model.load_state_dict(torch.load(save_path))
summary = test_model(model, final_loader, iou_thresh, device, out_dir=os.path.join('.', f"{MODEL_NAME}_Result"))
print("Test Summary:", summary)


                                                        

Test Summary: {'Precision': 0.28094175126760934, 'Recall': 0.34385879927422897, 'F1': 0.3012129937553158, 'IoU': 0.4164911056037921, 'Box Coverage': 0.5668325797849816}


