In [6]:
import os
import numpy as np
import pandas as pd
import random
from pathlib import Path
from sklearn.model_selection import train_test_split

In [7]:
np.random.seed(0)

In [8]:
def partition(num_partitions, image_paths, label_map):
    # Split into `num_partitions` non-overlapping partitions
    partitions = [[] for _ in range(num_partitions)]
    labels = [[] for _ in range(num_partitions)]

    for i, image_path in enumerate(image_paths):
        partitions[i % num_partitions].append(image_path)
        labels[i % num_partitions].append(label_map[Path(image_path).parent.name])

    # Create a dictionary of DataFrames
    partition_dfs = {i: pd.DataFrame({"image_path" : partitions[i], "label" : labels[i]}) for i in range(num_partitions)}
    return partition_dfs

In [9]:
val_ratio = 0.2
min_n = 500

label_map = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
label_map = {name: index for index, name in enumerate(label_map)}

image_dirs = [Path('data/PACS/art_painting'), 
              Path('data/PACS/cartoon'),
              Path('data/PACS/photo'),
              Path('data/PACS/sketch')]

c_id = 0
for dist_num, image_dir in enumerate(image_dirs):

    # Get all png files in the directory
    image_paths = [str(p) for p in image_dir.glob("**/*.jpg")] + [str(p) for p in image_dir.glob("**/*.png")]  
    num_partitions = max(1,len(image_paths) // min_n)

    print(image_paths[0])
    # Shuffle the image paths
    random.shuffle(image_paths)

    test_len = int(val_ratio * len(image_paths))
    test_paths = image_paths[-test_len:]
    image_paths = image_paths[:-test_len]  

    partition_dfs = partition(num_partitions, image_paths, label_map)
    test_partition_dfs = partition(num_partitions, test_paths, label_map)

    for i in range(num_partitions):
        center_num = c_id
        train_paths, val_paths = train_test_split(partition_dfs[i], test_size=val_ratio)
        test_paths = test_partition_dfs[i]

        train_df = pd.DataFrame(train_paths, columns=["image_path","label"])
        val_df = pd.DataFrame(val_paths, columns=["image_path","label"])
        test_df = pd.DataFrame(test_paths, columns=["image_path","label"])

        train_df.to_csv(f"data/PACS/pacs_{min_n}_train_{center_num}.csv", index=False)
        val_df.to_csv(f"data/PACS/pacs_{min_n}_val_{center_num}.csv", index=False)
        test_df.to_csv(f"data/PACS/pacs_{min_n}_test_{center_num}.csv", index=False)

        print("train", len(train_df), "val", len(val_df), "test", len(test_df))
        c_id += 1

data/PACS/art_painting/dog/pic_078.jpg
train 328 val 82 test 103
train 328 val 82 test 102
train 328 val 82 test 102
train 327 val 82 test 102
data/PACS/cartoon/dog/pic_078.jpg
train 375 val 94 test 117
train 375 val 94 test 117
train 375 val 94 test 117
train 375 val 94 test 117
data/PACS/photo/dog/056_0089.jpg
train 356 val 90 test 112
train 356 val 89 test 111
train 356 val 89 test 111
data/PACS/sketch/dog/n02103406_13049-6.png
train 360 val 90 test 113
train 359 val 90 test 112
train 359 val 90 test 112
train 359 val 90 test 112
train 359 val 90 test 112
train 359 val 90 test 112
train 359 val 90 test 112


In [10]:
image_paths

['data/PACS/sketch/giraffe/n02439033_10130-4.png',
 'data/PACS/sketch/elephant/n02503517_12688-1.png',
 'data/PACS/sketch/elephant/n02503517_10300-5.png',
 'data/PACS/sketch/horse/n02374451_12225-3.png',
 'data/PACS/sketch/guitar/n03272010_7000-5.png',
 'data/PACS/sketch/dog/n02103406_6274-1.png',
 'data/PACS/sketch/guitar/n03272010_446-4.png',
 'data/PACS/sketch/dog/n02109525_3782-2.png',
 'data/PACS/sketch/giraffe/n02439033_12939-4.png',
 'data/PACS/sketch/giraffe/n02439033_11894-9.png',
 'data/PACS/sketch/horse/n02374451_6207-4.png',
 'data/PACS/sketch/dog/n02106662_7885-3.png',
 'data/PACS/sketch/guitar/n03272010_392-4.png',
 'data/PACS/sketch/horse/n02374451_597-2.png',
 'data/PACS/sketch/horse/n02374451_12418-9.png',
 'data/PACS/sketch/horse/n02374451_4553-4.png',
 'data/PACS/sketch/guitar/n03272010_12151-1.png',
 'data/PACS/sketch/elephant/n02503517_5267-2.png',
 'data/PACS/sketch/dog/5359.png',
 'data/PACS/sketch/guitar/n02676566_8268-4.png',
 'data/PACS/sketch/guitar/n03467517