Import Libraries

In [None]:
import os
import re
import json
import glob
import shutil
import random
from pathlib import Path
from collections import Counter
import numpy as np
import pandas as pd
import cv2
from PIL import Image, ImageDraw
import requests
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import segmentation_models_pytorch as smp
from pycocotools import mask as mask_utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from scipy.ndimage import zoom
from scipy.spatial.distance import cdist
from numpy.polynomial.polynomial import polyvander2d, polyval2d
from skimage.filters import threshold_otsu
from skimage.morphology import binary_opening, rectangle, remove_small_objects
from skimage.measure import label, regionprops
from tqdm import tqdm


preparing data loaders using pytorch tensors

In [None]:
def extract_year_from_image_id(image_id):
    try:
        return int(image_id.split('-')[1][:4])
    except Exception as e:
        raise ValueError(f"Failed to extract year from image ID '{image_id}': {e}")

def split_dataset_by_year(dataset, train_years, val_years, test_years):
    train_indices, val_indices, test_indices = [], [], []
    for idx, image_id in enumerate(dataset.image_ids):
        year = extract_year_from_image_id(image_id)
        if year in train_years:
            train_indices.append(idx)
        elif year in val_years:
            val_indices.append(idx)
        elif year in test_years:
            test_indices.append(idx)
    return train_indices, val_indices, test_indices

In [None]:
# Image + Mask Transform Wrapper ---
class ImageMaskTransform:
    def __init__(self, image_size=(512, 512)):
        self.image_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5]) 
        ])
        # Mask is NOT resized to avoid blurring, just converted to tensor
        self.mask_transform = transforms.ToTensor()

    def __call__(self, image, mask):
        image = self.image_transform(image)
        mask = self.mask_transform(mask)
        return image, mask

# Custom Dataset
class FilamentDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_index = self.index_images_by_id(image_dir)
        self.mask_dir = mask_dir
        self.image_ids = sorted(self.image_index.keys())
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = self.image_index[image_id]
        mask_path = os.path.join(self.mask_dir, f"{image_id}.png")
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        if self.transform:
            image, mask = self.transform(image, mask)
        mask = (mask > 0).float()  # binarize mask
        return image, mask, image_id

    @staticmethod
    def index_images_by_id(image_dir, extensions=(".jpg", ".jpeg", ".png")):
        image_index = {}
        for root, _, files in os.walk(image_dir):
            for file in files:
                if file.lower().endswith(extensions):
                    image_id = os.path.splitext(file)[0]
                    image_index[image_id] = os.path.join(root, file)
        return image_index


In [None]:
image_dir = ""
mask_dir = ""

# Load full dataset and split by year
full_dataset = FilamentDataset(image_dir=image_dir, mask_dir=mask_dir, transform=None)
train_years = {2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018}
val_years = {2019, 2020}
test_years = {2021, 2022}
train_idx, val_idx, test_idx = split_dataset_by_year(full_dataset, train_years, val_years, test_years)

# Toggle for quick prototyping
use_small_subset = False  # Set to False for full dataset

if use_small_subset:
    RANDOM_SEED = 42
    random.seed(RANDOM_SEED)
    train_idx = random.sample(train_idx, min(200, len(train_idx)))
    val_idx = random.sample(val_idx, min(50, len(val_idx)))
    test_idx = random.sample(test_idx, min(50, len(test_idx)))

# Shared Transform
shared_transform = ImageMaskTransform(image_size=(512, 512))

In [None]:
# Subset Data
train_dataset = Subset(FilamentDataset(image_dir, mask_dir, transform=shared_transform), train_idx)
val_dataset = Subset(FilamentDataset(image_dir, mask_dir, transform=shared_transform), val_idx)
test_dataset = Subset(FilamentDataset(image_dir, mask_dir, transform=shared_transform), test_idx)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=2)

U-NET

