In [1]:
import os
import h5py, tifffile
import numpy as np
import random, cv2

os.chdir("../scripts/data_preparation")  
label_shape = (100, 3)

Setup   

In [2]:
base_folder = "/home/brian/data4/brian/PBnJ/jelly_h5s/RFamide/mixed_datasets"

raw = True # If true, we will make the h5s from raw tif files provided
frames_per_dataset = 100 # How many frames from each dataset we should sample

add_extra = False # Should we add more frames from an nrrd to augment the labeled frames
num_frames_to_add = 500 # If we add extra frames, how many should we add
padding = "single" # "single" or "double" - Do you want to pad on both the top and bottom or just the bottom
side_len = 1024 # What the cropped image size should be (enforcing square) should prob be 1024, 1080, or 1200

augmentation = None # Options are "log" "sqrt" None

assert padding in ["single", "double"], f"Padding must be 'single' or 'double', not {padding}"

## Make the dataset from scratch if needed

Assumes you want all moving to one fixed frame

TODO: Reconsider this

In [5]:
if raw:
    assert frames_per_dataset
    n_frames = 2000 # How many frames are in each tif file
    channel = 1 # Which channel to use (1 in RFa, 0 in normal - check)
    input_files = {"/storage/fs/store1/brian/swimming_videos_RFa/Folder_20250219110008_RFa_swim" : [
                                    "20250219_Experiment_01_0-1999.tif",
                                    "20250219_Experiment_01_2000-3999.tif",
                                    "20250219_Experiment_01_4000-5999.tif",
                                    "20250219_Experiment_01_8000-9999.tif",
                                    "20250219_Experiment_01_6000-7999.tif"],
                    "/storage/fs/store1/brian/swimming_videos_RFa/Folder_20250219120831_RFa": [
                                    "20250219_Experiment_01_0-1999.tif",
                                    "20250219_Experiment_01_2000-3999.tif",
                                    "20250219_Experiment_01_4000-5999.tif",
                                    "20250219_Experiment_01_6000-7999.tif",
                                    "20250219_Experiment_01_8000-9999.tif"],
                    "/storage/fs/store1/brian/swimming_videos_RFa/Folder_20250214153740_RFa": [
                                    "20250214_Experiment_01_0-1999.tif",
                                    "20250214_Experiment_01_2000-3999.tif",
                                    "20250214_Experiment_01_4000-5999.tif"],
                    }

    unlabs = np.full((30, 2), -1)

    with h5py.File(os.path.join(base_folder, "moving_images.h5"), 'w-') as h5m, h5py.File(os.path.join(base_folder, "fixed_images.h5"), 'w-') as h5f, h5py.File(
            os.path.join(base_folder, "moving_labels.h5"), 'w-') as h5ml, h5py.File(os.path.join(base_folder, "fixed_labels.h5"), 'w-') as h5fl:
        dataset_ind = 0
        for dataset_name, file_list in input_files.items():
            file_list = [os.path.join(dataset_name, f) for f in file_list]

            dataset_name = os.path.basename(dataset_name)

            print(f"Processing {dataset_name}...")
            file_list = list(file_list)
            random.shuffle(file_list)

            total_frames = len(file_list) * n_frames
            if frames_per_dataset > total_frames:
                raise ValueError(f"Requested {frames_per_dataset} frames, but only {total_frames} available in {dataset_name}.")

            # Sample global frame indices
            selected_global_indices = sorted(random.sample(range(total_frames), frames_per_dataset))
            with open(os.path.join(base_folder, "frame_log.txt"), 'a') as f: # Save a log of the indices
                f.write(f"# {dataset_name}\n")
                for idx in selected_global_indices:
                    f.write(f"{idx}\n")

            # Iterate and extract frames
            current_global_index = 0
            saved_count = 0
            for tif_path in file_list:
                start, end = map(int, tif_path.split("_")[-1].replace(".tif", "").split("-"))

                with tifffile.TiffFile(tif_path) as tif:
                    arr = tif.asarray()
                    if arr.ndim != 4:
                        raise ValueError(f"Expected shape (T, C, H, W) but got {arr.shape} in {tif_path}")
                    
                    if start == 0:
                        # Save the fixed image
                        h5f.create_dataset(f"{dataset_ind}_0to{dataset_ind}_0", data=arr[0, channel])
                        h5fl.create_dataset(f"{dataset_ind}_0to{dataset_ind}_0", data=unlabs)

                    for idx in selected_global_indices:
                        if idx >= start and idx <= end:
                            local_idx = idx - start
                            frame = arr[local_idx, channel]  # shape (H, W)
                            ds_name = f"{dataset_ind}_{idx}to{dataset_ind}_0" # We're going to add the first number to diferentiate the datasets
                            h5m.create_dataset(ds_name, data=frame)
                            h5ml.create_dataset(ds_name, data=unlabs)

                            saved_count += 1

            dataset_ind += 1
            print(f"Saved {saved_count} frames for {dataset_name}")

            

