In [1]:
from pathlib import Path
import shutil
import random
import itertools

In [2]:
passed_dir = Path(r'D:\data\SJJ\SingleOCR\for_crowd_sourcing\labeled_cut_images\passed')
filtered_dir = Path(r'D:/data/SJJ/SingleOCR/from_crowd_sourcing/210819')
dst_dir = Path(r'D:\data\SJJ\SingleOCR\mixed_passed7000_filtered1000')
assert passed_dir.is_dir()
assert filtered_dir.is_dir()
assert dst_dir.is_dir()

In [3]:
NB_CLASS_IMAGES = 8000
MAX_FILTERED_IMAGES = 1000
RANDOM_SEED = 10

In [4]:
def display_each_labeled_images(target_dir: Path):
    assert target_dir.is_dir()
    for label_dir in target_dir.iterdir():
        nb_images = sum(1 for _ in label_dir.glob('*.jpg'))
        print(f'label {label_dir.name} has {nb_images} images.')

In [5]:
display_each_labeled_images(passed_dir)

label 00 has 38222 images.
label 01 has 56339 images.
label 02 has 32402 images.
label 03 has 20501 images.
label 04 has 15015 images.
label 05 has 17094 images.
label 06 has 7643 images.
label 07 has 13745 images.
label 08 has 7951 images.
label 09 has 4793 images.
label 10 has 301075 images.
label 11 has 17270 images.
label 12 has 8 images.
label 13 has 3094 images.
label 14 has 4326 images.
label 15 has 12172 images.
label 16 has 192 images.
label 90 has 0 images.


In [6]:
display_each_labeled_images(filtered_dir)

label 00 has 1856 images.
label 01 has 6790 images.
label 02 has 1863 images.
label 03 has 934 images.
label 04 has 953 images.
label 05 has 834 images.
label 06 has 266 images.
label 07 has 1862 images.
label 08 has 406 images.
label 09 has 137 images.
label 10 has 44249 images.
label 11 has 1796 images.
label 12 has 439 images.
label 13 has 229 images.
label 14 has 7675 images.
label 15 has 5838 images.
label 16 has 197 images.


In [6]:
def fetch_file_paths(src_dir: Path, max_samples: int, ext: str = 'jpg', seed=RANDOM_SEED) -> list:
    """
    画像ディレクトリから指定の拡張子の画像パスをmax_samples数だけサンプルしてリストとして返す
    画像数がmax_samples以下の場合は全画像パスを返す
    """
    assert isinstance(ext, str) and len(ext) > 0
    # 拡張子のドットの有無を吸収
    if ext[0] == '.':
        ext = ext[1:]
    file_path_list = list(src_dir.glob('*.' + ext))
    if max_samples >= len(file_path_list):
        return file_path_list
    random.seed(seed)
    samples = random.sample(file_path_list, max_samples)
    return sorted(samples)

In [7]:
for f_label_dir in filtered_dir.iterdir():
    label = f_label_dir.name
    p_label_dir = passed_dir / label
    assert p_label_dir.is_dir()
    # 画像パスを合計がNB_CLASS_IMAGESになるようにサンプルする
    assert MAX_FILTERED_IMAGES <= NB_CLASS_IMAGES
    f_paths = fetch_file_paths(f_label_dir, max_samples=MAX_FILTERED_IMAGES)
    max_p_images = NB_CLASS_IMAGES - len(f_paths)
    p_paths = fetch_file_paths(p_label_dir, max_samples=max_p_images)
    
    # サンプルした画像をdst_dirのラベルディレクトリにコピーする
    dst_label_dir = dst_dir / label
    dst_label_dir.mkdir(exist_ok=True)
    nb_images = len(f_paths) + len(p_paths)
    assert nb_images <= NB_CLASS_IMAGES
    for img_path in itertools.chain(f_paths, p_paths):
        shutil.copy(img_path, dst_label_dir)

In [8]:
# 画像数のチェック
for label_dir in dst_dir.iterdir():
    print(f'label: {label_dir.name}\timages: {sum(1 for _ in label_dir.glob("*.jpg"))}')

label: 00	images: 8000
label: 01	images: 8000
label: 02	images: 8000
label: 03	images: 8000
label: 04	images: 8000
label: 05	images: 8000
label: 06	images: 7909
label: 07	images: 8000
label: 08	images: 8000
label: 09	images: 4930
label: 10	images: 8000
label: 11	images: 8000
label: 12	images: 447
label: 13	images: 3323
label: 14	images: 5326
label: 15	images: 8000
label: 16	images: 389
