In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import cv2
import glob
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import pandas as pd
import os
import shutil

from empatches import EMPatches
from itertools import combinations
from sklearn.model_selection import StratifiedGroupKFold
from tqdm import tqdm

We select the `PATCH_SIZE` to be 512 pxls such that it contains sufficient context information to be able to discriminate between flawed and flawless structures. Since we have a class imbalance with positive (flaweless) patches being abundant, we set the overlap factor during patching here to 0. Applying an overlap factor of 0.5 for the negative (flawed) patches aims at increasing the amount of negative (flawed) patches (-> augmentation).

In [3]:
# define patching parameters
PATCH_SIZE = 512
PATCH_OVLP_POS = 0
PATCH_OVLP_NEG = 0.5

In [4]:
# define input data
data_dir = "data/FlawDetectionTrainingImages"
positive_imgs = glob.glob(os.path.join(data_dir, "positive", "*jpg"))
negative_imgs = glob.glob(os.path.join(data_dir, "negative", "*jpg"))

# define output directories
output_dir = "data/FlawDetectionTrainingImages/patches"
positive_output_dir = os.path.join(output_dir, "positive")
negative_output_dir = os.path.join(output_dir, "negative")

In [5]:
# a) split positive images into patches using no overlap

# remove existing patches
shutil.rmtree(positive_output_dir, ignore_errors=True)
os.makedirs(positive_output_dir, exist_ok=True)

# patch new ones
n_patch = 0
for img_path in tqdm(positive_imgs):
    img = cv2.imread(img_path)
    img_marked = cv2.imread(img_path.replace("negative", "negative_marked"))
    img_stem = os.path.split(img_path)[-1].split(".jpg")[0]
    emp = EMPatches()
    img_patches, indices = emp.extract_patches(img, patchsize=PATCH_SIZE, overlap=PATCH_OVLP_POS)
    # patch
    for i, patch in enumerate(img_patches):
        out_path = f"{img_stem}_{n_patch}.jpg"
        cv2.imwrite(os.path.join(positive_output_dir, out_path), patch)
        n_patch += 1

100%|██████████| 101/101 [00:07<00:00, 12.89it/s]


In [6]:
# b) split negative images into patches using larger overlap

# remove existing patches
shutil.rmtree(negative_output_dir, ignore_errors=True)
os.makedirs(negative_output_dir, exist_ok=True)
shutil.rmtree(os.path.join(negative_output_dir, "overviews"), ignore_errors=True)
os.makedirs(os.path.join(negative_output_dir, "overviews"), exist_ok=True)

# patch new ones
n_patch = 0
for img_path in tqdm(negative_imgs):
    img = cv2.imread(img_path)
    img_marked = cv2.imread(img_path.replace("negative", "negative_marked"))
    img_stem = os.path.split(img_path)[-1].split(".jpg")[0]
    emp = EMPatches()
    img_patches, indices = emp.extract_patches(img, patchsize=PATCH_SIZE, overlap=PATCH_OVLP_NEG)
    # patch
    for i, patch in enumerate(img_patches):
        out_path = f"{img_stem}_{n_patch}.jpg"
        cv2.imwrite(os.path.join(negative_output_dir, out_path), patch)
        n_patch += 1
    # plot overview
    fig, axs = plt.subplots(1, 1, figsize=(10, 10))
    axs.imshow(np.flip(img_marked, 2))
    for i, idxs in enumerate(indices):
        rect = patches.Rectangle(
            (idxs[2], idxs[0]),
            PATCH_SIZE,
            PATCH_SIZE,
            linewidth=1,
            edgecolor='lightblue',
            facecolor='none'
        )
        axs.add_patch(rect)
        # add label
        axs.text(
            idxs[2] + PATCH_SIZE//2,
            idxs[0] + PATCH_SIZE//2,
            str(i),
            color='blue',
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=6,
        )
    axs.set_axis_off()
    fig.savefig(
        os.path.join(os.path.join(negative_output_dir, "overviews"), f"{img_stem}.jpg"),
        bbox_inches='tight',
        dpi=600,
    )
    plt.close(fig)
    n_patch = 0

100%|██████████| 36/36 [00:52<00:00,  1.47s/it]


After this patching step, we manually inspect the resulting negative patches to split them into flawed and flawless ones. To this end, the created overviews can be used to quickly find the indices of patches being flawed. We create a new directory, where we copy-paste all created patches and move the ones that are flawless from the negative to the corresponding positive subfolder. The resulting directory called `.../patches_v1` is the input to create some splits for cross-validation as specified below.    

In [7]:
# create some splits in stratified manner for cross-validation
# stratification: a given image should be only in one fold

# define input data
patch_dir = "data/FlawDetectionTrainingImages/patches_v1"
patches = [
    *glob.glob(os.path.join(patch_dir, "positive", "*jpg")),
    *glob.glob(os.path.join(patch_dir, "negative", "*jpg"))
]

# create dataframe
df = pd.DataFrame(patches, columns=["path"])
df["label"] = df["path"].apply(lambda x: 0 if "positive" in x else 1)
df["group"] = df["path"].apply(lambda x: os.path.split(x)[-1].rsplit("_",1)[0])

