In [1]:
import os
import shutil
import json
from pathlib import Path
from collections import defaultdict, Counter

import numpy as np
from PIL import Image
import cv2
import albumentations as A
from sklearn.model_selection import StratifiedGroupKFold


In [2]:
# === 경로 설정 ===
BASE_DIR    = Path("/data/ephemeral/home/pro-cv-objectdetection-cv-11-sj/dataset")          # train.json, test.json, train/, test/ 위치
TRAIN_JSON  = BASE_DIR / "train.json"
TEST_JSON   = BASE_DIR / "test.json"

# (필터링/증강 후 json 저장용)
TRAIN_LE35_JSON        = BASE_DIR / "train_le35.json"
TRAIN_WITH_AUG_JSON    = BASE_DIR / "train_le35_with_aug_train_only.json"

# YOLO용 루트
YOLO_ROOT = BASE_DIR / "yolo_dataset2"  # ~/dataset/yolo_dataset
IMG_DIR   = YOLO_ROOT / "images"
LBL_DIR   = YOLO_ROOT / "labels"

# === 원본 COCO train.json 로드 ===
with open(TRAIN_JSON, "r") as f:
    coco_orig = json.load(f)

images_orig      = coco_orig["images"]
annotations_orig = coco_orig["annotations"]
categories       = coco_orig["categories"]

print("원본 이미지 수:", len(images_orig))
print("원본 어노테이션 수:", len(annotations_orig))


원본 이미지 수: 4883
원본 어노테이션 수: 23144


In [3]:
# image_id 별 annotation 개수
ann_count_by_image = Counter([ann["image_id"] for ann in annotations_orig])

# 35 초과 이미지 id
bad_image_ids = {img_id for img_id, cnt in ann_count_by_image.items() if cnt > 35}
print("bbox > 35인 이미지 수:", len(bad_image_ids))

# 필터링
images_le35 = [img for img in images_orig if img["id"] not in bad_image_ids]
annotations_le35 = [ann for ann in annotations_orig if ann["image_id"] not in bad_image_ids]

coco_le35 = {
    "images": images_le35,
    "annotations": annotations_le35,
    "categories": categories,
}

with open(TRAIN_LE35_JSON, "w") as f:
    json.dump(coco_le35, f)
print("[저장 완료]", TRAIN_LE35_JSON)
print("필터링 후 이미지 수:", len(images_le35))
print("필터링 후 어노테이션 수:", len(annotations_le35))


bbox > 35인 이미지 수: 22
[저장 완료] /data/ephemeral/home/pro-cv-objectdetection-cv-11-sj/dataset/train_le35.json
필터링 후 이미지 수: 4861
필터링 후 어노테이션 수: 22188


In [4]:
train_data = coco_le35
annots = train_data["annotations"]
images_le35 = train_data["images"]
categories = train_data["categories"]

# (image_id, category_id) 쌍
var = [(ann["image_id"], ann["category_id"]) for ann in annots]
X = np.ones((len(var), 1))  # placeholder
y = np.array([v[1] for v in var])       # category_id
groups = np.array([v[0] for v in var])  # image_id

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=411)

fold_to_use = 0
train_ids = None
val_ids = None

for fold_idx, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    print(f"Fold {fold_idx}: train_idx={len(train_idx)}, val_idx={len(val_idx)}")
    if fold_idx == fold_to_use:
        train_ids = np.unique(groups[train_idx])
        val_ids   = np.unique(groups[val_idx])
        break

print(f"\nUsing fold {fold_to_use}")
print("train image ids:", len(train_ids))
print("val image ids:", len(val_ids))

train_id_set = set(train_ids.tolist())
val_id_set   = set(val_ids.tolist())


Fold 0: train_idx=17619, val_idx=4569

Using fold 0
train image ids: 3884
val image ids: 977


In [5]:
# === COCO bbox 공통 설정 ===
bbox_params = A.BboxParams(
    format="coco",          # [x_min, y_min, w, h]
    label_fields=["category_ids"],
    min_visibility=0.2,
    # filter_lost_elements=True
)

