In [3]:
import os
import torch
import numpy as np
from PIL import Image
import dancher_tools_segmentation as dt
import torchvision.transforms as T
import sys

module_path = os.path.dirname(os.path.abspath("__file__"))
sys.path.append(module_path)

def calculate_miou_per_image(pred_mask, true_mask, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_pixels = (pred_mask == cls)
        true_pixels = (true_mask == cls)
        intersection = np.logical_and(pred_pixels, true_pixels).sum()
        union = np.logical_or(pred_pixels, true_pixels).sum()
        if union == 0:
            iou = 1.0  # No pixels of this class in either; perfect match
        else:
            iou = intersection / union
        ious.append(iou)
    mious = np.mean(ious)
    return mious

def save_image_and_masks(image, pred_mask, save_dir, index):
    # 创建保存目录
    images_dir = os.path.join(save_dir, 'images')
    masks_dir = os.path.join(save_dir, 'masks')
    os.makedirs(images_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)
    
    # 确保 image 是 [C, H, W] 格式
    if isinstance(image, np.ndarray):
        if image.ndim == 3 and image.shape[0] == 3:  # CHW 格式
            image_tensor = torch.tensor(image).float()
        elif image.ndim == 3 and image.shape[-1] == 3:  # HWC 格式
            image_tensor = torch.tensor(image).permute(2, 0, 1).float()  # 转为 CHW
        elif image.ndim == 2:  # 单通道图像，扩展到 3 通道
            image_tensor = torch.tensor(np.stack([image] * 3, axis=0)).float()
        else:
            raise ValueError(f"Unsupported image shape: {image.shape}")
    else:
        raise TypeError(f"Expected numpy array, but got {type(image)}")
    
    # 反归一化
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)  # 调整为 [C, 1, 1]
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)   # 调整为 [C, 1, 1]
    unnormalized_image = image_tensor * std + mean  # 手动反归一化
    unnormalized_image = unnormalized_image.permute(1, 2, 0).numpy()  # 转回 HWC 格式
    unnormalized_image = np.clip(unnormalized_image, 0, 1)  # 限制范围在 [0, 1]
    unnormalized_image = (unnormalized_image * 255).astype(np.uint8)  # 转为 uint8

    # 保存图像
    img_pil = Image.fromarray(unnormalized_image)
    img_pil.save(os.path.join(images_dir, f'IwDA_{index}.png'))
    
    # 将预测掩码从 0, 1 转换为 0, 255
    pred_mask = pred_mask * 255
    pred_mask_pil = Image.fromarray(pred_mask.astype(np.uint8))
    pred_mask_pil.save(os.path.join(masks_dir, f'IwDA_{index}.png'))


def collect_and_save_predictions(loader, model, device, num_classes, save_dir, threshold=0.8):
    model.eval()
    index = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            images, masks = batch
            # print(f"DEBUG: Batch {batch_idx} - Images shape: {images.shape}, Masks shape: {masks.shape}")
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            true_masks = masks.cpu().numpy()
            
            for i in range(images.size(0)):
                pred = preds[i]
                true = true_masks[i]
                miou = calculate_miou_per_image(pred, true, num_classes)
                # print(f"DEBUG: Image {i} in batch {batch_idx} - mIoU: {miou:.4f}")
                if miou > threshold:
                    # print(f"DEBUG: Saving image {i} in batch {batch_idx} with mIoU: {miou:.4f}")
                    save_image_and_masks(images[i].cpu().numpy(), pred, save_dir, index)
                    index += 1
    # print(f"Saved {index} samples with mIoU > {threshold} in {save_dir}")

def main():
    # Parse configuration
    args = dt.utils.get_config('configs/IwDA55/reset_UNet.yaml')
    # print(f"DEBUG: Config args: {args}")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print(f"DEBUG: Using device: {device}")
    
    # Get data loaders
    train_loader, test_loader = dt.utils.get_dataloaders(args)
    # print("DEBUG: Data loaders initialized.")
    
    # Initialize and load model
    model = dt.utils.get_model(args, device)
    # print("DEBUG: Model initialized.")
    model.load(model_dir=args.model_save_dir, mode='best')  # Load best model
    # print("DEBUG: Model loaded.")
    
    # Create save directories
    save_dir_train = os.path.join(args.model_save_dir, 'reset_train')
    save_dir_test = os.path.join(args.model_save_dir, 'reset_test')
    
    # Collect and save training set predictions
    # print("DEBUG: Collecting predictions for training set...")
    collect_and_save_predictions(train_loader, model, device, model.num_classes, save_dir_train)
    
    # Collect and save test set predictions
    # print("DEBUG: Collecting predictions for test set...")
    collect_and_save_predictions(test_loader, model, device, model.num_classes, save_dir_test)

if __name__ == '__main__':
    main()


Successfully loaded dataset module: iw_dataset
Successfully loaded color maps and class names for dataset 'iw_dataset'
Loaded and validated config from configs/IwDA55/reset_UNet.yaml: {'model_name': 'UNet', 'img_size': 224, 'num_classes': 2, 'weights': 'results/IwDA55/UNet/UNet_best.pth', 'transfer_weights': None, 'load_mode': 'best', 'learning_rate': 0.001, 'batch_size': 16, 'num_workers': 4, 'patience': 5, 'delta': 0.005, 'loss': 'ce', 'loss_weights': None, 'num_epochs': 500, 'model_save_dir': './results/IwDA_reset_UNet', 'metrics': ['mIoU', 'precision', 'recall', 'f1_score'], 'export': False, 'conf_threshold': None, 'save_interval': 5, 'in_channels': 3, 'ds': {'name': 'iw_dataset', 'train_paths': ['datasets/IwDA55/train', 'datasets/IwDA55/test'], 'test_paths': ['datasets/IwDA55/test'], 'color_to_class': {(0, 0, 0): 0, (255, 255, 255): 1}, 'class_to_color': {0: (0, 0, 0), 1: (255, 255, 255)}, 'class_name': ['background', 'internal wave']}}
Successfully loaded dataset module: iw_datas