In [None]:
#!/usr/bin/env python
"""
Auto-generate YOLO dataset from video using pseudo-labeling and synthetic augmentation.

Author: Никита
Date: 2025-11-06
"""

import os
import cv2
import random
import shutil
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
from ultralytics import YOLO
from tqdm import tqdm
import albumentations as A


# ==================== КОНФИГУРАЦИЯ ====================
class Config:
    """Configuration parameters for dataset creation."""
    VIDEO_PATH: str = "crowd.mp4"
    OUTPUT_DIR: str = "dataset"
    TRAIN_RATIO: float = 0.8
    FRAME_SKIP: int = 1
    MIN_CONF: float = 0.4
    SYNTHETIC_RATIO: float = 0.5  # +50% synthetic images to train
    MAX_COPY_PASTE: int = 2
    SEED: int = 42


# ==================== УТИЛИТЫ ====================
def setup_directories(output_dir: str) -> Dict[str, str]:
    """
    Create dataset directory structure.

    Args:
        output_dir (str): Root dataset directory.

    Returns:
        dict: Paths to image and label directories.
    """
    paths = {
        "img_train": os.path.join(output_dir, "images", "train"),
        "img_val": os.path.join(output_dir, "images", "val"),
        "lbl_train": os.path.join(output_dir, "labels", "train"),
        "lbl_val": os.path.join(output_dir, "labels", "val"),
    }
    for p in paths.values():
        os.makedirs(p, exist_ok=True)
    return paths


def extract_frames(video_path: str, frame_skip: int = 1) -> List[Tuple[int, np.ndarray]]:
    """
    Extract frames from video.

    Args:
        video_path (str): Path to input video.
        frame_skip (int): Extract every Nth frame.

    Returns:
        List[Tuple[int, np.ndarray]]: List of (original_index, frame).
    """
    if not Path(video_path).exists():
        raise FileNotFoundError(f"Video not found: {video_path}")

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {video_path}")

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    duration = total_frames / fps if fps > 0 else 0

    print(f"Video: {total_frames} frames, {duration:.1f}s, {fps:.1f} FPS")

    frames = []
    frame_idx = 0
    saved_idx = 0

    print("Extracting frames...")
    with tqdm(total=total_frames, unit="frame", colour="blue") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            if frame_idx % frame_skip == 0:
                frames.append((saved_idx, frame.copy()))
                saved_idx += 1
            frame_idx += 1
            pbar.update(1)

    cap.release()
    print(f"Extracted {len(frames)} frames")
    return frames


def pseudo_label_frames(
    frames: List[Tuple[int, np.ndarray]],
    model: YOLO,
    min_conf: float = 0.4
) -> List[Tuple[int, np.ndarray, List[str]]]:
    """
    Run YOLO inference and convert to YOLO format labels.

    Args:
        frames: List of (idx, frame).
        model: Loaded YOLO model.
        min_conf: Minimum confidence.

    Returns:
        List of (idx, frame, labels_list).
    """
    print("Pseudo-labeling frames with YOLO...")
    labeled = []

    for idx, (_, frame) in enumerate(tqdm(frames, unit="frame")):
        h, w = frame.shape[:2]
        results = model(frame, conf=min_conf, classes=[0], verbose=False)[0]
        labels = []

        for box in results.boxes:
            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
            conf = box.conf.item()

            # YOLO format: class x_center y_center width height (normalized)
            x_center = (x1 + x2) / 2 / w
            y_center = (y1 + y2) / 2 / h
            bw = (x2 - x1) / w
            bh = (y2 - y1) / h
            labels.append(f"0 {x_center:.6f} {y_center:.6f} {bw:.6f} {bh:.6f}")

        labeled.append((idx, frame, labels))

    return labeled