# === 공통 기본 증강 (모든 클래스) ===
transform_base = A.Compose([
    A.RandomResizedCrop(
        height=1024, width=1024,
        scale=(0.9, 1.0),
        ratio=(0.9, 1.1),
        p=0.5
    ),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.GaussNoise(var_limit=(5.0, 25.0), p=0.1),
    A.RandomBrightnessContrast(
        brightness_limit=0.15,
        contrast_limit=0.15,
        p=0.2
    ),
    A.HueSaturationValue(
        hue_shift_limit=5,
        sat_shift_limit=20,
        val_shift_limit=10,
        p=0.2
    ),
], bbox_params=bbox_params)

# === General trash 이미지를 위한 강한 증강 ===
transform_gt = A.Compose([
    A.RandomResizedCrop(
        height=1024, width=1024,
        scale=(0.75, 1.0),
        ratio=(0.85, 1.15),
        p=0.8
    ),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    A.RandomBrightnessContrast(
        brightness_limit=0.3,
        contrast_limit=0.3,
        p=0.6
    ),
    A.HueSaturationValue(
        hue_shift_limit=10,
        sat_shift_limit=40,
        val_shift_limit=20,
        p=0.6
    ),
], bbox_params=bbox_params)


In [6]:
# image_id -> image_info
images_by_id_le35 = {img["id"]: img for img in images_le35}

# image_id -> annotations 리스트
anns_by_image_le35 = defaultdict(list)
for ann in annotations_le35:
    anns_by_image_le35[ann["image_id"]].append(ann)

# General trash category_id 찾기
gt_category_id = None
for cat in categories:
    if cat["name"] == "General trash":
        gt_category_id = cat["id"]
        break

if gt_category_id is None:
    raise ValueError("'General trash' 카테고리를 찾을 수 없습니다.")

print("General trash category_id:", gt_category_id)


General trash category_id: 0


In [7]:
aug_images = []
aug_annotations = []

# 새 image/ann id 시작점
next_image_id = max(img["id"] for img in images_le35) + 1
next_ann_id   = max(ann["id"] for ann in annotations_le35) + 1

num_aug_base = 1   # 일반 이미지당 증강 개수
num_aug_gt   = 3   # General trash 포함 이미지당 증강 개수