Processing Folder_20250219110008_RFa_swim...
Saved 100 frames for Folder_20250219110008_RFa_swim
Processing Folder_20250219120831_RFa...
Saved 100 frames for Folder_20250219120831_RFa
Processing Folder_20250214153740_RFa...
Saved 100 frames for Folder_20250214153740_RFa


# Clone fixed image to train
Start with images (formatted as just 1 fixed image from Weissbourd)

In [8]:
old_fixed_h5 = os.path.join(base_folder, "fixed_images.h5")
new_fixed_h5 = os.path.join(base_folder, "fixed_fixed_images.h5")
moving_h5 = os.path.join(base_folder, "moving_images.h5")
with h5py.File(old_fixed_h5, 'r') as f, h5py.File(moving_h5, 'r') as g, h5py.File(new_fixed_h5, 'w-') as fo:
    if len(f.keys()) == 1:
        img = f[list(f.keys())[0]][:]
        for prob in g.keys():
            fo.create_dataset(prob, data = img)
    else:
        assert len(set([f.split("_")[-1] for f in f.keys()])) == 1, "Expected a file with either one image or one image per dataset"
        base_imgs = {}
        for im in f.keys():
            base_imgs[im.split("_")[0]] = f[im][:]
        for prob in g.keys():
            fo.create_dataset(prob, data = base_imgs[prob.split("_")[0]])

Then do the same for labels

In [9]:
old_fixed_h5 = os.path.join(base_folder, "fixed_labels.h5")
new_fixed_h5 = os.path.join(base_folder, "fixed_fixed_labels.h5")
moving_h5 = os.path.join(base_folder, "moving_labels.h5")

with h5py.File(old_fixed_h5, 'r') as f, h5py.File(moving_h5, 'r') as g, h5py.File(new_fixed_h5, 'w-') as fo:
    if len(f.keys()) == 1:
        labs = f[list(f.keys())[0]][:]
        for prob in g.keys():
            fo.create_dataset(prob, data = labs)
    else:
        assert len(set([f.split("_")[-1] for f in f.keys()])) == 1, "Expected a file with either one image or one image per dataset"
        base_labs = {}
        for im in f.keys():
            base_labs[im.split("_")[0]] = f[im][:]
        for prob in g.keys():
            fo.create_dataset(prob, data = base_labs[prob.split("_")[0]])

Now add a third column so the labels have three dimensions

In [10]:
old_fixed_h5 = os.path.join(base_folder, "fixed_fixed_labels.h5")
new_fixed_h5 = os.path.join(base_folder, "fixed_fixed_fixed_labels.h5")
with h5py.File(old_fixed_h5, 'r') as f:
    assert f[list(f.keys())[0]][:].shape[1] == 2, "This is meant to expand the labels from 2D to 3D"
    with h5py.File(new_fixed_h5, 'w-') as g:
        for prob in f.keys():
            labs = f[prob][:]
            labs = np.concatenate((labs, np.zeros((labs.shape[0], 1))), 1)
            labs = np.pad(labs, [(0,label_shape[0] - labs.shape[0]), (0,0)], "constant", constant_values=-1)
            g.create_dataset(prob, data = labs)

