In [1]:
import os
import glob

In [2]:
datasets = os.listdir("dataset")

In [3]:
import random

def split_list(data, train_ratio=0.7, val_ratio=0.1, test_ratio=0.2):
    if not abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6:
        raise ValueError("Ratios must sum to 1.")
    
    random.shuffle(data)
    total = len(data)
    
    train_end = int(total * train_ratio)
    val_end = train_end + int(total * val_ratio)
    
    train = data[:train_end]
    val = data[train_end:val_end]
    test = data[val_end:]
    
    return train, val, test

# Example usage:
my_list = list(range(100))  # Replace with your actual list
train, val, test = split_list(my_list)

print(f"Train size: {len(train)}")
print(f"Val size: {len(val)}")
print(f"Test size: {len(test)}")

Train size: 70
Val size: 10
Test size: 20


In [4]:
train_sets = []
val_sets = []
test_sets = []

In [5]:
for dataset in datasets:
    dataset_path = os.path.join("dataset",dataset,"cropped_images")
    current_images = glob.glob(os.path.join(dataset_path, "*.jpg"))
    current_train, current_val, current_test = split_list(current_images)
    current_train = [os.path.join(dataset_path, img).replace("\\","/") for img in current_train]
    current_val = [os.path.join(dataset_path, img).replace("\\","/") for img in current_val]
    current_test = [os.path.join(dataset_path, img).replace("\\","/") for img in current_test]
    train_sets.extend(current_train)
    val_sets.extend(current_val)
    test_sets.extend(current_test)

In [6]:
print("len(train_sets):",len(train_sets))
print("len(val_sets):",len(val_sets))   
print("len(test_sets):",len(test_sets))

len(train_sets): 64226
len(val_sets): 9175
len(test_sets): 18352


In [7]:
def get_img_label(file):
    basename = os.path.basename(file)
    basename = basename.split("_")[-1]
    basename = basename.split(".")[0]
    try:
        label = [int(c) for c in basename]
        return label
    except:
        print("Error in file:",file)
        return [0,0,0,0,0,0]

In [8]:
# for each image in dataset, save the image path and its label in a cvs file
import pandas as pd
import csv
import numpy as np
def save_to_csv(image_paths, output_csv):
    # Create a DataFrame
    df = pd.DataFrame(image_paths, columns=['image_path'])
    # Add labels
    df['label'] = df['image_path'].apply(get_img_label)
    # Save to CSV
    df.to_csv(output_csv, index=False)

# Save train, val, test sets to CSV files
train_csv_path = "train.csv"
val_csv_path = "val.csv"
test_csv_path = "test.csv"
save_to_csv(train_sets, train_csv_path)
save_to_csv(val_sets, val_csv_path)
save_to_csv(test_sets, test_csv_path)