for img in images_le35:
    img_id = img["id"]

    # === Train set 에 속한 이미지만 증강 ===
    if img_id not in train_id_set:
        continue

    file_name = img.get("file_name", img.get("filename"))
    if file_name is None:
        continue

    img_path = BASE_DIR / file_name
    if not img_path.exists():
        # 혹시 file_name이 '0000.jpg'라면 train/0000.jpg 로 시도
        img_path = BASE_DIR / "train" / Path(file_name).name
    if not img_path.exists():
        print("이미지를 찾을 수 없습니다:", file_name)
        continue

    img_anns = anns_by_image_le35[img_id]
    if len(img_anns) == 0:
        continue

    # 이미지 로드 (BGR -> RGB)
    image = cv2.imread(str(img_path))
    if image is None:
        print("이미지를 읽을 수 없습니다:", img_path)
        continue
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    bboxes = [ann["bbox"] for ann in img_anns]
    category_ids = [ann["category_id"] for ann in img_anns]

    has_gt = any(cat_id == gt_category_id for cat_id in category_ids)

    if has_gt:
        transform_to_use = transform_gt
        num_aug = num_aug_gt
    else:
        transform_to_use = transform_base
        num_aug = num_aug_base

    for n in range(num_aug):
        augmented = transform_to_use(
            image=image,
            bboxes=bboxes,
            category_ids=category_ids
        )

        aug_img = augmented["image"]
        aug_bboxes = augmented["bboxes"]
        aug_category_ids = augmented["category_ids"]

        # bbox가 하나도 안 남으면 버림
        if len(aug_bboxes) == 0:
            continue

        # 새 파일 이름 (원본과 구분 가능하게)
        orig_stem = Path(file_name).stem    # '0000'
        orig_dir  = Path(file_name).parent  # 'train'
        aug_file_name = f"{orig_stem}_aug_cls_{n+1}.jpg"
        aug_rel_path  = orig_dir / aug_file_name

        # 증강 이미지 저장 (RGB -> BGR)
        save_path = BASE_DIR / aug_rel_path
        save_path.parent.mkdir(parents=True, exist_ok=True)
        save_img_bgr = cv2.cvtColor(aug_img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(str(save_path), save_img_bgr)

        # 새 image 항목
        new_image = {
            "id": next_image_id,
            "width": aug_img.shape[1],
            "height": aug_img.shape[0],
            "file_name": str(aug_rel_path).replace("\\", "/"),
        }
        aug_images.append(new_image)

        # 새 annotation 항목들
        for bbox, cat_id in zip(aug_bboxes, aug_category_ids):
            new_ann = {
                "id": next_ann_id,
                "image_id": next_image_id,
                "category_id": cat_id,
                "bbox": bbox,
                "area": bbox[2] * bbox[3],
                "iscrowd": 0,
            }
            aug_annotations.append(new_ann)
            next_ann_id += 1

        next_image_id += 1

print("증강된 이미지 수:", len(aug_images))
print("증강된 어노테이션 수:", len(aug_annotations))


증강된 이미지 수: 7226
증강된 어노테이션 수: 38736


In [8]:
images_all      = images_le35 + aug_images
annotations_all = annotations_le35 + aug_annotations

coco_with_aug = {
    "images": images_all,
    "annotations": annotations_all,
    "categories": categories,
}

with open(TRAIN_WITH_AUG_JSON, "w") as f:
    json.dump(coco_with_aug, f)
print("[저장 완료]", TRAIN_WITH_AUG_JSON)
print("전체 이미지 수 (원본+aug):", len(images_all))
print("전체 어노테이션 수 (원본+aug):", len(annotations_all))


[저장 완료] /data/ephemeral/home/pro-cv-objectdetection-cv-11-sj/dataset/train_le35_with_aug_train_only.json
전체 이미지 수 (원본+aug): 12087
전체 어노테이션 수 (원본+aug): 60924


In [9]:
# YOLO 디렉토리들 생성
for p in [
    IMG_DIR / "train",
    IMG_DIR / "val",
    IMG_DIR / "test",
    LBL_DIR / "train",
    LBL_DIR / "val"
]:
    p.mkdir(parents=True, exist_ok=True)

print("YOLO_ROOT:", YOLO_ROOT)


YOLO_ROOT: /data/ephemeral/home/pro-cv-objectdetection-cv-11-sj/dataset/yolo_dataset2


In [10]:
# 변환에 사용할 COCO는 "원본+증강" (coco_with_aug)
train_data_all = coco_with_aug

images_all      = train_data_all["images"]
annotations_all = train_data_all["annotations"]
categories      = train_data_all["categories"]

# image_id -> image_info
images_by_id = {img["id"]: img for img in images_all}

# image_id -> annotations 리스트
annotations_by_image = defaultdict(list)
for ann in annotations_all:
    annotations_by_image[ann["image_id"]].append(ann)

# 카테고리 이름 → YOLO class index 매핑
names = [
    "General trash",
    "Paper",
    "Paper pack",
    "Metal",
    "Glass",
    "Plastic",
    "Styrofoam",
    "Plastic bag",
    "Battery",
    "Clothing",
]

name_to_idx = {name: i for i, name in enumerate(names)}
cat_id_to_idx = {}
for cat in categories:
    if cat["name"] not in name_to_idx:
        raise ValueError(f"카테고리 이름이 names 리스트에 없음: {cat['name']}")
    cat_id_to_idx[cat["id"]] = name_to_idx[cat["name"]]

print("cat_id_to_idx:", cat_id_to_idx)


cat_id_to_idx: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}