In [11]:
old_moving_h5 = os.path.join(base_folder, "moving_labels.h5")
new_moving_h5 = os.path.join(base_folder, "fixed_moving_labels.h5")
with h5py.File(old_moving_h5, 'r') as f:
    assert f[list(f.keys())[0]][:].shape[1] == 2, "This is meant to expand the labels from 2D to 3D"
    with h5py.File(new_moving_h5, 'w-') as g:
        for prob in f.keys():
            labs = f[prob][:]
            labs = np.concatenate((labs, np.zeros((labs.shape[0], 1))), 1)
            labs = np.pad(labs, [(0,label_shape[0] - labs.shape[0]), (0,0)], "constant", constant_values=-1)
            g.create_dataset(prob, data = labs)

And add a third dim for the images

In [12]:
old_fixed_h5 = os.path.join(base_folder, "fixed_fixed_images.h5")
new_fixed_h5 = os.path.join(base_folder, "fixed_fixed_fixed_images.h5")
with h5py.File(old_fixed_h5, 'r') as f:
    assert len(f[list(f.keys())[0]][:].shape) == 2, "This is meant to expand the images from 2D to 3D"
    with h5py.File(new_fixed_h5, 'w-') as g:
        for prob in f.keys():
            img = f[prob][:]
            img = np.expand_dims(img, 2)
            g.create_dataset(prob, data = img)

In [13]:
old_moving_h5 = os.path.join(base_folder, "moving_images.h5")
new_moving_h5 = os.path.join(base_folder, "fixed_moving_images.h5")
with h5py.File(old_moving_h5, 'r') as f:
    assert len(f[list(f.keys())[0]][:].shape) == 2, "This is meant to expand the images from 2D to 3D"
    with h5py.File(new_moving_h5, 'w-') as g:
        for prob in f.keys():
            img = f[prob][:]
            img = np.expand_dims(img, 2)
            g.create_dataset(prob, data = img)

# Take unlabeled frames and add to a labeled h5

