In [1]:
import glob
import os
import cv2
import shutil
from empatches import EMPatches
from tqdm import tqdm

In [2]:
# define input data
data_dir = "data/FlawDetectionTrainingImages"
positive_imgs = glob.glob(os.path.join(data_dir, "positive", "*jpg"))
negative_imgs = glob.glob(os.path.join(data_dir, "negative", "*jpg"))

# define output directories
output_dir = "data/FlawDetectionTrainingImages/patches"
positive_output_dir = os.path.join(output_dir, "positive")
negative_output_dir = os.path.join(output_dir, "negative")

In [3]:
# a) split positive images into patches using no overlap

# remove existing patches
shutil.rmtree(positive_output_dir, ignore_errors=True)
os.makedirs(positive_output_dir, exist_ok=True)

# patch new ones
n_img = 0
for img_path in tqdm(positive_imgs):
    img = cv2.imread(img_path)
    emp = EMPatches()
    img_patches, indices = emp.extract_patches(img, patchsize=256, overlap=0)
    for i, patch in enumerate(img_patches):
        cv2.imwrite(os.path.join(positive_output_dir, f"{n_img}.jpg"), patch)
        n_img += 1

100%|██████████| 101/101 [00:10<00:00,  9.53it/s]


In [4]:
# b) split negative images into patches using larger overlap
# indicate original position in image
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches

# remove existing patches
shutil.rmtree(negative_output_dir, ignore_errors=True)
os.makedirs(negative_output_dir, exist_ok=True)
shutil.rmtree(os.path.join(negative_output_dir, "overviews"), ignore_errors=True)
os.makedirs(os.path.join(negative_output_dir, "overviews"), exist_ok=True)

# patch new ones
n_patch = 0
for img_path in tqdm(negative_imgs):
    img = cv2.imread(img_path)
    img_stem = os.path.split(img_path)[-1].split(".jpg")[0]
    emp = EMPatches()
    img_patches, indices = emp.extract_patches(img, patchsize=256, overlap=0.5)
    # patch
    for i, patch in enumerate(img_patches):
        out_path = f"{img_stem}_{n_patch}.jpg"
        cv2.imwrite(os.path.join(negative_output_dir, out_path), patch)
        n_patch += 1
    # plot overview
    fig, axs = plt.subplots(1, 1, figsize=(10, 10))
    axs.imshow(np.flip(img, 2))
    for i, idxs in enumerate(indices):
        rect = patches.Rectangle(
            (idxs[2], idxs[0]),
            256,
            256,
            linewidth=1,
            edgecolor='lightblue',
            facecolor='none'
        )
        axs.add_patch(rect)
        # add label
        axs.text(
            idxs[2] + 128,
            idxs[0] + 128,
            str(i),
            color='blue',
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=6,
        )
    axs.set_axis_off()
    fig.savefig(
        os.path.join(os.path.join(negative_output_dir, "overviews"), f"{img_stem}.jpg"),
        bbox_inches='tight',
        dpi=600,
    )
    plt.close(fig)
    n_patch = 0

100%|██████████| 36/36 [01:15<00:00,  2.11s/it]