In [11]:
def link_or_copy(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        return
    try:
        os.symlink(src, dst)
    except (AttributeError, OSError, NotImplementedError):
        shutil.copy2(src, dst)

def convert_split(split_name: str, image_ids):
    """
    split_name: "train" or "val"
    image_ids : 해당 split에 포함될 image_id 리스트/배열
    """
    img_split_dir = IMG_DIR / split_name
    lbl_split_dir = LBL_DIR / split_name
    img_split_dir.mkdir(parents=True, exist_ok=True)
    lbl_split_dir.mkdir(parents=True, exist_ok=True)

    n_images = 0
    n_labels = 0

    for image_id in image_ids:
        info = images_by_id[image_id]

        file_name = info.get("file_name", info.get("filename"))
        if file_name is None:
            raise KeyError("image entry missing 'file_name'/'filename'")

        # 원본/증강 이미지 경로
        src_img = BASE_DIR / file_name      # 예: dataset/train/0000.jpg, dataset/train/0000_aug...
        if not src_img.exists():
            src_img = BASE_DIR / "train" / Path(file_name).name
        if not src_img.exists():
            print(f"[{split_name}] WARNING: image not found for id {image_id}: {file_name}")
            continue

        base_name = Path(file_name).name   # '0000.jpg' or '0000_aug_cls_1.jpg'
        dst_img = img_split_dir / base_name
        link_or_copy(src_img, dst_img)
        n_images += 1

        # === (3-1) 실제 이미지 크기 사용 ===
        with Image.open(src_img) as im:
            width, height = im.size

        anns = annotations_by_image.get(image_id, [])
        if not anns:
            # 객체가 하나도 없으면 라벨 파일 생략 가능
            continue

        lines = []
        for ann in anns:
            cat_original = ann["category_id"]
            cls_idx = cat_id_to_idx[cat_original]   # 0~9

            # COCO bbox: [x_min, y_min, w, h]
            x_min, y_min, w, h = ann["bbox"]
            x_c = x_min + w / 2.0
            y_c = y_min + h / 2.0

            # 정규화
            x_c /= width
            y_c /= height
            w   /= width
            h   /= height

            # 범위 클립
            x_c = min(max(x_c, 0.0), 1.0)
            y_c = min(max(y_c, 0.0), 1.0)
            w   = min(max(w, 0.0), 1.0)
            h   = min(max(h, 0.0), 1.0)

            # === (3-2) 너무 작은 bbox 필터링 ===
            if w < 1e-6 or h < 1e-6:
                continue

            lines.append(f"{cls_idx} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}")

        if not lines:
            # 이 이미지에서 의미 있는 bbox가 안 남으면 label 파일 생성 X
            continue

        label_path = lbl_split_dir / (Path(base_name).stem + ".txt")
        with open(label_path, "w") as f:
            f.write("\n".join(lines))
        n_labels += 1

    print(f"[{split_name}] wrote {n_images} images and {n_labels} label files")


In [12]:
# 증강 이미지 id는 모두 train에서 나왔으므로, 전부 train split에 포함
aug_image_ids = [img["id"] for img in aug_images]

train_ids_full = list(train_ids) + aug_image_ids
val_ids_full   = list(val_ids)   # val 쪽에는 증강 이미지 없음

convert_split("train", train_ids_full)
convert_split("val", val_ids_full)


[train] wrote 11110 images and 11110 label files
[val] wrote 977 images and 977 label files


In [13]:
# === test.json 로드 후 images/test 로 복사 (라벨 없음) ===
with open(TEST_JSON, "r") as f:
    test_data = json.load(f)

images_test = test_data["images"]
img_test_dir = IMG_DIR / "test"
img_test_dir.mkdir(parents=True, exist_ok=True)

test_filename_to_id = {}

for img in images_test:
    file_name = img.get("file_name", img.get("filename"))
    if file_name is None:
        continue

    src_img = BASE_DIR / file_name  # ex) dataset/test/3105.jpg
    if not src_img.exists():
        src_img = BASE_DIR / "test" / Path(file_name).name
    if not src_img.exists():
        print(f"[test] WARNING: image not found: {file_name}")
        continue

    dst_img = img_test_dir / Path(file_name).name
    link_or_copy(src_img, dst_img)

    test_filename_to_id[file_name] = img["id"]

print("Copied test images:", len(list(img_test_dir.glob('*.*'))))

# === YOLO data.yaml 생성 ===
data_yaml_path = YOLO_ROOT / "trash10.yaml"

with open(data_yaml_path, "w") as f:
    f.write(f"path: {YOLO_ROOT}\n")
    f.write("train: images/train\n")
    f.write("val: images/val\n")
    f.write("test: images/test\n")
    f.write("names:\n")
    for i, name in enumerate(names):
        f.write(f"  {i}: {name}\n")

print("=== trash10.yaml ===")
print(data_yaml_path.read_text())


Copied test images: 4871
=== trash10.yaml ===
path: /data/ephemeral/home/pro-cv-objectdetection-cv-11-sj/dataset/yolo_dataset2
train: images/train
val: images/val
test: images/test
names:
  0: General trash
  1: Paper
  2: Paper pack
  3: Metal
  4: Glass
  5: Plastic
  6: Styrofoam
  7: Plastic bag
  8: Battery
  9: Clothing

