In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import kwcoco
import os
from pathlib import Path
import shutil
import numpy as np
from tqdm import tqdm

def coco2yolo(
                coco:kwcoco.coco_dataset.CocoDataset, 
                output_dir:str    = ".", 
                dataset_name:str  = None,
                config_name:str   = "config.yaml",
                cat_file_name:str = "categories.txt",
                exists_ok:bool    = True,
                task:str          = None # None or "BBox" or "Segmentation" or "Keypoint"
              ):

    # 前準備
    if dataset_name is None or not isinstance(dataset_name, str):
        dataset_name = str(Path(coco.fpath).name.split(".")[0])

    base_save_dir = Path(output_dir) / Path(dataset_name)
    save_images_dir = base_save_dir / Path("images")
    save_labels_dir = base_save_dir / Path("labels")

    if Path(output_dir).exists() and exists_ok:
        shutil.rmtree(Path(output_dir))

    save_images_dir.mkdir(parents=True, exist_ok=True)
    save_labels_dir.mkdir(parents=True, exist_ok=True)

    config_path = Path(output_dir)/config_name
    cat_file_path = Path(output_dir)/cat_file_name

    # カテゴリーを書き出す
    if not cat_file_path.exists():
        with open(str(cat_file_path), mode = "w") as f:
            for i in range(1, max(list(coco.cats.keys()))+1):
                try:
                    f.write(f"{i}: {coco.cats[i]['name']}\n")
                except:
                    f.write(f"{i}:\n")

    # YOLOv5学習用configを書き出す
    if not config_path.exists():
        with open(str(config_path), mode = "w") as f:
            f.write(f"path:  {Path(output_dir).name}\n")
            f.write(f"train: train/images\n")
            f.write(f"val:   valid/images\n")
            f.write(f"test:  test/images\n")
            f.write(f"\n")
            f.write(f"nc: {len(list(coco.cats.keys()))+1}\n")
            f.write(f"\n")
            f.write(f"nkpt:\n")
            for i in range(0, max(list(coco.cats.keys()))+1):
                try:
                    print(coco.cats[i])
                    f.write(f"  {i}: {len(coco.cats[i]['keypoints'])}\n")
                except:
                    f.write(f"  {i}: 0\n")
            
            f.write(f"\n")
            f.write(f"names: \n")
            for i in range(0, max(list(coco.cats.keys()))+1):
                try:
                    f.write(f"  {i}: {coco.cats[i]['name']}\n")
                except:
                    f.write(f"  {i}:\n")

    # アノテーションを書き出す
    gids = list(coco.imgs.keys())
    for i, gid in enumerate(tqdm(gids, desc="coco2yolo")):
        img_src_path = Path(coco.get_image_fpath(gid))
        img_dst_path = save_images_dir/img_src_path.name
        shutil.copyfile(img_src_path, img_dst_path)

        img = coco.load_image(gid)
        with open(f'{str(save_labels_dir)}/{img_dst_path.name.split(".")[0]}.txt', mode = "w") as f:
            aids = coco.gid_to_aids[gid]
            for j, aid in enumerate(aids):
                cls = coco.index.anns[aid]["category_id"]
                dh, dw = (1/img.shape[0], 1/img.shape[1])

                if (task is None) or (task == "BBox") or (task == "Keypoint"):
                    x, y, w, h = coco.index.anns[aid]["bbox"]

                    center_x      = (x + w / 2) * dw
                    center_y      = (y + h / 2) * dh
                    width         = w * dw
                    height        = h * dh

                    center_x = max(0, min(center_x, 1.0))
                    center_y = max(0, min(center_y, 1.0))
                    width    = max(0, min(width, 1.0))
                    height   = max(0, min(height, 1.0))

                    annotations = [cls, center_x, center_y, width, height]

                if "keypoints" in coco.index.anns[aid] and (task == "Keypoint"):
                    kpts = coco.index.anns[aid]["keypoints"]
                    kpts = [kpts[idx:idx + 3] for idx in range(0, len(kpts), 3)]
                    kpts = [[kp[0]*dw, kp[1]*dh, kp[2]] for kp in kpts]
                    kpts = [i for kp in kpts for i in kp]
                    annotations += kpts

                if "segmentation" in coco.index.anns[aid] and (task == "Segmentation"):
                    annotations = [cls]
                    masks = coco.index.anns[aid]["segmentation"]
                    for mask in masks:
                        mask_n = [mask[idx:idx + 2] for idx in range(0, len(mask), 2)]
                        mask_n = [[m[0]*dw, m[1]*dh] for m in mask_n]
                        mask_n = [i for m in mask_n for i in m]

                        annotations += mask_n
                        print(*annotations, file=f)

                    continue

                print(*annotations, file=f)

    return None

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
datasets = [
    # {
    #     "json_path" : "../dataset/sentan_dataset_coco/sentan_dataset/sentan_dataset.json",
    #     "img_root"  : "../dataset/sentan_dataset_coco/sentan_dataset/",
    # },
    {
        "json_path" : "../dataset/nakane_grape1/nakane_grape1.json",
        "img_root"  : "../dataset/nakane_grape1",
    },
]

for dataset in datasets:
    coco = kwcoco.CocoDataset(data=dataset["json_path"])
    coco.img_root = dataset["img_root"]

    coco2yolo(  coco         = coco, 
                output_dir   = f"../yolo_datasets/{Path(dataset['json_path']).stem}", 
                dataset_name = "train",
                exists_ok    = True,
                task="BBox",
                )

    coco2yolo(  coco         = coco, 
                output_dir   = f"../yolo_datasets/{Path(dataset['json_path']).stem}", 
                dataset_name = "valid",
                exists_ok    = False,
                task="BBox",
                )


{'id': 3, 'name': 'cluster', 'supercategory': '', 'color': '#02fd05', 'metadata': {}, 'creator': 'onozaka', 'keypoint_colors': []}


coco2yolo:   0%|          | 0/96 [00:00<?, ?it/s]

coco2yolo: 100%|██████████| 96/96 [00:04<00:00, 21.97it/s]
coco2yolo: 100%|██████████| 96/96 [00:04<00:00, 22.76it/s]
