In [9]:
import os
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as TF
import torch

# ========= 0. 全域參數 =========
bean_type   = "honduras_natural"         # ← 想換其他豆種直接改
use_noback  = False                      # True → 使用 *_Noback 資料夾
subfolder   = "corp_augmented_dataNoback" if use_noback else "corp_augmented_data"
dataset_path = f"coffee_beans_data/{bean_type}/crop/classByhands"
# 你可以直接用 dataset_path 當作輸入資料夾
input_dir = dataset_path
output_dir = f"coffee_beans_data/{bean_type}/{subfolder}"      # 你可依需求調整輸出路徑

augment_per_image = 2
IMG_SIZE = 128

# 安全 pad-to-square（維持比例，補空白成正方形）
def pad_to_square(img, fill=0):
    w, h = img.size
    if w == h:
        return img
    diff = abs(h - w)
    pad1, pad2 = diff // 2, diff - diff // 2
    return TF.pad(img, (0, pad1, 0, pad2) if h < w else (pad1, 0, pad2, 0), fill=fill)

# Clamp 類別：避免 Tensor 超出 0~1
class Clamp(object):
    def __init__(self, min_val, max_val):
        self.min = min_val
        self.max = max_val
    def __call__(self, tensor):
        return torch.clamp(tensor, self.min, self.max)

# 資料擴增方式（不變形）
def get_safe_augment_transform():
    return transforms.Compose([
        transforms.Lambda(lambda img: pad_to_square(img, fill=0)),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
        transforms.RandomGrayscale(p=0.1),
        transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.3),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),  # 微量高斯雜訊
        Clamp(0, 1),
        transforms.ToPILImage()
    ])

# 建立資料夾並執行擴增
target_classes = ["good", "bad"] if use_noback else ["good", "bad", "back"]

for class_name in os.listdir(input_dir):
    if class_name not in target_classes:
        continue  # 跳過不需要的類別

    class_path = os.path.join(input_dir, class_name)
    if not os.path.isdir(class_path):
        continue

    output_class_path = os.path.join(output_dir, class_name)
    os.makedirs(output_class_path, exist_ok=True)

    for img_name in os.listdir(class_path):
        img_path = os.path.join(class_path, img_name)
        if not os.path.isfile(img_path):
            continue
        if not img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
            continue

        try:
            image = Image.open(img_path).convert("RGB")
        except:
            print(f"❌ 無法開啟圖檔：{img_path}")
            continue

        for i in range(augment_per_image):
            transform = get_safe_augment_transform()
            augmented = transform(image)
            new_name = f"{os.path.splitext(img_name)[0]}_aug{i}.jpg"
            augmented.save(os.path.join(output_class_path, new_name))

print(f"✅ 擴增完成！圖片已儲存在 {output_dir} 中")

✅ 擴增完成！圖片已儲存在 coffee_beans_data/honduras_natural/corp_augmented_data 中