In [14]:
if add_extra:

    old_img_fixed_h5 = "/home/brian/data4/brian/PBnJ/jelly_h5s/full_lab_movds/fixed_images.h5"
    new_img_fixed_h5 = "/home/brian/data4/brian/PBnJ/jelly_processed_data/mixed_lab_h5/fixed_images.h5"

    old_label_fixed_h5 = "/home/brian/data4/brian/PBnJ/jelly_h5s/full_lab_movds/fixed_labels.h5"
    new_label_fixed_h5 = "/home/brian/data4/brian/PBnJ/jelly_processed_data/mixed_lab_h5/fixed_labels.h5"

    old_img_moving_h5 = "/home/brian/data4/brian/PBnJ/jelly_h5s/full_lab_movds/moving_images.h5"
    new_img_moving_h5 = "/home/brian/data4/brian/PBnJ/jelly_processed_data/mixed_lab_h5/moving_images.h5"

    old_label_moving_h5 = "/home/brian/data4/brian/PBnJ/jelly_h5s/full_lab_movds/moving_labels.h5"
    new_label_moving_h5 = "/home/brian/data4/brian/PBnJ/jelly_processed_data/mixed_lab_h5/moving_labels.h5"

    frame_folder = "/home/brian/data4/brian/PBnJ/jelly_centroid_prep/zoomed_in_vid"


    frame_names = [
        os.path.splitext(p)[0] for p in os.listdir(frame_folder)
        if os.path.splitext(p)[-1] in [".jpg"]
    ]

    probs = []

    for i in range(num_frames_to_add):
        probs.append(tuple(random.sample(frame_names, 2)))

    # old_mean = 0
    old_max = 0
    # num = 0

    with h5py.File(new_img_fixed_h5, 'w-') as nif:
        with h5py.File(new_label_fixed_h5, 'w-') as nlf:
            with h5py.File(old_img_fixed_h5, 'r') as oif:
                with h5py.File(old_label_fixed_h5, 'r') as olf:
                    for prob in oif.keys():
                        img = oif[prob][:]
                        labs = olf[prob][:]

                        # old_mean += np.mean(img)
                        old_max = max(old_max, np.max(img))
                        # num += 1

                        nif.create_dataset(prob, data = img, dtype=float)
                        nlf.create_dataset(prob, data = labs, dtype=np.float32)
            
            # old_mean = old_mean / num
            scale_factor = old_max / 255 # Photo max val
            
            unlabeled_labs = np.ones_like(labs, dtype=np.float32) * -1
            for fixed, moving in probs:
                prob = f"{moving}to{fixed}"
                img = cv2.imread(os.path.join(frame_folder, fixed + ".jpg"))[:,:,0:1]
                nif.create_dataset(prob, data = img * scale_factor, dtype=float)
                nlf.create_dataset(prob, data = unlabeled_labs, dtype=np.float32)





    with h5py.File(new_img_moving_h5, 'w-') as nim:
        with h5py.File(new_label_moving_h5, 'w-') as nlm:
            with h5py.File(old_img_moving_h5, 'r') as oim:
                with h5py.File(old_label_moving_h5, 'r') as olm:
                    for prob in oim.keys():
                        img = oim[prob][:]
                        labs = olm[prob][:]
                        nim.create_dataset(prob, data = img, dtype=float)
                        nlm.create_dataset(prob, data = labs, dtype=np.float32)

            for fixed, moving in probs:
                prob = f"{moving}to{fixed}"
                img = cv2.imread(os.path.join(frame_folder, moving + ".jpg"))[:,:,0:1]
                nim.create_dataset(prob, data = img * scale_factor, dtype=float)
                nlm.create_dataset(prob, data = unlabeled_labs, dtype=np.float32)

## Crop

In [15]:
in_dir = base_folder
out_dir = os.path.join(base_folder, "cropped")

crop_shape = np.array((side_len, side_len, 1))

os.makedirs(out_dir, exist_ok=True)

