In [None]:
# Importing the packages 
import pandas as pd 
import os 
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# Defining which dir to analyse 
dir_to_analyse = 'dataset'

In [None]:
# Defining the directory for the images 
image_dir = os.path.join(os.getcwd(), '..', 'dataset', 'images', 'temp')
label_dir = os.path.join(os.getcwd(), '..', 'dataset', 'labels', 'temp')

In [None]:
# Defining the class dictionary 
class_dict = {
    '0': 'closed', 
    '1': 'opened'
}

In [None]:
# Defining a function that given an image path and the coordinate for the label, 
# it will extract the label from the image and save it in the label directory
output_label_dir = os.path.join(os.getcwd(), 'label_inference', dir_to_analyse)
if not os.path.exists(output_label_dir):
    os.makedirs(output_label_dir)

# Creating subdirectories for each class
for class_name in class_dict.values():
    if not os.path.exists(os.path.join(output_label_dir, class_name)):
        os.makedirs(os.path.join(output_label_dir, class_name))

def extract_label(image_path: str, label: list, label_index: int):
    # Reading the image 
    image = plt.imread(image_path)

    # Getting the image width and height
    image_width = image.shape[1]
    image_height = image.shape[0]

    # Extracting the x, y, width and height of the bounding box
    x = label[1]
    y = label[2]
    width = label[3]
    height = label[4]

    # Converting to pixel values
    x = int(x * image_width)
    y = int(y * image_height)
    width = int(width * image_width)
    height = int(height * image_height)

    # Calculating the anchor points
    x1 = int(x - width/2)
    y1 = int(y - height/2)
    x2 = int(x + width/2)
    y2 = int(y + height/2)
    
    # Getting the class name 
    class_name = class_dict[str(int(label[0]))]

    # Extracting the label 
    label_image = image[y1:y2, x1:x2]

    # Defining the output name 
    output_name = f"{class_name}_{os.path.basename(image_path).split('.')[0]}_{label_index}.jpg"

    # Saving the label 
    plt.imsave(os.path.join(output_label_dir, class_name, output_name), label_image)

In [None]:
# Listing all the images 
images = os.listdir(image_dir)

In [None]:
# Listing all the images 
images = os.listdir(image_dir)

# Iterating over the images 
for image in tqdm(images):
    try:
        # Reading the label 
        label = pd.read_csv(os.path.join(label_dir, image.split('.')[0] + '.txt'), sep = ' ', header = None)
    except Exception as e:
        print(e)
        continue

    # Iterating over the rows 
    for index, row in label.iterrows():
        # Extracting the label 
        try:
            extract_label(os.path.join(image_dir, image), row.values, index)
        except Exception as e: 
            print(e)