# StratifiedGroupKFold - stratified split according to label, grouped by image
n_splits = 5
seed = 12
sgkf_l1 = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)

for fold, (train_idx, test_idx) in enumerate(sgkf_l1.split(df, df["label"], df["group"])):
    
    # retrieve pandas data frame
    _train_df = df.iloc[train_idx]
    test_df = df.iloc[test_idx]
    # split preliminary train further into train and val
    _train_df = _train_df.reset_index(drop=True)
    sgkf_l2 = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    for _, (train_idx, val_idx) in enumerate(sgkf_l2.split(_train_df, _train_df["label"], _train_df["group"])):
        # retrieve pandas data frame
        train_df = _train_df.iloc[train_idx]
        val_df = _train_df.iloc[val_idx]
    # sanity check - train and val/test set should not overlap
    assert len(set(train_df.group.unique()).intersection(set(val_df.group.unique()))) == 0
    assert len(set(train_df.group.unique()).intersection(set(test_df.group.unique()))) == 0

    # print stats
    print(f"Fold {fold}:")
    print(f"  Train size: {len(train_df)}, Flaws: {int(train_df.value_counts('label')[1])}")
    print(f"  Val size: {len(val_df)}, Flaws: {int(val_df.value_counts('label')[1])}")
    print(f"  Test size: {len(test_df)}, Flaws: {int(test_df.value_counts('label')[1])}")

    # merge train val and test set to single df with train/val/test labels
    train_df = train_df.copy()
    val_df = val_df.copy()
    test_df = test_df.copy()
    train_df.loc[:, "split"] = "train"
    val_df.loc[:, "split"] = "val"
    test_df.loc[:, "split"] = "test"
    split_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
    split_df.to_csv(os.path.join(patch_dir, f"split_{fold}.csv"), index=False)

Fold 0:
  Train size: 4032, Flaws: 116
  Val size: 1038, Flaws: 28
  Test size: 1128, Flaws: 41
Fold 1:
  Train size: 4032, Flaws: 118
  Val size: 920, Flaws: 29
  Test size: 1246, Flaws: 38
Fold 2:
  Train size: 4032, Flaws: 130
  Val size: 1008, Flaws: 24
  Test size: 1158, Flaws: 31
Fold 3:
  Train size: 3798, Flaws: 103
  Val size: 1096, Flaws: 46
  Test size: 1304, Flaws: 36
Fold 4:
  Train size: 3798, Flaws: 113
  Val size: 1038, Flaws: 33
  Test size: 1362, Flaws: 39


In [8]:
# create almost balanced cross-validation splits
# to see if network can be trained properly under these simplified conditions

for fold in range(n_splits):

    # subsample & balance patches (ratio 2:1)
    df = pd.read_csv(os.path.join(patch_dir, f"split_{fold}.csv"))
    pos_sample_ratio = 2 * df.value_counts("label")[1]
    df = pd.concat([
        df[df["label"] == 0].sample(pos_sample_ratio, random_state=42),
        df[df["label"] == 1]
    ])
    df = df.reset_index(drop=True)

    # sanity check: ensure image groups are non-overlapping between splits
    imgs_per_split = df.groupby("split").apply(lambda x: x.group.unique(), include_groups=False)
    split_combis = combinations(["train", "val", "test"], 2)
    for split_combi in split_combis:
        intersect = set(imgs_per_split[split_combi[0]]) & set(imgs_per_split[split_combi[1]])
        assert len(intersect) == 0

    # write to disk patches
    df.to_csv(
        os.path.join(patch_dir, f"split_balanced_{fold}.csv"),
        index=False
    )

    # print stats
    print(f"Fold {fold}:")
    for split in ["train", "val", "test"]:
        size = len(df[df['split'] == split])
        pos = df[df['split'] == split].value_counts('label')[1]
        print(f"  {split}: total size: {size}, flaws: {pos}, guess accuracy: {max(pos/size, 1-(pos/size)):.2f}")

Fold 0:
  train: total size: 368, flaws: 116, guess accuracy: 0.68
  val: total size: 88, flaws: 28, guess accuracy: 0.68
  test: total size: 99, flaws: 41, guess accuracy: 0.59
Fold 1:
  train: total size: 370, flaws: 118, guess accuracy: 0.68
  val: total size: 76, flaws: 29, guess accuracy: 0.62
  test: total size: 109, flaws: 38, guess accuracy: 0.65
Fold 2:
  train: total size: 382, flaws: 130, guess accuracy: 0.66
  val: total size: 80, flaws: 24, guess accuracy: 0.70
  test: total size: 93, flaws: 31, guess accuracy: 0.67
Fold 3:
  train: total size: 342, flaws: 103, guess accuracy: 0.70
  val: total size: 105, flaws: 46, guess accuracy: 0.56
  test: total size: 108, flaws: 36, guess accuracy: 0.67
Fold 4:
  train: total size: 349, flaws: 113, guess accuracy: 0.68
  val: total size: 92, flaws: 33, guess accuracy: 0.64
  test: total size: 114, flaws: 39, guess accuracy: 0.66