def split_and_save(
    labeled_frames: List[Tuple[int, np.ndarray, List[str]]],
    paths: Dict[str, str],
    train_ratio: float
) -> Tuple[List[int], List[int]]:
    """
    Split into train/val and save images + labels.

    Args:
        labeled_frames: List of labeled data.
        paths: Directory paths.
        train_ratio: Fraction for training.

    Returns:
        train_indices, val_indices.
    """
    random.seed(Config.SEED)
    random.shuffle(labeled_frames)
    split_idx = int(len(labeled_frames) * train_ratio)

    train_data = labeled_frames[:split_idx]
    val_data = labeled_frames[split_idx:]

    def save_set(data, img_dir, lbl_dir, start_idx):
        indices = []
        for i, (_, frame, labels) in enumerate(data):
            global_idx = start_idx + i
            img_name = f"frame_{global_idx:06d}.jpg"
            lbl_name = f"frame_{global_idx:06d}.txt"

            cv2.imwrite(os.path.join(img_dir, img_name), frame)
            with open(os.path.join(lbl_dir, lbl_name), "w") as f:
                if labels:
                    f.write("\n".join(labels))
            indices.append(global_idx)
        return indices

    print("Saving base dataset...")
    train_indices = save_set(train_data, paths["img_train"], paths["lbl_train"], 0)
    val_indices = save_set(val_data, paths["img_val"], paths["lbl_val"], len(train_data))

    return train_indices, val_indices


def generate_synthetic(
    train_indices: List[int],
    paths: Dict[str, str],
    model: YOLO,
    synth_count: int
) -> None:
    """
    Generate synthetic images using augmentation + copy-paste.

    Args:
        train_indices: List of training image indices.
        paths: Dataset paths.
        model: YOLO model (for re-detection if needed).
        synth_count: Number of synthetic images to generate.
    """
    if synth_count <= 0:
        return

    # Augmentation pipeline
    aug = A.Compose([
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
        A.GaussNoise(var_limit=(10, 50), p=0.3),
        A.MotionBlur(blur_limit=3, p=0.3),
        A.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.3, p=0.2),
        A.RandomRain(p=0.2),
        A.HorizontalFlip(p=0.5),
        A.RandomScale(scale_limit=0.2, p=0.5),
        A.Rotate(limit=15, p=0.5),
    ], bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))

    def copy_paste(src_img, src_labels, dst_img, max_pastes=2):
        h, w = src_img.shape[:2]
        people = []
        for label in src_labels:
            parts = list(map(float, label.split()))
            if len(parts) != 5:
                continue
            _, cx, cy, bw, bh = parts
            x1 = int((cx - bw/2) * w)
            y1 = int((cy - bh/2) * h)
            x2 = int((cx + bw/2) * w)
            y2 = int((cy + bh/2) * h)
            if x1 < x2 and y1 < y2:
                person = src_img[y1:y2, x1:x2]
                if person.size > 0:
                    people.append((person, (cx, cy, bw, bh)))

        new_labels = []
        pasted = 0
        dst_h, dst_w = dst_img.shape[:2]
        for person, (cx, cy, bw, bh) in random.sample(people, k=min(len(people), max_pastes)):
            if pasted >= max_pastes:
                break
            ph, pw = person.shape[:2]
            scale = random.uniform(0.6, 1.4)
            nw, nh = int(pw * scale), int(ph * scale)
            if nw <= 0 or nh <= 0:
                continue
            person_resized = cv2.resize(person, (nw, nh))

            x = random.randint(0, max(1, dst_w - nw))
            y = random.randint(0, max(1, dst_h - nh))

            dst_img[y:y+nh, x:x+nw] = person_resized

            new_cx = (x + nw/2) / dst_w
            new_cy = (y + nh/2) / dst_h
            new_bw = nw / dst_w
            new_bh = nh / dst_h
            new_labels.append(f"0 {new_cx:.6f} {new_cy:.6f} {new_bw:.6f} {new_bh:.6f}")
            pasted += 1

        return dst_img, new_labels

    print(f"Generating {synth_count} synthetic images...")
    synth_idx = len(train_indices)

    for _ in tqdm(range(synth_count), unit="img"):
        src_idx = random.choice(train_indices)
        src_img_path = os.path.join(paths["img_train"], f"frame_{src_idx:06d}.jpg")
        src_lbl_path = os.path.join(paths["lbl_train"], f"frame_{src_idx:06d}.txt")

        src_img = cv2.imread(src_img_path)
        with open(src_lbl_path, "r") as f:
            src_labels = [line.strip() for line in f if line.strip()]

        # Augmentation
        bboxes = [list(map(float, l.split()[1:])) for l in src_labels]
        class_labels = ["person"] * len(bboxes)

        if bboxes:
            augmented = aug(image=src_img, bboxes=bboxes, class_labels=class_labels)
            aug_img = augmented["image"]
            aug_bboxes = augmented["bboxes"]
        else:
            aug_img = src_img
            aug_bboxes = []

        # Copy-paste
        final_img, paste_labels = copy_paste(src_img, src_labels, aug_img.copy(), Config.MAX_COPY_PASTE)

        # Combine labels
        final_labels = [f"0 {x:.6f} {y:.6f} {w:.6f} {h:.6f}" for x, y, w, h in aug_bboxes]
        final_labels.extend(paste_labels)

        # Save
        synth_name = f"synth_{synth_idx:06d}"
        cv2.imwrite(os.path.join(paths["img_train"], f"{synth_name}.jpg"), final_img)
        with open(os.path.join(paths["lbl_train"], f"{synth_name}.txt"), "w") as f:
            if final_labels:
                f.write("\n".join(final_labels))
        synth_idx += 1

    print(f"Added {synth_count} synthetic images to train set")


