#### （已事先將訓練資料集拆分為 8:1:1 存到 "Training_dataset/train"、"Training_dataset/validation"、"Training_dataset/test" 三資料夾中的 imgs 以及 gts）

In [None]:
import os
import random
from PIL import Image
import shutil
from torchvision import transforms

#### 建立新資料夾 "Training_dataset/aug_train"，將原圖以及擴增後的圖片存到 aug_train 的 aug_imgs 以及 aug_gts


In [None]:
# 定義圖片檔名的前綴和後綴
prefixes = ["TRA_RI_", "TRA_RO_"]
suffix = ".jpg"

# 設定原圖資料夾路徑
img_folder = "Training_dataset/train/imgs/"
gt_folder = "Training_dataset/train/gts/"

# 創建資料增強後的資料夾
aug_img_folder = "Training_dataset/aug_train/aug_imgs/"
aug_gt_folder = "Training_dataset/aug_train/aug_gts/"
os.makedirs(aug_img_folder, exist_ok=True)
os.makedirs(aug_gt_folder, exist_ok=True)

# Step 1: RandomHorizontalFlip
def random_horizontal_flip(img_path, gt_path):
    img = Image.open(img_path)
    gt = Image.open(gt_path)
    transform = transforms.RandomHorizontalFlip(p=1)
    img = transform(img)
    gt = transform(gt)
    return img, gt

# Step 2: RandomVerticalFlip
def random_vertical_flip(img_path, gt_path):
    img = Image.open(img_path)
    gt = Image.open(gt_path)
    transform = transforms.RandomVerticalFlip(p=1)
    img = transform(img)
    gt = transform(gt)
    return img, gt

# Step 3: ColorJitter
def color_jitter(img_path):
    img = Image.open(img_path)
    transform = transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0)
    img = transform(img)
    return img

# Step 1: RandomHorizontalFlip
# 將圖片檔名中的 "_200" 改為 "_300" 表示做水平翻轉後的圖片
selected_imgs = random.sample(os.listdir(img_folder), 3456)
for img_name in selected_imgs:
    for prefix in prefixes:
        if img_name.startswith(prefix):
            img_path = os.path.join(img_folder, img_name)
            gt_path = os.path.join(gt_folder, img_name.replace(".jpg", ".png"))
            new_img_name = img_name.replace("_200", "_300")
            new_gt_name = img_name.replace("_200", "_300").replace(".jpg", ".png")
            new_img_path = os.path.join(aug_img_folder, new_img_name)
            new_gt_path = os.path.join(aug_gt_folder, new_gt_name)
            img, gt = random_horizontal_flip(img_path, gt_path)
            img.save(new_img_path)
            gt.save(new_gt_path)

# Step 2: RandomVerticalFlip
# 將圖片檔名中的 "_200" 改為 "_400" 表示做垂直翻轉後的圖片
selected_imgs = random.sample(os.listdir(img_folder), 3456)
for img_name in selected_imgs:
    for prefix in prefixes:
        if img_name.startswith(prefix):
            img_path = os.path.join(img_folder, img_name)
            gt_path = os.path.join(gt_folder, img_name.replace(".jpg", ".png"))
            new_img_name = img_name.replace("_200", "_400")
            new_gt_name = img_name.replace("_200", "_400").replace(".jpg", ".png")
            new_img_path = os.path.join(aug_img_folder, new_img_name)
            new_gt_path = os.path.join(aug_gt_folder, new_gt_name)
            img, gt = random_vertical_flip(img_path, gt_path)
            img.save(new_img_path)
            gt.save(new_gt_path)

# Step 3: ColorJitter
# 將圖片檔名中的 "_200" 改為 "_500" 表示做亮度調整後的圖片
selected_imgs = random.sample(os.listdir(img_folder), 3456)
for img_name in selected_imgs:
    for prefix in prefixes:
        if img_name.startswith(prefix):
            img_path = os.path.join(img_folder, img_name)
            new_img_name = img_name.replace("_200", "_500")
            new_img_path = os.path.join(aug_img_folder, new_img_name)
            img = color_jitter(img_path)
            img.save(new_img_path)
            gt_name = img_name.replace(".jpg", ".png")
            gt_path = os.path.join(gt_folder, gt_name)
            shutil.copy(gt_path, aug_gt_folder)


for gt_name in os.listdir(aug_gt_folder):
    if "_200" in gt_name:
        new_gt_name = gt_name.replace("_200", "_500")
        os.rename(os.path.join(aug_gt_folder, gt_name), os.path.join(aug_gt_folder, new_gt_name))

# 複製 train/imgs 資料夾中的圖片到 aug_imgs 資料夾中
for img_name in os.listdir(img_folder):
    shutil.copy(os.path.join(img_folder, img_name), os.path.join(aug_img_folder, img_name))

# 複製 train/gts 資料夾中的圖片到 aug_gts 資料夾中
for gt_name in os.listdir(gt_folder):
    shutil.copy(os.path.join(gt_folder, gt_name), os.path.join(aug_gt_folder, gt_name))