In [1]:
from pathlib import Path
import shutil
import random
import itertools
from tqdm.notebook import tqdm

In [2]:
# passed_dir = Path(r'D:\data\SJJ\SingleOCR\for_crowd_sourcing\labeled_cut_images\passed')
passed_root = Path(r'D:\data\SJJ\SingleOCR\for_crowd_sourcing\all_labeled_cut_images')
groups = [str(i).zfill(2) for i in range(10)]  # [00..09]
passed_dirs = [passed_root / group / 'passed' for group in groups]
filtered_dir = Path(r'D:/data/SJJ/SingleOCR/from_crowd_sourcing/210819')
dst_dir = Path(r'D:\data\SJJ\SingleOCR\mixed_passed11000_filtered1000')
assert all(passed_dir.is_dir() for passed_dir in passed_dirs)
assert filtered_dir.is_dir()
assert dst_dir.is_dir()

In [3]:
NB_CLASS_IMAGES = 12000
MAX_FILTERED_IMAGES = 1000
RANDOM_SEED = 10

In [4]:
def fetch_file_paths(src_dir: Path, max_samples: int, ext: str = 'jpg', seed=RANDOM_SEED) -> list:
    """
    画像ディレクトリから指定の拡張子の画像パスをmax_samples数だけサンプルしてリストとして返す
    画像数がmax_samples以下の場合は全画像パスを返す
    """
    assert isinstance(ext, str) and len(ext) > 0
    # 拡張子のドットの有無を吸収
    if ext[0] == '.':
        ext = ext[1:]
    # src_dirがディレクトリなら画像ファイルパスを取得、リストなら全ディレクトリから画像ファイルパスを取得
    if isinstance(src_dir, Path):
        file_path_list = list(src_dir.glob('*.' + ext))
    elif isinstance(src_dir, list) and all(isinstance(d, Path) for d in src_dir):
        file_path_list = [path for path in itertools.chain(*(d.glob('*.' + ext) for d in src_dir))]
    else:
        raise ValueError(f'`src_dir` should be Path or [Path] object, but `{type(src_dir)}`.')
        
    if max_samples >= len(file_path_list):
        return file_path_list
    random.seed(seed)
    samples = random.sample(file_path_list, max_samples)
    return sorted(samples)

In [8]:
nb_labels = sum(1 for _ in filtered_dir.iterdir())

for f_label_dir in tqdm(filtered_dir.iterdir(), total=nb_labels, desc='label directory'):
    label = f_label_dir.name
    # p_label_dir = passed_dir / label
    # assert p_label_dir.is_dir()
    p_label_dirs = [d / label for d in passed_dirs]
    assert all(d.is_dir() for d in p_label_dirs)
    # 画像パスを合計がNB_CLASS_IMAGESになるようにサンプルする
    assert MAX_FILTERED_IMAGES <= NB_CLASS_IMAGES
    f_paths = fetch_file_paths(f_label_dir, max_samples=MAX_FILTERED_IMAGES)
    max_p_images = NB_CLASS_IMAGES - len(f_paths)
    p_paths = fetch_file_paths(p_label_dirs, max_samples=max_p_images)
    
    # サンプルした画像をdst_dirのラベルディレクトリにコピーする
    dst_label_dir = dst_dir / label
    dst_label_dir.mkdir(exist_ok=True)
    nb_images = len(f_paths) + len(p_paths)
    assert nb_images <= NB_CLASS_IMAGES
    for img_path in tqdm(itertools.chain(f_paths, p_paths), total=nb_images, desc='copy images', leave=False):
        shutil.copy(img_path, dst_label_dir)

In [6]:
# 画像数のチェック
for label_dir in dst_dir.iterdir():
    print(f'label: {label_dir.name}\timages: {sum(1 for _ in label_dir.glob("*.jpg"))}')

label: 00	images: 12000
label: 01	images: 12000
label: 02	images: 12000
label: 03	images: 12000
label: 04	images: 12000
label: 05	images: 12000
label: 06	images: 12000
label: 07	images: 12000
label: 08	images: 12000
label: 09	images: 12000
label: 10	images: 12000
label: 11	images: 12000
label: 12	images: 507
label: 13	images: 12000
label: 14	images: 12000
label: 15	images: 12000
label: 16	images: 2554
