In [16]:
import os
import sys
import torch
import random
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
import numpy as np
import scipy.ndimage as ndimage

# 设置模块路径
module_path = os.path.dirname(os.path.abspath("__file__"))
sys.path.append(module_path)

import dancher_tools_segmentation as dt

display = False
export = False
specified_dir = None

# %%


def cut_image(image, crop_size=224):
    """
    将图像裁剪为多个大小为 crop_size x crop_size 的块，每次步长为 crop_size/2。
    """
    width, height = image.size
    blocks = []
    positions = []

    step_size = crop_size // 2  # 设置步长为 crop_size/2

    for y in range(0, height - crop_size + 1, step_size):  # 每次进 img_size/2
        for x in range(0, width - crop_size + 1, step_size):  # 每次进 img_size/2
            x_end = min(x + crop_size, width)
            y_end = min(y + crop_size, height)
            block = image.crop((x, y, x_end, y_end)).resize((crop_size, crop_size), Image.LANCZOS)
            blocks.append(block)
            positions.append((x, y, x_end - x, y_end - y))

    return blocks, positions


def assemble_image(blocks, positions, original_size):
    """
    将多个裁剪块重新组合成完整的图像，保留重叠区域中的白色部分。
    并对最终的预测掩码进行噪声去除。
    """
    full_mask = np.zeros(original_size, dtype=np.uint8)  # 初始化为全黑
    for block, (x, y, w, h) in zip(blocks, positions):
        block = np.array(block.resize((w, h), Image.NEAREST))  # 将块调整为合适大小
        # 在重叠区域内，保留白色部分
        full_mask[y:y + h, x:x + w] = np.maximum(full_mask[y:y + h, x:x + w], block)  # 保留1

    # 对合并后的 mask 进行噪声去除
    cleaned_mask = remove_small_objects(full_mask)
    return cleaned_mask


def remove_small_objects(mask, min_size=40):
    """
    去除小的噪点，去除连续像素点数少于 min_size 的区域。
    
    Args:
        mask (np.array): 二值化的掩码。
        min_size (int): 最小连通区域的面积，低于此值的区域会被去除。
    
    Returns:
        np.array: 过滤后的掩码。
    """
    # 先进行二值化处理，将掩码中大于0的值设为1，小于等于0的值设为0
    binary_mask = mask > 0
    # 标记连通区域
    label_im, num_labels = ndimage.label(binary_mask)

    # 获取每个连通区域的面积
    sizes = ndimage.sum(binary_mask, label_im, range(num_labels + 1))

    # 创建一个新的掩码，去除小区域
    mask_cleaned = np.zeros_like(mask)
    for i in range(1, num_labels + 1):
        if sizes[i] >= min_size:
            mask_cleaned[label_im == i] = 255  # 保留大于 min_size 的区域

    return mask_cleaned


def predict_block(block, model, device):
    """
    对单个裁剪块进行预测。
    """
    model.eval()
    with torch.no_grad():
        transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image_tensor = transform(block).unsqueeze(0).to(device)
        logits = model(image_tensor)
        pred_mask = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
        return Image.fromarray((pred_mask * 255).astype(np.uint8))


def display_samples_with_mask(indices, all_test_images, all_test_masks, model, device):
    """
    展示指定索引的测试样本和预测结果。
    """
    for idx in indices:
        image_path = all_test_images[idx]
        mask_path = all_test_masks[idx] if all_test_masks else None
        image = Image.open(image_path).convert('RGB')
        original_size = image.size[::-1]  # (H, W)

        blocks, positions = cut_image(image)
        pred_blocks = [predict_block(block, model, device) for block in blocks]
        full_pred_mask = assemble_image(pred_blocks, positions, original_size)

        # 展示
        plt.figure(figsize=(20, 10))
        plt.subplot(1, 3, 1)
        plt.imshow(image)
        plt.title("Original Image")
        plt.axis('off')

        if mask_path:
            true_mask = np.array(Image.open(mask_path).convert('L'))  # 转为灰度掩码
            plt.subplot(1, 3, 2)
            plt.imshow(true_mask, cmap='gray')
            plt.title("Ground Truth Mask")
            plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(full_pred_mask, cmap='gray')
        plt.title("Predicted Mask")
        plt.axis('off')
        plt.show()


def export_predictions(all_test_images, model, device, export_dir):
    """
    导出预测结果到指定目录。
    """
    os.makedirs(export_dir, exist_ok=True)
    for image_path in all_test_images:
        image = Image.open(image_path).convert('RGB')
        original_size = image.size[::-1]
        blocks, positions = cut_image(image)
        pred_blocks = [predict_block(block, model, device) for block in blocks]
        full_pred_mask = assemble_image(pred_blocks, positions, original_size)

        # 保存结果
        base_name = os.path.basename(image_path)
        save_path = os.path.join(export_dir, f"{os.path.splitext(base_name)[0]}.png")
        Image.fromarray(full_pred_mask).save(save_path)
        print(f"Saved prediction to {save_path}")