with h5py.File(os.path.join(in_dir, "fixed_moving_images.h5"), 'r') as imv,  h5py.File(os.path.join(in_dir, "fixed_moving_labels.h5"), 'r') as lmv, h5py.File(
    os.path.join(out_dir, "moving_images.h5"), 'w-') as oimv,  h5py.File(os.path.join(out_dir, "moving_labels.h5"), 'w-') as olmv:
    with h5py.File(os.path.join(in_dir, "fixed_fixed_fixed_images.h5"), 'r') as ifx,  h5py.File(os.path.join(in_dir, "fixed_fixed_fixed_labels.h5"), 'r') as lfx, h5py.File(
        os.path.join(out_dir, "fixed_images.h5"), 'w-') as oifx,  h5py.File(os.path.join(out_dir, "fixed_labels.h5"), 'w-') as olfx:
        for prob in imv.keys():
            img = imv[prob][:]
            imgF = ifx[prob][:]

            crop_offset = (img.shape - crop_shape) / 2
            assert np.all(crop_offset == crop_offset.astype(int))
            crop_offset = crop_offset.astype(int)

            ## Crop the images
            img = img[crop_offset[0]:crop_offset[0] + crop_shape[0], crop_offset[1]:crop_offset[1] + crop_shape[1], crop_offset[2]:crop_offset[2] + crop_shape[2]]
            assert np.all(img.shape == crop_shape)

            imgF = imgF[crop_offset[0]:crop_offset[0] + crop_shape[0], crop_offset[1]:crop_offset[1] + crop_shape[1], crop_offset[2]:crop_offset[2] + crop_shape[2]]
            assert np.all(imgF.shape == crop_shape), imgF.shape

            ## Adjust the labels
            labs = lmv[prob][:]
            neg_ones = labs < 0
            labs = labs - crop_offset 

            labsF = lfx[prob][:]
            neg_onesF = labsF < 0
            labsF = labsF - crop_offset
            assert np.all(neg_ones == neg_onesF), "Fixed and Moving labels have a different number of non negative 1 labels"
            
            # Remove centroids that are cropped out of the image
            crop = np.logical_or(labs < 0, labs >= crop_shape)
            crop = np.max(crop, axis=-1)
            cropF = np.logical_or(labsF < 0, labsF >= crop_shape)
            cropF = np.max(cropF, axis=-1)
            crop = np.logical_or(crop, cropF) # If either the fixed or moving are out of bounds then exclude both

            labs[crop] = -1
            labsF[crop] = -1

            labs[neg_ones] = -1 # Retain -1s
            labsF[neg_ones] = -1 # Retain -1s

            # Double check that all of the out of frame centroids are gone
            assert np.all(np.logical_or(labs >= 0, labs == -1)), "The crop results in moving centroids that are under bounds"
            assert np.all(labs < crop_shape[0]), "The crop results in moving centroids that are over bounds"
            assert np.all(np.logical_or(labsF >= 0, labsF == -1)), "The crop results in fixed centroids that are under bounds"
            assert np.all(labsF < crop_shape[0]), "The crop results in fixed centroids that are over bounds"

            assert np.all((labs == -1) == (labsF == -1)), "The labels that are excluded are not the same between moving and fixed"


            oimv.create_dataset(prob, data = img)
            olmv.create_dataset(prob, data = labs)
            oifx.create_dataset(prob, data = imgF)
            olfx.create_dataset(prob, data = labsF)


## Add Padding

In [16]:
in_dir = os.path.join(base_folder, "cropped")
out_dir = os.path.join(base_folder, "cropped", "padded")

os.makedirs(out_dir, exist_ok=True)

if padding == "single":
    padding = np.array([[0,0],[0,0],[1,0]])
else:
    padding = np.array([[0,0],[0,0],[1,1]])

with h5py.File(os.path.join(in_dir, "moving_images.h5"), 'r') as imv,  h5py.File(os.path.join(in_dir, "moving_labels.h5"), 'r') as lmv:
    with h5py.File(os.path.join(out_dir, "moving_images.h5"), 'w-') as oimv,  h5py.File(os.path.join(out_dir, "moving_labels.h5"), 'w-') as olmv:
        for prob in imv.keys():
            img = imv[prob][:]
            # img = np.pad(img, padding, "constant", constant_values=0)
            img = np.pad(img, padding, "constant", constant_values=np.min(img))

            labs = lmv[prob][:]
            neg_ones = labs < 0
            labs = labs + padding[:, 0] 
            labs[neg_ones] = -1 # Retain -1s

            oimv.create_dataset(prob, data = img)
            olmv.create_dataset(prob, data = labs)



with h5py.File(os.path.join(in_dir, "fixed_images.h5"), 'r') as ifx,  h5py.File(os.path.join(in_dir, "fixed_labels.h5"), 'r') as lfx:
    with h5py.File(os.path.join(out_dir, "fixed_images.h5"), 'w-') as oifx,  h5py.File(os.path.join(out_dir, "fixed_labels.h5"), 'w-') as olfx:
        for prob in ifx.keys():
            img = ifx[prob][:]
            # img = np.pad(img, padding, "constant", constant_values=0)
            img = np.pad(img, padding, "constant", constant_values=np.min(img))

            labs = lfx[prob][:]
            neg_ones = labs < 0
            labs = labs + padding[:, 0] 
            labs[neg_ones] = -1 # Retain -1s
            
            oifx.create_dataset(prob, data = img)
            olfx.create_dataset(prob, data = labs)


## Ceiling
#### log