def write_data_yaml(output_dir: str) -> None:
    """
    Create data.yaml for YOLO training.

    Args:
        output_dir (str): Dataset root.
    """
    yaml_content = f"""path: {Path(output_dir).resolve()}
train: images/train
val: images/val

nc: 1
names: ['person']
"""
    yaml_path = os.path.join(output_dir, "data.yaml")
    with open(yaml_path, "w", encoding="utf-8") as f:
        f.write(yaml_content)
    print(f"data.yaml created: {yaml_path}")


# ==================== ОСНОВНАЯ ФУНКЦИЯ ====================
def create_dataset(
    video_path: str = Config.VIDEO_PATH,
    output_dir: str = Config.OUTPUT_DIR,
    frame_skip: int = Config.FRAME_SKIP,
    min_conf: float = Config.MIN_CONF,
    train_ratio: float = Config.TRAIN_RATIO,
    synthetic_ratio: float = Config.SYNTHETIC_RATIO
) -> str:
    """
    Main function: create full YOLO dataset from video.

    Returns:
        str: Path to generated data.yaml
    """
    random.seed(Config.SEED)
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)

    # 1. Setup
    paths = setup_directories(output_dir)
    model = YOLO("yolov8n.pt")

    # 2. Extract frames
    frames = extract_frames(video_path, frame_skip)

    # 3. Pseudo-label
    labeled = pseudo_label_frames(frames, model, min_conf)

    # 4. Split & save
    train_indices, val_indices = split_and_save(labeled, paths, train_ratio)

    # 5. Synthetic data
    synth_count = int(len(train_indices) * synthetic_ratio)
    generate_synthetic(train_indices, paths, model, synth_count)

    # 6. data.yaml
    write_data_yaml(output_dir)

    # Summary
    train_imgs = len(os.listdir(paths["img_train"]))
    val_imgs = len(os.listdir(paths["img_val"]))
    print("\n" + "="*50)
    print("DATASET CREATED SUCCESSFULLY!")
    print(f"Dataset: {output_dir}")
    print(f"Train: {train_imgs} images")
    print(f"Val: {val_imgs} images")
    print(f"data.yaml: {os.path.join(output_dir, 'data.yaml')}")
    print("="*50)

    return str(output_path / "data.yaml")


# ==================== ТОЧКА ВХОДА ====================
def main() -> None:
    """Entry point."""
    data_yaml = create_dataset()
    print(f"\nГотово! Запустите обучение:")
    print(f"   python train.py  # (из предыдущего задания)")
    print(f"   или:")
    print(f"   yolo train data={data_yaml} model=yolov8s.pt epochs=100 imgsz=640")


if __name__ == "__main__":
    main()