In [None]:
%pip install albumentations opencv-python numpy

In [None]:
#!/usr/bin/env python3
"""
yolo_hit_augment_split.py

Over-sample a single-class ("hit") YOLO dataset, preserving an 80/20 train/val split.

Layout:
  data/
    train/
      images/
      labels/
    val/
      images/
      labels/

Outputs into:
  augmented/
    train/
      images/
      labels/
    val/
      images/
      labels/
"""

import os
import random
import shutil
from glob import glob

import cv2
import albumentations as A
import numpy as np

# --- USER CONFIG ---
DATA_DIR        = "//datax/scratch/jliang/dataset_final_small"
TRAIN_IMG_DIR   = os.path.join(DATA_DIR, "train", "images")
TRAIN_LBL_DIR   = os.path.join(DATA_DIR, "train", "labels")
VAL_IMG_DIR     = os.path.join(DATA_DIR, "val",   "images")
VAL_LBL_DIR     = os.path.join(DATA_DIR, "val",   "labels")
OUTPUT_DIR      = "/datax/scratch/jliang/augmented"
TARGET_TOTAL    = 2000      # desired total images (train+val) after augmentation
SPLIT_RATIO     = {"train": 0.8, "val": 0.2}
SEED            = 42
# ---------------------

random.seed(SEED)
np.random.seed(SEED)

def gather_pairs(img_dir, lbl_dir):
    pairs = []
    for img_path in glob(os.path.join(img_dir, "*")):
        stem, ext = os.path.splitext(os.path.basename(img_path))
        lbl_path = os.path.join(lbl_dir, stem + ".txt")
        if os.path.exists(lbl_path):
            pairs.append((img_path, lbl_path))
    return pairs

# 1) Gather train/val pairs
# train_pairs = gather_pairs(TRAIN_IMG_DIR, TRAIN_LBL_DIR)
val_pairs   = gather_pairs(VAL_IMG_DIR,   VAL_LBL_DIR)

# if not train_pairs:
#     raise RuntimeError("No train images+labels found!")
if not val_pairs:
    raise RuntimeError("No val images+labels found!")

# n_train_orig = len(train_pairs)
n_orig_total   = len(val_pairs)
# n_orig_total = n_val_orig

print(f"Original counts: {n_orig_total})")

# 2) Compute how many augmentations needed per split
needed_total = max(0, TARGET_TOTAL - n_orig_total)
n_train_aug  = int(round(needed_total * SPLIT_RATIO["train"]))
n_val_aug    = needed_total - n_train_aug

print(f"Will create {n_train_aug} train augmentations and {n_val_aug} val augmentations")

# 3) Prepare output dirs
for split, pairs in [("train", train_pairs), ("val", val_pairs)]:
    for sub in ("images", "labels"):
        od = os.path.join(OUTPUT_DIR, split, sub)
        if os.path.exists(od):
            shutil.rmtree(od)
        os.makedirs(od)
    # copy originals
    for imgf, lblf in pairs:
        stem, ext = os.path.splitext(os.path.basename(imgf))
        shutil.copy(imgf, os.path.join(OUTPUT_DIR, split, "images", stem + ext))
        shutil.copy(lblf, os.path.join(OUTPUT_DIR, split, "labels", stem + ".txt"))

# 4) Set up Albumentations pipeline (bbox-aware)
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomScale(scale_limit=0.2, p=0.5),
    A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5),
    A.HueSaturationValue(p=0.5),
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

def augment_split(pairs, n_aug, out_dir):
    i = 0
    while i < n_aug:
        # --- pick a random sample + load image ---
        img_path, lbl_path = random.choice(pairs)
        img = cv2.imread(img_path)
        if img is None:
            continue

        # --- load all YOLO bboxes from that single-class file ---
        bboxes = []
        labels = []
        with open(lbl_path) as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 5:
                    continue
                _, x_c, y_c, bw, bh = parts
                try:
                    bboxes.append([float(x_c), float(y_c), float(bw), float(bh)])
                    labels.append(0)
                except ValueError:
                    # malformed line
                    continue
        if not bboxes:
            continue

        # --- apply augmentation (skip if Albumentations still blows up) ---
        try:
            aug = transform(image=img, bboxes=bboxes, class_labels=labels)
        except ValueError:
            continue

        img_aug    = aug["image"]
        bboxes_aug = aug["bboxes"]
        labels_aug = aug["class_labels"]

        # --- clamp tiny out-of-bounds to [0,1] + drop zero-area boxes ---
        clamped = []
        for (xc, yc, w, h), lbl in zip(bboxes_aug, labels_aug):
            xc = min(max(xc, 0.0), 1.0)
            yc = min(max(yc, 0.0), 1.0)
            w  = min(max(w,  0.0), 1.0)
            h  = min(max(h,  0.0), 1.0)
            if w <= 0 or h <= 0:
                continue
            clamped.append((xc, yc, w, h, lbl))
        if not clamped:
            continue

        # --- write out the augmented image & label file ---
        stem    = os.path.splitext(os.path.basename(img_path))[0] + f"_aug_{i}"
        out_img = os.path.join(out_dir, "images", stem + os.path.splitext(img_path)[1])
        out_lbl = os.path.join(out_dir, "labels", stem + ".txt")

        cv2.imwrite(out_img, img_aug)
        with open(out_lbl, "w") as fw:
            for xc, yc, w, h, lbl in clamped:
                fw.write(f"{lbl} {xc:.6f} {yc:.6f} {w:.6f} {h:.6f}\n")

        i += 1
        if i % 50 == 0 or i == n_aug:
            print(f" → {i}/{n_aug} aug in {os.path.basename(out_dir)} done")


In [None]:
# 5) Run augmentation on each split
augment_split(train_pairs,
              n_train_aug,
              os.path.join(OUTPUT_DIR, "train"))
augment_split(val_pairs,
              n_val_aug,
              os.path.join(OUTPUT_DIR, "val"))

print("All done ", OUTPUT_DIR)