In [17]:
if augmentation == "sqrt":

    in_dir = os.path.join(base_folder, "cropped", "padded")
    out_dir = os.path.join(base_folder, "cropped", "padded", "log_scaled")

    os.makedirs(out_dir, exist_ok=True)

    with h5py.File(os.path.join(in_dir, "moving_images.h5"), 'r') as imv,  h5py.File(os.path.join(in_dir, "moving_labels.h5"), 'r') as lmv:
        with h5py.File(os.path.join(out_dir, "moving_images.h5"), 'w-') as oimv,  h5py.File(os.path.join(out_dir, "moving_labels.h5"), 'w-') as olmv:
            for prob in imv.keys():
                img = imv[prob][:]
                labs = lmv[prob][:]
                
                img = np.log2(img + 1, dtype=np.float32)
                # img = np.log(img + 1, dtype=np.float32)

                oimv.create_dataset(prob, data = img)
                olmv.create_dataset(prob, data = labs)



    with h5py.File(os.path.join(in_dir, "fixed_images.h5"), 'r') as ifx,  h5py.File(os.path.join(in_dir, "fixed_labels.h5"), 'r') as lfx:
        with h5py.File(os.path.join(out_dir, "fixed_images.h5"), 'w-') as oifx,  h5py.File(os.path.join(out_dir, "fixed_labels.h5"), 'w-') as olfx:
            for prob in ifx.keys():
                img = ifx[prob][:]
                labs = lfx[prob][:]
                
                img = np.log2(img + 1, dtype=np.float32)
                # img = np.log(img + 1, dtype=np.float32)
                
                oifx.create_dataset(prob, data = img)
                olfx.create_dataset(prob, data = labs)


#### sqrt

In [18]:
if augmentation == "log":

    in_dir = os.path.join(base_folder, "cropped", "padded")
    out_dir = os.path.join(base_folder, "cropped", "padded", "sqrt_scaled")



    with h5py.File(os.path.join(in_dir, "moving_images.h5"), 'r') as imv,  h5py.File(os.path.join(in_dir, "moving_labels.h5"), 'r') as lmv:
        with h5py.File(os.path.join(out_dir, "moving_images.h5"), 'w-') as oimv,  h5py.File(os.path.join(out_dir, "moving_labels.h5"), 'w-') as olmv:
            for prob in imv.keys():
                img = imv[prob][:]
                labs = lmv[prob][:]
                
                img = np.sqrt(img, dtype=np.float32)

                oimv.create_dataset(prob, data = img)
                olmv.create_dataset(prob, data = labs)



    with h5py.File(os.path.join(in_dir, "fixed_images.h5"), 'r') as ifx,  h5py.File(os.path.join(in_dir, "fixed_labels.h5"), 'r') as lfx:
        with h5py.File(os.path.join(out_dir, "fixed_images.h5"), 'w-') as oifx,  h5py.File(os.path.join(out_dir, "fixed_labels.h5"), 'w-') as olfx:
            for prob in ifx.keys():
                img = ifx[prob][:]
                labs = lfx[prob][:]
                
                img = np.sqrt(img, dtype=np.float32)
                
                oifx.create_dataset(prob, data = img)
                olfx.create_dataset(prob, data = labs)


In [19]:
with h5py.File(new_fixed_h5, 'r') as f:
    probs = list(f.keys())
    print(probs)
    print(len(probs))