In [None]:
# Basic Double Convolution (Conv -> ReLU -> Conv -> ReLU)
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# Vanilla U-Net
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()

        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsampling path
        ch = in_channels
        for feature in features:
            self.downs.append(DoubleConv(ch, feature))
            ch = feature

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Upsampling path
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))

        # Final output conv
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder
        skip_connections = skip_connections[::-1]
        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)  # up-conv
            skip = skip_connections[i // 2]

            # Resize if needed (for odd input sizes)
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])

            # Concatenate skip connection
            x = torch.cat((skip, x), dim=1)
            x = self.ups[i + 1](x)  # double conv

        return self.final_conv(x)


In [None]:
# Soft IoU
def soft_iou(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = (pred + target - pred * target).sum(dim=(1, 2, 3))
    return ((intersection + smooth) / (union + smooth)).mean().item()

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.BCEWithLogitsLoss()

if len(train_loader) == 0 or len(val_loader) == 0:
    raise ValueError("Train or validation loader is empty.")

train_losses, val_losses = [], []
train_ious, val_ious = [], []

epochs = 50

for epoch in range(epochs):
    model.train()
    train_loss, train_iou = 0.0, 0.0
    for img, mask, _ in train_loader:
        img, mask = img.to(device), mask.to(device)
        pred = model(img)
        loss = loss_fn(pred, mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_iou += soft_iou(pred, mask)

    model.eval()
    val_loss, val_iou = 0.0, 0.0
    with torch.no_grad():
        for img, mask, _ in val_loader:
            img, mask = img.to(device), mask.to(device)
            pred = model(img)
            val_loss += loss_fn(pred, mask).item()
            val_iou += soft_iou(pred, mask)

    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    avg_train_iou = train_iou / len(train_loader)
    avg_val_iou = val_iou / len(val_loader)

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_ious.append(avg_train_iou)
    val_ious.append(avg_val_iou)

    print(f"Epoch {epoch+1}/{epochs} - "
          f"Train Loss: {avg_train_loss:.4f}, IoU: {avg_train_iou:.4f} | "
          f"Val Loss: {avg_val_loss:.4f}, IoU: {avg_val_iou:.4f}")

    if epoch % 1 == 0:
        model.eval()
        with torch.no_grad():
            random_idx = np.random.randint(0, len(val_dataset))
            image, mask, _ = val_dataset[random_idx]
            image = image.unsqueeze(0).to(device)
            mask = mask.unsqueeze(0).to(device)
            pred = model(image)
            pred_bin = (torch.sigmoid(pred[0]) > 0.5).float().cpu().numpy()[0]
            mask_vis = mask[0].cpu().numpy()[0]
            img_vis = image[0].cpu().numpy()[0]

            plt.figure(figsize=(12, 3))
            plt.subplot(1, 3, 1); plt.imshow(img_vis, cmap='gray'); plt.title('Input Image')
            plt.subplot(1, 3, 2); plt.imshow(mask_vis, cmap='gray'); plt.title('Ground Truth')
            plt.subplot(1, 3, 3); plt.imshow(pred_bin, cmap='gray'); plt.title('Prediction (Binary)')
            plt.tight_layout(); plt.show()

In [None]:

sns.set_context("paper", font_scale=1.4)
sns.set_style("white")  
colors = sns.color_palette("Dark2", 4)

fig, axes = plt.subplots(1, 2, figsize=(14, 5), dpi=150)

axes[0].plot(train_losses, label='Train Loss', marker='D', markersize=6,
             color=colors[0], linewidth=2)
axes[0].plot(val_losses, label='Val Loss', marker='o', markersize=6,
             color=colors[1], linewidth=2)
axes[0].set_title('Loss vs Epoch', fontsize=14, weight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend(frameon=False)

axes[1].plot(train_ious, label='Train IoU', marker='D', markersize=6,
             color=colors[2], linewidth=2)
axes[1].plot(val_ious, label='Val IoU', marker='o', markersize=6,
             color=colors[3], linewidth=2)
axes[1].set_title('IoU vs Epoch', fontsize=14, weight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('IoU')
axes[1].legend(frameon=False)

for ax in axes:
    ax.set_facecolor('white')  
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()


EVALUATION

In [None]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from scipy.ndimage import label as connected_components
import time

# 1. Pairwise IoU per image
def compute_pairwise_iou(gt_mask, pred_mask):
    """
    Compute pairwise IoU between objects in gt_mask and pred_mask (both binary 2D arrays).
    Returns a list of IoU values for all matched object pairs.
    """
    labeled_gt, n_gt = connected_components(gt_mask)
    labeled_pred, n_pred = connected_components(pred_mask)

    ious = []
    for i in range(1, n_gt + 1):
        gt_component = (labeled_gt == i)
        if gt_component.sum() == 0:
            continue

        overlap_pred_labels = np.unique(labeled_pred[gt_component])
        overlap_pred_labels = overlap_pred_labels[overlap_pred_labels != 0]

        for j in overlap_pred_labels:
            pred_component = (labeled_pred == j)
            intersection = np.logical_and(gt_component, pred_component).sum()
            union = np.logical_or(gt_component, pred_component).sum()
            if union > 0:
                iou = intersection / union
                ious.append(iou)

    return ious if ious else [0.0]

# Mean pairwise IoU per image
def compute_mean_pairwise_iou_per_image(gt_mask, pred_mask):
    ious = compute_pairwise_iou(gt_mask, pred_mask)
    return np.mean(ious)

# Dataset-level mean pairwise IoU (over all images and objects)
def compute_dataset_pairwise_miou(gt_masks, pred_masks):
    start_time = time.time()
    all_ious = []
    for i in range(gt_masks.shape[0]):
        ious = compute_pairwise_iou(gt_masks[i, 0] > 0.5, pred_masks[i, 0] > 0.5)
        all_ious.extend(ious)
    result = float(np.mean(all_ious)) if all_ious else 0.0
    duration = time.time() - start_time
    print(f"Dataset Pairwise mIoU computation took: {duration:.2f}s")
    return result

# 2. Multiscale IoU per image
def downsample_mask(mask, scale):
    """
    Downsamples mask tensor of shape [B,C,H,W] by given scale using bilinear interpolation.
    """
    B, C, H, W = mask.shape
    new_H, new_W = max(1, int(H * scale)), max(1, int(W * scale))
    return F.interpolate(mask, size=(new_H, new_W), mode='bilinear', align_corners=False)

def multiscale_ratio(gt, pred, scales):
    """
    Computes average IoU ratio over multiple scales for one sample (gt, pred tensors).
    gt, pred shape: [1, 1, H, W]
    """
    total = 0.0
    for scale in scales:
        gt_s = downsample_mask(gt, scale) > 0.5
        pred_s = downsample_mask(pred, scale) > 0.5

        # Convert to boolean numpy arrays for sum
        gt_s_np = gt_s.cpu().numpy().astype(bool).squeeze()
        pred_s_np = pred_s.cpu().numpy().astype(bool).squeeze()

        intersection = np.logical_and(gt_s_np, pred_s_np).sum()
        gt_pixels = gt_s_np.sum()
        total += (intersection / gt_pixels) if gt_pixels > 0 else 1.0
    return total / len(scales)

def compute_multiscale_iou_per_sample(gt_masks, pred_masks, scales=np.logspace(-2, 0, 10)):
    start_time = time.time()
    iou_per_sample = []
    for i in range(gt_masks.shape[0]):
        gt_tensor = torch.tensor(gt_masks[i:i+1]).float()
        pred_tensor = torch.tensor(pred_masks[i:i+1]).float()
        iou = multiscale_ratio(gt_tensor, pred_tensor, scales)
        iou_per_sample.append(iou)
    duration = time.time() - start_time
    print(f"Multiscale IoU computation took: {duration:.2f}s")
    return iou_per_sample

def compute_mean_multiscale_iou(gt_masks, pred_masks, scales=np.logspace(-2, 0, 10)):
    iou_per_sample = compute_multiscale_iou_per_sample(gt_masks, pred_masks, scales)
    return np.mean(iou_per_sample)

# 3. Collect predictions and labels from model + dataloader
def collect_predictions_and_labels(model, dataloader, device):
    start_time = time.time()
    model.eval()
    all_probs, all_targets = [], []
    with torch.no_grad():
        for images, masks, _ in tqdm(dataloader, desc="Evaluating"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            probs = torch.sigmoid(outputs).cpu().numpy()
            targets = masks.cpu().numpy()
            all_probs.append(probs)
            all_targets.append(targets)

    all_probs = np.concatenate(all_probs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    preds_05 = (all_probs > 0.5).astype(np.uint8)

    y_true_flat = all_targets.flatten()
    y_score_flat = all_probs.flatten()

    duration = time.time() - start_time
    print(f"Prediction collection took: {duration:.2f}s")

    return y_true_flat, y_score_flat, preds_05, all_targets

# 4. Evaluation pipeline to combine metrics
def evaluate_model(model, test_loader, device):
    start_time = time.time()

    print("→ Collecting predictions...")
    y_true, y_score, preds_05, y_targets = collect_predictions_and_labels(model, test_loader, device)

    print("→ Computing Pairwise mIoU@0.5...")
    pairwise_miou_05 = compute_dataset_pairwise_miou(y_targets, preds_05)

    print("→ Computing Multiscale IoU...")
    miou_multiscale = compute_mean_multiscale_iou(y_targets, preds_05)

    end_time = time.time()
    total_duration = end_time - start_time
    print(f"\n=== Total Evaluation Time: {total_duration:.2f} seconds ===")

    print("\n=== Evaluation Results ===")
    print(f"Pairwise mIoU@0.5: {pairwise_miou_05:.4f}")
    print(f"Multiscale IoU (MIoU): {miou_multiscale:.4f}")

    return {
        "Pairwise_mIoU@0.5": pairwise_miou_05,
        "MIoU_multiscale": miou_multiscale
    }


In [None]:
results = evaluate_model(model, test_loader, device)

In [None]:
model = UNet() 
model.load_state_dict(torch.load(""))
model = model.to(device)  

In [None]:
def visualize_prediction(model, dataset, device, index=0):
    model.eval()

    image, gt_mask, _ = dataset[index]  
    image_tensor = image.unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(image_tensor)
        pred_mask = torch.sigmoid(pred).squeeze().cpu().numpy()

    pred_mask = (pred_mask > 0.5).astype(np.uint8)
    gt_mask = gt_mask.squeeze().cpu().numpy().astype(np.uint8)

    # Convert image tensor to numpy array
    img_np = image.squeeze().cpu().numpy()
    if img_np.ndim == 3:
        img_np_show = np.moveaxis(img_np, 0, -1)
    else:
        img_np_show = img_np

    H_img, W_img = img_np_show.shape[:2]

    # Resize masks if needed
    if gt_mask.shape != (H_img, W_img):
        gt_mask = resize(gt_mask, (H_img, W_img), order=0, preserve_range=True).astype(np.uint8)
    if pred_mask.shape != (H_img, W_img):
        pred_mask = resize(pred_mask, (H_img, W_img), order=0, preserve_range=True).astype(np.uint8)

    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    for ax in axs:
        if img_np_show.ndim == 2:
            ax.imshow(img_np_show, cmap='gray')
        else:
            ax.imshow(img_np_show)
        ax.axis('off')

    axs[0].set_title("Input Image")

    # GT overlay (green)
    gt_overlay = np.zeros((H_img, W_img, 4), dtype=np.float32)
    gt_overlay[gt_mask == 1] = [0, 0.5, 0, 0.6]
    axs[1].imshow(gt_overlay)
    axs[1].set_title("Image + Ground Truth Mask")

    # Pred overlay (red)
    pred_overlay = np.zeros((H_img, W_img, 4), dtype=np.float32)
    pred_overlay[pred_mask == 1] = [0.5, 0, 0.5, 1]
    axs[2].imshow(pred_overlay)
    axs[2].set_title("Image + Predicted Mask")

    plt.tight_layout()
    plt.show()


In [None]:
visualize_prediction(model, train_dataset, device, index=13)


Generating IoUs boxplots

In [None]:
def generate_all_boxplots_in_row(model, dataloader, device, max_images=30, save_name="boxplots.pdf"):
    model.eval()
    spacing = 0.7
    fig_width = max_images * spacing
    fig, ax = plt.subplots(figsize=(fig_width, 7), dpi=600)

    image_counter = 0
    boxplot_positions = []
    xtick_labels = []
    scales = np.logspace(-2, 0, 10)

    with torch.no_grad():
        for images, masks, filenames in tqdm(dataloader, desc="Generating Boxplots"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(float)

            for i in range(images.size(0)):
                if image_counter >= max_images:
                    break

                gt_mask = masks[i, 0].cpu().numpy() > 0.5
                pred_mask = preds[i, 0]
                pred_prob = probs[i, 0]
                pairwise_ious = compute_pairwise_iou(gt_mask, pred_mask)
                if isinstance(pairwise_ious, (int, float)):
                    pairwise_ious = [pairwise_ious]
                if len(pairwise_ious) <= 1:
                    continue

                multiscale_iou = multiscale_ratio(
                    torch.tensor(gt_mask).unsqueeze(0).unsqueeze(0).float(),
                    torch.tensor(pred_mask).unsqueeze(0).unsqueeze(0).float(),
                    scales
                )
                mean_iou = np.mean(pairwise_ious)
                median_iou = np.median(pairwise_ious)
                if median_iou <= 0:
                    continue

                pos = (image_counter + 1) * spacing
                boxplot_positions.append(pos)
                xtick_labels.append(str(filenames[i]))

                ax.boxplot(
                    pairwise_ious,
                    positions=[pos],
                    patch_artist=True,
                    showfliers=True,
                    widths=0.35,
                    boxprops=dict(facecolor='#393432', edgecolor='#000000', linewidth=0.5),
                    medianprops=dict(color='#00bfff', linewidth=4),
                    flierprops=dict(marker='o', markerfacecolor='none', markeredgecolor='black', markersize=10),
                    whiskerprops=dict(color='#000000', linewidth=0.6),
                    capprops=dict(color='#000000', linewidth=1.8),
                )

                ax.plot(pos, multiscale_iou, marker='^', color='red', markersize=10, linestyle='None')  
                ax.plot(pos, mean_iou, marker='s', markerfacecolor='none', markeredgecolor='yellow', markersize=10, linestyle='None')  

                image_counter += 1
            if image_counter >= max_images:
                break

    bottom_margin, top_margin = 0.02, 0.03
    left_margin = spacing * 0.5
    right_margin = spacing * 0.5
    pos_max = boxplot_positions[-1] if boxplot_positions else max_images * spacing

    ax.set_xlim(left_margin, pos_max + right_margin)
    ax.set_ylim(-bottom_margin, 1 + top_margin)

    ax.set_yticks(np.linspace(0, 1, 6))
    ax.set_xticks(boxplot_positions)
    ax.tick_params(axis='x', which='both', labelbottom=False)

    ax.set_ylabel('IoU', fontsize=25, fontfamily='DejaVu Sans', color='black') 


    ax.tick_params(axis='both', which='both', direction='out',
                   length=6, width=1.5, labelsize=15, color='black')

    for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
        tick.tick1line.set_clip_on(False)
        tick.tick2line.set_clip_on(False)

    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    rect = patches.Rectangle(
        (left_margin, -bottom_margin),
        pos_max + right_margin - left_margin,
        1 + top_margin + bottom_margin,
        linewidth=1.5, edgecolor='black', facecolor='none',
        zorder=10, clip_on=False
    )
    ax.add_patch(rect)

    legend_elements = [
        Line2D([0], [0], marker='^', linestyle='None', color='red', label='Multiscale IoU', markersize=10),  
        Line2D([0], [0], marker='s', linestyle='None', markerfacecolor='none', markeredgecolor='yellow', label='Mean Pairwise IoU', markersize=10),  
        Line2D([0], [0], marker='o', linestyle='None', color='none', markerfacecolor='none', markeredgecolor='black', label='Outliers', markersize=10),  
        Line2D([0], [0], linestyle='-', color='#00bfff', linewidth=3, label='Median (boxplot)'), 
        Line2D([0], [0], color='black', linewidth=3, label='Pairwise IoU Distribution')  
    ]

    plt.tight_layout()
    save_path = os.path.join(os.getcwd(), save_name)
    plt.savefig(save_path, bbox_inches="tight", format="pdf")
    plt.show()

    print("\nFilenames corresponding to boxplots:")
    for name in xtick_labels:
        print(name)


In [None]:
generate_all_boxplots_in_row(
    model=model,
    dataloader=test_loader,
    device=device,
    max_images=30,
    save_name=""
)

In [None]:
legend_elements = [
    Line2D([0], [0], marker='^', linestyle='None', color='red', label=r'$IoU_{\mathit{multiscale}}$', markersize=12),
    Line2D([0], [0], marker='s', linestyle='None', markerfacecolor='none', markeredgecolor='yellow',
           label=r'$mIoU_{\mathit{pairwise}}$', markersize=12),
    Line2D([0], [0], marker='o', linestyle='None', color='none', markerfacecolor='none', markeredgecolor='black',
           label=r'$IoU_{\mathit{pairwise}}\ \text{Outliers}$', markersize=12),
    Line2D([0], [0], linestyle='-', color='#00bfff', linewidth=3,
           label=r'$IoU_{\mathit{pairwise}}\ \text{Median}$')
]

fig, ax = plt.subplots(figsize=(10, 2), dpi=600)
ax.axis('off')

legend = ax.legend(
    handles=legend_elements,
    loc='center',
    frameon=True,
    facecolor='#d3d3d3',
    edgecolor='none',
    framealpha=1,
    ncol=len(legend_elements),
    handlelength=1.2,
    handletextpad=0.2,
    columnspacing=1.8,
    fontsize=25,
    borderpad=0.8
)

legend.get_frame().set_boxstyle("square")

fig.canvas.draw()

bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())

plt.savefig("", bbox_inches=bbox, dpi=600, transparent=True)
plt.show()


In [None]:
def visualize_prediction_on_raw_image(model, image_path, device, input_size=(512, 512)):
    model.eval()
    img = Image.open(image_path).convert("L")
    original_size = img.size

    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    image_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(image_tensor)
        if isinstance(pred, dict):
            pred = pred['logits'] if 'logits' in pred else list(pred.values())[0]
        pred = interpolate(pred, size=input_size, mode="bilinear", align_corners=False)
        pred_mask = torch.sigmoid(pred).squeeze().cpu().numpy()

    pred_mask_bin = (pred_mask > 0.5).astype(np.uint8)
    pred_mask_resized = np.array(Image.fromarray(pred_mask_bin * 255).resize(original_size, Image.NEAREST)) // 255

    img_np = np.array(img).astype(np.float32)
    img_vis = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)

    img_rgb = np.stack([img_vis]*3, axis=-1)

    pred_overlay = img_rgb.copy()
    pred_overlay[pred_mask_resized == 1] = [1, 0, 0]

    fig, axs = plt.subplots(1, 2, figsize=(14, 6))
    axs[0].imshow(img_rgb, cmap="gray")
    axs[0].set_title("Original H-alpha Image")
    axs[0].axis("off")

    axs[1].imshow(pred_overlay)
    axs[1].set_title("Model Prediction Overlay (Blue)")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
image_path = ""
visualize_prediction_on_raw_image(model=model, image_path=image_path, device="cuda")