def get_test_images_and_masks(test_paths, specified_dir=None):
    """
    获取所有测试图像和掩码路径。

    Args:
        test_paths (list): 默认的测试集路径列表。
        specified_dir (str): 用户指定的测试数据路径（可选）。

    Returns:
        tuple: (all_test_images, all_test_masks)
    """
    all_test_images = []
    all_test_masks = []

    if specified_dir:
        images_dir = os.path.join(specified_dir, 'images')
        masks_dir = os.path.join(specified_dir, 'masks')

        if os.path.exists(images_dir):
            all_test_images = sorted([os.path.join(images_dir, img) for img in os.listdir(images_dir) if img.endswith(('.png', '.jpg'))])
            if os.path.exists(masks_dir):
                all_test_masks = sorted([os.path.join(masks_dir, mask) for mask in os.listdir(masks_dir) if mask.endswith(('.png', '.jpg'))])
            else:
                print(f"[INFO] No masks directory found in {specified_dir}. Proceeding without masks.")
        else:
            # 假设用户指定的是一个直接包含图像的目录
            all_test_images = sorted([os.path.join(specified_dir, img) for img in os.listdir(specified_dir) if img.endswith(('.png', '.jpg'))])
            print(f"[INFO] No structured dataset found in {specified_dir}. Using images only.")
    else:
        for test_path in test_paths:
            images_dir = os.path.join(test_path, 'images')
            masks_dir = os.path.join(test_path, 'masks')

            if os.path.exists(images_dir) and os.path.exists(masks_dir):
                images = sorted([os.path.join(images_dir, img) for img in os.listdir(images_dir) if img.endswith(('.png', '.jpg'))])
                masks = sorted([os.path.join(masks_dir, mask) for mask in os.listdir(masks_dir) if mask.endswith(('.png', '.jpg'))])

                # 检查文件数量是否匹配
                if len(images) != len(masks):
                    print(f"[WARNING] Number of images and masks do not match in {test_path}")
                    continue

                all_test_images.extend(images)
                all_test_masks.extend(masks)

    if not all_test_images:
        print("[ERROR] No test images found.")
        sys.exit(1)

    return all_test_images, all_test_masks


def handle_display_and_export(all_test_images, all_test_masks, model, device, args, display=True, export=True, specified_dir=None):
    """
    处理展示和导出逻辑。

    Args:
        all_test_images (list): 所有测试图像路径。
        all_test_masks (list): 所有测试掩码路径。
        model: 模型对象。
        device: 运行设备。
        args: 配置参数。
        display (bool): 是否展示样本。
        export (bool): 是否导出预测结果。
        specified_dir (str): 用户指定的测试数据路径（可选）。
    """
    if display:
        sample_indices = random.sample(range(len(all_test_images)), min(3, len(all_test_images)))
        display_samples_with_mask(sample_indices, all_test_images, all_test_masks, model, device)

    if export:
        # 在原有基础上加上 specified_dir 的最后一个路径部分
        export_dir = Path(args.model_save_dir) / "predictions"
        if specified_dir:
            specified_dir_name = os.path.basename(os.path.normpath(specified_dir))
            export_dir = export_dir / specified_dir_name
        
        export_predictions(all_test_images, model, device, export_dir)



In [17]:
# %%
args = dt.utils.get_config('configs/S1/FLaTO.yaml')
specified_dir = 'datasets/S1/images'
# display = True
export = True

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 初始化并加载模型
model = dt.utils.get_model(args, device)
model.load(model_dir=args.model_save_dir, mode=args.load_mode, specified_path=args.weights)

# 获取测试图像和掩码路径
all_test_images, all_test_masks = get_test_images_and_masks(args.ds['test_paths'], specified_dir=specified_dir)

# 控制展示和导出逻辑
handle_display_and_export(all_test_images, all_test_masks, model, device, args, display=display, export=export, specified_dir=specified_dir)


Successfully loaded dataset module: iw_dataset
Successfully loaded color maps and class names for dataset 'iw_dataset'
Loaded and validated config from configs/S1/FLaTO.yaml: {'model_name': 'FLaTO', 'img_size': 224, 'num_classes': 2, 'weights': None, 'transfer_weights': None, 'load_mode': 'best', 'learning_rate': 0.001, 'batch_size': 4, 'num_workers': 4, 'patience': 50, 'delta': 0.005, 'loss': 'ce,focal', 'loss_weights': None, 'num_epochs': 500, 'model_save_dir': './results/S1_100shot', 'metrics': ['mIoU', 'precision', 'recall', 'f1_score'], 'export': False, 'conf_threshold': None, 'save_interval': 5, 'in_channels': 3, 'cache_reset': False, 'ds': {'name': 'iw_dataset', 'train_paths': ['datasets/S1/S1_100shot'], 'test_paths': ['datasets/S1/S1_100shot'], 'color_to_class': {(0, 0, 0): 0, (255, 255, 255): 1}, 'class_to_color': {0: (0, 0, 0), 1: (255, 255, 255)}, 'class_name': ['background', 'internal wave']}}
Loaded custom model 'FLaTO' from 'models/FLaTO.py'.
n_patches: 196, img_size: 14,