['0_1087to0_0', '0_1197to0_0', '0_1226to0_0', '0_1299to0_0', '0_1372to0_0', '0_1399to0_0', '0_142to0_0', '0_1451to0_0', '0_1575to0_0', '0_1579to0_0', '0_1599to0_0', '0_1613to0_0', '0_1624to0_0', '0_1630to0_0', '0_1649to0_0', '0_1893to0_0', '0_2003to0_0', '0_2177to0_0', '0_2224to0_0', '0_2238to0_0', '0_224to0_0', '0_2413to0_0', '0_2475to0_0', '0_2676to0_0', '0_2703to0_0', '0_2795to0_0', '0_2963to0_0', '0_2984to0_0', '0_3465to0_0', '0_3487to0_0', '0_3660to0_0', '0_3696to0_0', '0_3738to0_0', '0_3879to0_0', '0_4013to0_0', '0_4077to0_0', '0_407to0_0', '0_4255to0_0', '0_4493to0_0', '0_4575to0_0', '0_4599to0_0', '0_474to0_0', '0_4852to0_0', '0_4908to0_0', '0_4988to0_0', '0_4991to0_0', '0_5055to0_0', '0_5226to0_0', '0_5342to0_0', '0_5481to0_0', '0_5487to0_0', '0_5505to0_0', '0_5541to0_0', '0_5558to0_0', '0_5677to0_0', '0_5757to0_0', '0_5891to0_0', '0_591to0_0', '0_5933to0_0', '0_5985to0_0', '0_6012to0_0', '0_6122to0_0', '0_6123to0_0', '0_6215to0_0', '0_653to0_0', '0_6716to0_0', '0_680to0_0', '

# Split into training and val

In [20]:
val_probs = random.sample(probs, 10)
print(val_probs)

['1_2442to1_0', '1_8141to1_0', '2_3471to2_0', '2_4169to2_0', '1_8850to1_0', '0_9583to0_0', '1_953to1_0', '1_1044to1_0', '2_5891to2_0', '2_94to2_0']


In [21]:
# val_probs = [
#             "3563to15011",
#             "5004to15011",
#             "6739to15011",
#             "8110to15011",
#             "985to15011"
#         ]
# val_probs = ['4175to0000', '8140to0000', '5347to0000', '5683to0000', '5127to0000', '4480to0000', '5002to0000', '5417to0000', '8261to0000', '4215to0000', '5242to0000', '5753to0000', '5067to0000', '4455to0000', '4981to0000', '5082to0000', '5868to0000', '5843to0000', '5437to0000', '5948to0000', '8014to0000', '4871to0000', '5152to0000', '4691to0000', '8351to0000', '5748to0000', '5643to0000', '2841to0000', '4200to0000', '4035to0000', '8289to0000', '5728to0000', '5037to0000', '4606to0000', '4821to0000', '5563to0000', '5693to0000', '8002to0000', '5057to0000', '8273to0000', '5462to0000', '8203to0000', '4831to0000', '8297to0000', '4956to0000', '4450to0000', '5302to0000', '8022to0000', '4656to0000', '3999to0000', '5593to0000', '5653to0000', '4936to0000', '4561to0000', '5192to0000', '4060to0000', '5257to0000', '8367to0000', '5117to0000', '8257to0000', '4811to0000', '7683to0000', '5327to0000', '5402to0000', '8196to0000', '2105to0000', '4651to0000', '4566to0000', '4766to0000']
# val_probs = ['2984to0000', '3838to0000', '3315to0000', '0765to0000', '2572to0000', '3235to0000', '5215to0000', '0738to0000', '2683to0000', '5295to0000', '1662to0000', '4472to0000', '2773to0000', '5727to0000', '3305to0000', '1389to0000', '5396to0000', '4010to0000', '4883to0000', '4190to0000', '1688to0000', '1125to0000', '1468to0000', '2833to0000', '3657to0000', '1512to0000', '0712to0000', '1037to0000', '0413to0000', '2210to0000', '3567to0000', '0923to0000', '2522to0000', '3546to0000', '0263to0000', '5125to0000', '0395to0000', '0536to0000', '3205to0000', '5999to0000', '2552to0000', '1116to0000', '1407to0000', '0571to0000', '5506to0000', '5235to0000', '0114to0000', '3737to0000', '3587to0000']

if augmentation == None:
    base_dir = os.path.join(base_folder, "cropped", "padded")
elif augmentation == "sqrt":
    base_dir = os.path.join(base_folder, "cropped", "padded", "sqrt_scaled")    
elif augmentation == "log":
    base_dir = os.path.join(base_folder, "cropped", "padded", "log_scaled") 
else:
    raise ValueError(f"Unknown augmentation {augmentation}")

os.mkdir(os.path.join(base_dir, "train"))
os.mkdir(os.path.join(base_dir, "val"))

old_img_fixed_h5 = f"{base_dir}/fixed_images.h5"
train_img_fixed_h5 = f"{base_dir}/train/fixed_images.h5"
val_img_fixed_h5 = f"{base_dir}/val/fixed_images.h5"

old_label_fixed_h5 = f"{base_dir}/fixed_labels.h5"
train_label_fixed_h5 = f"{base_dir}/train/fixed_labels.h5"
val_label_fixed_h5 = f"{base_dir}/val/fixed_labels.h5"

old_img_moving_h5 = f"{base_dir}/moving_images.h5"
train_img_moving_h5 = f"{base_dir}/train/moving_images.h5"
val_img_moving_h5 = f"{base_dir}/val/moving_images.h5"

old_label_moving_h5 = f"{base_dir}/moving_labels.h5"
train_label_moving_h5 = f"{base_dir}/train/moving_labels.h5"
val_label_moving_h5 = f"{base_dir}/val/moving_labels.h5"


with h5py.File(train_img_fixed_h5, 'w-') as tif,  h5py.File(val_img_fixed_h5, 'w-') as vif:
    with h5py.File(train_label_fixed_h5, 'w-') as tlf,  h5py.File(val_label_fixed_h5, 'w-') as vlf:
        with h5py.File(old_img_fixed_h5, 'r') as oif:
            with h5py.File(old_label_fixed_h5, 'r') as olf:
                for prob in oif.keys():
                    img = oif[prob][:]
                    labs = olf[prob][:]
                    if prob in val_probs:
                        vif.create_dataset(prob, data = img)
                        vlf.create_dataset(prob, data = labs)
                    else:
                        tif.create_dataset(prob, data = img)
                        tlf.create_dataset(prob, data = labs)
        

with h5py.File(train_img_moving_h5, 'w-') as tim,  h5py.File(val_img_moving_h5, 'w-') as vim:
    with h5py.File(train_label_moving_h5, 'w-') as tlm,  h5py.File(val_label_moving_h5, 'w-') as vlm:
        with h5py.File(old_img_moving_h5, 'r') as oim:
            with h5py.File(old_label_moving_h5, 'r') as olm:
                for prob in oim.keys():
                    img = oim[prob][:]
                    labs = olm[prob][:]
                    if prob in val_probs:
                        vim.create_dataset(prob, data = img)
                        vlm.create_dataset(prob, data = labs)
                    else:
                        tim.create_dataset(prob, data = img)
                        tlm.create_dataset(prob, data = labs)

### Create empty ROIs

In [22]:
old_img_moving_h5 = f"{base_dir}/moving_images.h5"

train_roi_moving_h5 = f"{base_dir}/train/moving_rois.h5"
val_roi_moving_h5 = f"{base_dir}/val/moving_rois.h5"
train_roi_fixed_h5 = f"{base_dir}/train/fixed_rois.h5"
val_roi_fixed_h5 = f"{base_dir}/val/fixed_rois.h5"


with h5py.File(val_roi_fixed_h5, 'w-') as vrf,  h5py.File(val_roi_moving_h5, 'w-') as vrm:
# with h5py.File(train_roi_fixed_h5, 'w-') as trf,  h5py.File(train_roi_moving_h5, 'w-') as trm:
    with h5py.File(old_img_moving_h5, 'r') as oim:
        for prob in oim.keys():
            blank = np.zeros_like(oim[prob][:])
            vrf.create_dataset(prob, data = blank)
            vrm.create_dataset(prob, data = blank)
            # trf.create_dataset(prob, data = blank)
            # trm.create_dataset(prob, data = blank)