In [1]:
import matplotlib.pyplot as plt
import numpy as np
import cv2 
import pandas as pd
import shutil
import os 
from glob import glob
from tqdm import tqdm
import torchvision.transforms as T
from PIL import Image

## Organize provided images into one folder for each class

In [2]:
provided_image_dir = "provided_screw_images"
output_image_dir = "provided_screw_datasets"
for s in ['train', 'validation', 'test']:
    for n in ['hex', 'irex', 'sq', 'tri', 'minus', 'plus', 'others']:
        os.makedirs(os.path.join(output_image_dir, s, n), exist_ok=True)

In [3]:
for s in ['train', 'validation', 'test']:
    image_fps = glob(os.path.join(provided_image_dir, s, "*.jpg"))
    label_df = pd.read_csv(os.path.join(provided_image_dir, s, "_classes.csv"))
    print(f"Num {s} images: {len(image_fps)} (images) | Num available labels: {len(label_df)} (rows)")
    for i in range(7):
        df = label_df[label_df.iloc[:,i+1] == 1]
        class_name = label_df.columns[i+1]
        class_name = class_name.lower().replace('-', '').strip()
        for idx, r in df.iterrows():
            filename = r['filename']
            fp = os.path.join(provided_image_dir, s, filename)
            out_fp = os.path.join(output_image_dir, s, class_name, filename)
            if os.path.isfile(fp):
                shutil.copy(fp, out_fp)

Num train images: 669 (images) | Num available labels: 669 (rows)
Num validation images: 63 (images) | Num available labels: 63 (rows)
Num test images: 33 (images) | Num available labels: 33 (rows)


## Create a mixed dataset with augmentation for captured screws

In [4]:
output_mixed_dir = "mixed_datasets"
os.makedirs(output_mixed_dir, exist_ok=True)
for s in ['train', 'validation', 'test']:
    os.makedirs(os.path.join(output_mixed_dir, s, "irex"), exist_ok=True)
    os.makedirs(os.path.join(output_mixed_dir, s, "others"), exist_ok=True)

In [5]:
provided_fps = glob(os.path.join(output_image_dir, "*/*/*.jpg"))
captured_fps = glob(os.path.join("captured_screw_datasets", "*/*/*.jpg"))
len(provided_fps), len(captured_fps)

(762, 541)

### Copy provided screw images into 2 final dataset folder

In [6]:
set_name = None
class_name = None
for fp in tqdm(provided_fps):
    if 'train' in fp:
        set_name = "train"
    elif 'validation' in fp:
        set_name = "validation"
    else:
        set_name = "test"
        
    if 'irex' in fp:
        class_name = "irex"
    else:
        class_name = "others" 

    output_fp = os.path.join(output_mixed_dir, set_name, class_name, os.path.basename(fp))
    shutil.copy(fp, output_fp)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 762/762 [00:05<00:00, 138.99it/s]


### Copy captured screw images into 2 final dataset folder (with augmentation)

In [7]:
# Define augmentation pipeline
transform = T.Compose([
    T.RandomRotation(degrees=25, expand=False, fill=[222, 162, 112]),  # Rotate randomly within ±30 degrees
    T.ColorJitter(brightness=0.35),  # Adjust brightness by a factor of ±20%
])

In [8]:
set_name = None
class_name = None
for fp in tqdm(captured_fps):
    if 'train' in fp:
        set_name = "train"
    elif 'validation' in fp:
        set_name = "validation"
    else:
        set_name = "test"
        
    if 'irex' in fp:
        class_name = "irex"
    else:
        class_name = "others" 

    output_fp = os.path.join(output_mixed_dir, set_name, class_name, os.path.basename(fp))
    shutil.copy(fp, output_fp)
    if "test" in fp:
        continue
    image = Image.open(fp)
    augmented_image_v1 = transform(image)
    augmented_image_v2 = transform(image)
    augmented_image_v1.save(output_fp.replace(".jpg", "_v1.jpg"))
    augmented_image_v2.save(output_fp.replace(".jpg", "_v2.jpg"))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 541/541 [00:02<00:00, 196.00it/s]


### Collect image names and put into a dataframe

In [9]:
mixed_dataset_dict = {"filename": [], "set": [], "is_augmented": [], "label": []}
for fp in tqdm(glob(os.path.join(output_mixed_dir, "*/*/*.jpg"))):
    set_name = fp.split("\\")[1]
    filename = os.path.basename(fp)
    if "irex" in fp:
        label = 1
    else:
        label = 0
    mixed_dataset_dict["filename"].append(filename)
    mixed_dataset_dict["set"].append(set_name)
    mixed_dataset_dict["is_augmented"].append(1 if "v1" in filename or "v2" in filename else 0)
    mixed_dataset_dict["label"].append(label)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2147/2147 [00:00<00:00, 178961.64it/s]


In [10]:
mixed_dataset_df = pd.DataFrame.from_dict(mixed_dataset_dict)
mixed_dataset_df.head()

Unnamed: 0,filename,set,is_augmented,label
0,22_png.rf.31843c8a3d74795b58f1e718b5eed556.jpg,test,0,1
1,pose07_1_2_3.jpg,test,0,1
2,pose07_1_3_2.jpg,test,0,1
3,pose07_1_3_3.jpg,test,0,1
4,pose07_2_2_2.jpg,test,0,1


In [18]:
mixed_dataset_df.shape, mixed_dataset_df[mixed_dataset_df.is_augmented == 0].shape, mixed_dataset_df[mixed_dataset_df.is_augmented == 1].shape

((2147, 4), (1301, 4), (846, 4))

In [15]:
mixed_dataset_df[mixed_dataset_df.set == "train"].shape, mixed_dataset_df[mixed_dataset_df.set == "validation"].shape, mixed_dataset_df[mixed_dataset_df.set == "test"].shape

((1533, 4), (462, 4), (152, 4))

In [20]:
mixed_dataset_df.to_csv("mixed_dataset_v1.csv", index=False)