In [None]:
import os
import shutil

def rearrange_images_by_class(images_dir, labels_dir, output_dir):
    # Create the main dataset folder if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Create class directories (class_0, class_1, ..., class_3)
    for class_id in range(4):  # We assume classes are 0, 1, 2, 3
        class_folder = os.path.join(output_dir, f'class_{class_id}')
        #print(class_folder)
        if not os.path.exists(class_folder):
            os.makedirs(class_folder)

    # Iterate through all image files in this subfolder
    for image_file in os.listdir(images_dir):
        #print(image_file)
            
        # Construct the corresponding label file path (same filename, but in 'gt' folder)
        label_file = image_file.replace('.png', '.txt')
        #print(label_file)
        #print(labels_dir)
    
        image_path = os.path.join(images_dir, image_file)
        #print(image_path)
        label_path = os.path.join(labels_dir, label_file)
        #print(label_path)
        #print(os.path.exists(label_path))

        # Check if the label file exists
        if os.path.exists(label_path):
            # Read the label from the txt file (should be an integer between 0 and 3)
            with open(label_path, 'r') as f:
                label = int(f.read().strip())
                #print(label)
            # Check if the label is valid (0 to 3)
            if 0 <= label <= 3:
                # Move the image to the appropriate class folder
                class_folder = os.path.join(output_dir, f'class_{label}')
                #print(class_folder)
                shutil.copy(image_path, class_folder)
            else:
                print(f"Warning: Invalid label {label} in {label_file}. Skipping...")
        else:
            print(f"Warning: Label file for {image_file} not found. Skipping...")

home_dir = os.path.expanduser("~/git")
dataset_for_training = "dataset9"
img_dir = home_dir + "/ml_project/datasets/symbols/" + dataset_for_training + "/data"
gt_dir = home_dir + "/ml_project/datasets/symbols/" + dataset_for_training + "/gt"
checkpoint_filepath = home_dir + "/ml_project/datasets/symbols/" + dataset_for_training + "/chpt/"
tf_dir = home_dir + "/ml_project/datasets/symbols/" + dataset_for_training + "_tf"

print(img_dir)
print(gt_dir)
print(tf_dir)

if __name__ == "__main__":
    images_dir = img_dir  # Change this to your 'data' folder path
    labels_dir = gt_dir  # Change this to your 'gt' folder path
    output_dir = tf_dir  # Change this to your desired output folder path

    rearrange_images_by_class(images_dir, labels_dir, output_dir)
