In [4]:
import os
import random
from pathlib import Path
import cv2
import numpy as np

# Dataset root and augmentation output directories
SRC_ROOT = Path("./USA_segmentation")     # 原数据根目录
DST_ROOT = Path("./USA_segmentation_aug") # 增强后输出根目录（不会覆盖原图）

# File extensions
IMAGE_EXT = ".png"                        # 如果是 .jpg/.tif 请改
MASK_EXT  = ".png"

# Which split files to augment
SPLITS = ["train.txt"]                    # 只增强训练集；想全量就改成 ["train.txt","val.txt","test.txt"]

# Random seed
SEED = 2025
random.seed(SEED)

# Source subfolders
SRC_RGB = SRC_ROOT / "RGB_images"
SRC_NIR = SRC_ROOT / "NRG_images"
SRC_MSK = SRC_ROOT / "masks"
SPLIT_DIR = SRC_ROOT / "splits"

# Create output directories
(DST_ROOT / "RGB_images").mkdir(parents=True, exist_ok=True)
(DST_ROOT / "NRG_images").mkdir(parents=True, exist_ok=True)
(DST_ROOT / "masks").mkdir(parents=True, exist_ok=True)

print(f"Source RGB folder : {SRC_RGB}")
print(f"Source NIR folder : {SRC_NIR}")
print(f"Source masks folder: {SRC_MSK}")
print(f"Aug output folder  : {DST_ROOT}")

Source RGB folder : USA_segmentation\RGB_images
Source NIR folder : USA_segmentation\NRG_images
Source masks folder: USA_segmentation\masks
Aug output folder  : USA_segmentation_aug


In [5]:
def ensure_dir(path: Path):
    """
    Ensure that `path` is an existing directory.
    """
    if not path.is_dir():
        raise FileNotFoundError(f"Directory not found: {path}")


def load_images(folder: Path, ext: str):
    """
    Load all images with given extension from `folder`.
    Returns list of numpy arrays sorted by filename.
    """
    ensure_dir(folder)
    images = []
    for fname in sorted(os.listdir(folder)):
        if fname.lower().endswith(ext):
            img = cv2.imread(str(folder / fname), cv2.IMREAD_UNCHANGED)
            if img is None:
                print(f"Warning: failed to read {fname}")
                continue
            images.append(img)
    return images


def preprocess(img, size=(256,256), normalize=True):
    """
    Resize to `size` and normalize pixel values to [0,1].
    """
    resized = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
    if normalize:
        return resized.astype(np.float32) / 255.0
    return resized


def read_split(file_path: Path):
    """
    Read image IDs from a split file (one per line).
    """
    ensure_dir(file_path.parent)
    with open(file_path, 'r') as f:
        return [line.strip() for line in f if line.strip()]