In [1]:
import os
import random
from PIL import Image
import numpy as np
import pickle

In [2]:
#method to get all combinations of fixed and moving images
def make_combinations(files):
    file_pairs = []
    length = len(files)
    for i in range(length):
        if files[i][0] ==".": continue
        for j in range(i + 1, length):
            if files[j][0] ==".": continue
            file_pairs.append((files[i], files[j]))
            
    random.shuffle(file_pairs)
    return file_pairs

#save images as pickle file in pkl_path
def make_pickle(file1, file2, pkl_path):
    # Load and ensure images are in RGB format
    img1 = Image.open(file1).convert("RGB")
    img2 = Image.open(file2).convert("RGB")

    # Convert images to numpy arrays
    img1_rgb = np.array(img1, dtype=np.float32)  # Ensure type matches expectations
    img2_rgb = np.array(img2, dtype=np.float32)

    # Convert to grayscale
    img1_gray = np.array(img1.convert("L"), dtype=np.float32)
    img2_gray = np.array(img2.convert("L"), dtype=np.float32)

    # Prepare data as a tuple (x, y, x_gray, y_gray)
    data = (img1_rgb, img2_rgb, img1_gray, img2_gray)

    # Save data to a pickle file
    with open(pkl_path, "wb") as f:
        pickle.dump(data, f)

In [16]:
# Example File Paths
# mnist_path = "/scratch/udz8gm/registrationDataset/mnist"
# quick_draw_path = "/scratch/udz8gm/registrationDataset/google_quickdraw"
# data_path = "/home/udz8gm/TransMorph2D/dataset"

#UPDATE file paths here
mnist_path = "registrationDataset/mnist"
quick_draw_path = "registrationDataset/google_quickdraw"
data_path = "dataset"

#specify train, test, validation subpaths
train_path = data_path+"/train"
val_path = data_path+"/val"
test_path = data_path+"/test"

os.makedirs(data_path, exist_ok=True)
os.makedirs(train_path, exist_ok=True)
os.makedirs(val_path, exist_ok=True)
os.makedirs(test_path, exist_ok=True)

globalCounter = 0

In [17]:
# make training and validation data
for digit_folder in os.listdir(mnist_path):
    if digit_folder[0] ==".": continue
    digit_folder_path = f"{mnist_path}/{digit_folder}"
    for type_folder in os.listdir(digit_folder_path):
        if type_folder[0] ==".": continue
        type_folder_path = f"{digit_folder_path}/{type_folder}"
        files = os.listdir(type_folder_path)
        file_pairs = make_combinations(files)
        sep = int(len(file_pairs)*0.8)
        train_files, val_files = file_pairs[:sep], file_pairs[sep:]
        
        for f1, f2 in train_files:
            globalCounter +=1
            f1 = f"{type_folder_path}/{f1}"
            f2 = f"{type_folder_path}/{f2}"
            output = f"{train_path}/{globalCounter}.pkl"
            make_pickle(f1, f2, output)
        
        for f1, f2 in val_files:
            globalCounter +=1
            f1 = f"{type_folder_path}/{f1}"
            f2 = f"{type_folder_path}/{f2}"
            output = f"{val_path}/{globalCounter}.pkl"
            make_pickle(f1, f2, output)

In [9]:
#make testing data
for drawing_folder in os.listdir(quick_draw_path):
    if drawing_folder[0] ==".": continue
    drawing_path = f"{quick_draw_path}/{drawing_folder}"
    test_files = os.listdir(drawing_path)
    file_pairs = make_combinations(test_files)
    for f1, f2 in file_pairs:
        globalCounter +=1
        f1 = f"{drawing_path}/{f1}"
        f2 = f"{drawing_path}/{f2}"
        output = f"{test_path}/{globalCounter}.pkl"
        make_pickle(f1, f2, output)