# Create training data for CellDiscoveryNet

This notebook creates training data for the CellDiscoveryNet model. It assumes you've already run the `make_AutoCellLabeler_input` notebook to create 4D H5 files. Notably, it is important that `θh_pos_is_ventral` is set correctly in that notebook to align the images. It is not necessary for those H5 files to have human labels - as CellDiscoveryNet uses unsupervised learning, any labels will be discarded.

This notebook then aligns those H5 files using manual rotation to align the axes, followed by GPU-accelerated Euler registration.

The output of this notebook is available in [our Dropbox](https://www.dropbox.com/scl/fo/ealblchspq427pfmhtg7h/ALZ7AE5o3bT0VUQ8TTeR1As?rlkey=1e6tseyuwd04rbj7wmn2n6ij7&e=1&st=ybsvv0ry&dl=0) in the `CellDiscoveryNet` folder.

In [None]:
ENV["CUDA_VISIBLE_DEVICES"] = 0 # set GPU to use

using ImageRegistration

using HDF5

using ImageDataIO
using PyCall
using NRRDIO
using Statistics
using ImageRegistration
using FlavellBase
using ProgressMeter
using MultivariateStats
using PyPlot
using ProgressMeter
using Images

using JLD2

## Load data

In [None]:
datasets_prj_neuropal = ["2022-07-15-06", "2022-07-15-12", "2022-07-20-01", "2022-07-26-01", "2022-08-02-01", "2023-01-23-08", "2023-01-23-15", "2023-01-23-21", "2023-01-19-08", "2023-01-19-22", "2023-01-09-28", "2023-01-17-01", "2023-01-19-15", "2023-01-23-01", "2023-03-07-01", "2022-12-21-06", "2023-01-05-18", "2023-01-06-01", "2023-01-06-08", "2023-01-09-08", "2023-01-09-15", "2023-01-09-22", "2023-01-10-07", "2023-01-10-14", "2023-01-13-07", "2023-01-16-01", "2023-01-16-08", "2023-01-16-15", "2023-01-16-22", "2023-01-17-07", "2023-01-17-14", "2023-01-18-01"]
datasets_prj_rim = ["2023-06-09-01", "2023-07-28-04", "2023-06-24-02", "2023-07-07-11", "2023-08-07-01", "2023-06-24-11", "2023-07-07-18", "2023-08-18-11", "2023-06-24-28", "2023-07-11-02", "2023-08-22-08", "2023-07-12-01", "2023-07-01-09", "2023-07-13-01", "2023-06-09-10", "2023-07-07-01", "2023-08-07-16", "2023-08-22-01", "2023-08-23-23", "2023-08-25-02", "2023-09-15-01", "2023-09-15-08", "2023-08-18-18", "2023-08-19-01", "2023-08-23-09", "2023-08-25-09", "2023-09-01-01", "2023-08-31-03", "2023-07-01-01", "2023-07-01-23"]

datasets_prj_aversion = ["2023-03-30-01", "2023-06-29-01", "2023-06-29-13", "2023-07-14-08", "2023-07-14-14", "2023-07-27-01", "2023-08-08-07", "2023-08-14-01", "2023-08-16-01", "2023-08-21-01", "2023-09-07-01", "2023-09-14-01", "2023-08-15-01", "2023-10-05-01", "2023-06-23-08", "2023-12-11-01", "2023-06-21-01"]
datasets_prj_5ht = ["2022-07-26-31", "2022-07-26-38", "2022-07-27-31", "2022-07-27-38", "2022-07-27-45", "2022-08-02-31", "2022-08-02-38", "2022-08-03-31"]
datasets_prj_starvation = ["2023-05-25-08", "2023-05-26-08", "2023-06-05-10", "2023-06-05-17", "2023-07-24-27", "2023-09-27-14", "2023-05-25-01", "2023-05-26-01", "2023-07-24-12", "2023-07-24-20", "2023-09-12-01", "2023-09-19-01", "2023-09-29-19", "2023-10-09-01", "2023-09-13-02"]

# append all datasets togther
datasets = []
append!(datasets, datasets_prj_neuropal)
append!(datasets, datasets_prj_rim)
append!(datasets, datasets_prj_aversion)
append!(datasets, datasets_prj_5ht)
append!(datasets, datasets_prj_starvation)

datasets_val = ["2023-06-24-02", "2023-08-07-01", "2023-08-19-01", # RIM datasets
                "2022-07-26-01", "2023-01-23-21", "2023-01-23-01", # NeuroPAL datasets
                "2023-07-14-08", # Aversion datasets
                "2022-08-02-31", # 5-HT datasets
                "2023-07-24-27", "2023-07-24-20"] # Starvation datasets
datasets_test = ["2023-08-22-01", "2023-07-07-18", "2023-07-01-23",  # RIM datasets
                 "2023-01-06-01", "2023-01-10-07", "2023-01-17-07", # Neuropal datasets
                 "2023-08-21-01", "2023-06-23-08", # Aversion datasets
                 "2022-07-27-38", # 5-HT datasets
                 "2023-10-09-01", "2023-09-13-02" # Starvation datasets
                 ]
datasets_train = [dataset for dataset in datasets if !(dataset in datasets_val) && !(dataset in datasets_test)]

datasets_ = deepcopy(datasets_train)
append!(datasets_, datasets_val)
append!(datasets_, datasets_test);

In [None]:
# Function to load images from HDF5 files
function load_images(datasets, base_path; key="raw")
    imgs = []
    for dataset in datasets
        path_img = base_path * dataset * ".h5"
        img = nothing
        h5open(path_img, "r") do f
            img = read(f[key])
        end
        push!(imgs, img)
    end
    return imgs
end

# Function to rotate 4D image by 180 degrees in the xy plane
function rotate_image_180_xy(img)
    for z in 1:size(img, 3)
        for ch in 1:size(img, 4)
            img[:, :, z, ch] = rot180(img[:, :, z, ch])
        end
    end
    return img
end



## Manually align the images

This notebook assumes the images are already aligned about the xz axis, which is done automatically by the `make_AutoCellLabeler_input` notebook using the `θh_pos_is_ventral` parameter. However, the xy axis may be misaligned. We will manually align the images by rotating them about the z axis. To do this for your data, check where the heads of each of the animals are and set `rotate_idx_train` to the set of images with the uncommon head orientation to rotate those images by 180 degrees.

In [None]:
imgs_train = load_images(datasets_train, "/path/to/your/input/data/train")
imgs_val = load_images(datasets_val, "/path/to/your/input/data/val")
imgs_test = load_images(datasets_test, "/path/to/your/input/data/test")
imgs_roi = load_images(datasets_, "/path/to/your/input/data/roi_crop/", key="roi");

root_path = "/path/to/your/output/data/CellDiscoveryNet"

create_dir(root_path)
rotate_idx_train = [5, 62];

In [None]:
# Rotate specified images by 180 degrees in the xy plane for training data
for idx in rotate_idx_train
    imgs_train[idx] = rotate_image_180_xy(imgs_train[idx])
end

In [None]:
imgs_all = deepcopy(imgs_train)
append!(imgs_all, imgs_val)
append!(imgs_all, imgs_test);

In [None]:
# normalize images per channel
for i in 1:length(imgs_all)
    img = Float64.(imgs_all[i])
    for j in 1:4
        img[:,:,:,j] ./= maximum(img[:,:,:,j])
    end
    imgs_all[i] = img
end

In [None]:
for idx in rotate_idx_train
    imgs_roi[idx] = rotate_image_180_xy(imgs_roi[idx])
end

In [None]:
datasets_component = Dict(
    "train" => datasets_train,
    "val" => datasets_val,
    "test" => datasets_test
);

## Euler register images

This registers each pair of images in the xy plane using the GPU-accelerated Euler registration.



In [None]:
euler_parameters = Dict()
np = pyimport("numpy")

euler_x_translation_range_1 = np.sort(np.concatenate((
    np.linspace(-0.24, 0.24, 49),
    np.linspace(-0.46, -0.25, 8),
    np.linspace(0.25, 0.46, 8),
    np.linspace(0.5, 1, 3),
    np.linspace(-1, -0.5, 3))))

euler_y_translation_range_1 = np.sort(np.concatenate((
    np.linspace(-0.28, 0.28, 29),
    np.linspace(-0.54, -0.3, 5),
    np.linspace(0.3, 0.54, 5),
    np.linspace(0.6, 1.4, 3),
    np.linspace(-1.4, -0.6, 3))))

euler_theta_rotation_range_xy = np.sort(np.concatenate((
    np.linspace(0, 19, 20),
    np.linspace(341, 359, 19)))) # disallow 180 degree rotations since we manually handled those earlier

memory_dict = nothing
memory_dict_3d = nothing

euler_gpu = pyimport("euler_gpu")
pytorch = pyimport("torch")

device = pytorch.device("cuda:0")

BATCH_SIZE = 5000 # reduce this if you run out of GPU memory

@showprogress for (j, fixed_dataset) in enumerate(datasets_)
    fixed_image = imgs_all[j]
    transposed_fixed_image = permutedims(fixed_image, (4, 3, 2, 1))

    create_dir(joinpath(root_path, "img_fixed"))
    fixed_img_file = joinpath(root_path, "img_fixed/$(j).h5")
    h5write(fixed_img_file, "raw", transposed_fixed_image)

    path = joinpath(root_path, "roi_fixed")
    create_dir(path)    
    h5open(joinpath(path, "$(j).h5"), "w") do f
        img = imgs_roi[j]
        write(f, "roi", permutedims(img, (3,2,1))) # make dimensions consistent with images
    end

    for (i, moving_dataset) in enumerate(datasets_)
        if i >= j 
            continue
        end
        moving_image = imgs_all[i]

        @assert(size(moving_image)[1] % 2 == 0 && size(moving_image)[2] % 2 == 0 && size(moving_image) == size(fixed_image))

        moving_image_downsampled = euler_gpu.max_intensity_projection_and_downsample(moving_image[:,:,:,4], 4, 2)
        fixed_image_downsampled = euler_gpu.max_intensity_projection_and_downsample(fixed_image[:,:,:,4], 4, 2)

        if isnothing(memory_dict)
            memory_dict = euler_gpu.initialize(fixed_image_downsampled, moving_image_downsampled, euler_x_translation_range_1, euler_y_translation_range_1, euler_theta_rotation_range_xy, BATCH_SIZE, device)
        else
            memory_dict["moving_images_repeated"] = pytorch.tensor(moving_image_downsampled, device=device, dtype=pytorch.float32).unsqueeze(0).repeat(BATCH_SIZE,1,1,1);
            memory_dict["fixed_images_repeated"] = pytorch.tensor(fixed_image_downsampled, device=device, dtype=pytorch.float32).unsqueeze(0).repeat(BATCH_SIZE,1,1,1);
        end

        best_score, best_transformation = euler_gpu.grid_search(memory_dict)

        z_dim = size(fixed_image)[3]

        if isnothing(memory_dict_3d)
            memory_dict_3d = euler_gpu.initialize(fixed_image[:,:,1,4], moving_image[:,:,1,4], zeros(z_dim), zeros(z_dim), zeros(z_dim), z_dim, device);
        end

        moving_image_tensor = pytorch.tensor(permutedims(moving_image[:,:,:,4], [3, 1, 2]), device=device, dtype=pytorch.float32).unsqueeze(1).repeat(1,1,1,1);
        moving_image_transformed = euler_gpu.transform_image(moving_image_tensor, (best_transformation[1]).repeat(z_dim), best_transformation[2].repeat(z_dim), best_transformation[3].repeat(z_dim), memory_dict_3d);
        moving_image_transformed_cpu = permutedims(dropdims(moving_image_transformed.cpu().numpy(), dims=2), [2,3,1]);
        shift_range = collect(-40:40)
        dz, gncc, moving_image_transformed_z = euler_gpu.translate_along_z(shift_range, fixed_image[:,:,:,1], moving_image_transformed_cpu, 0)

        euler_parameters[(i, j)] = ([tfm.cpu().numpy() for tfm in best_transformation], dz)

        transformed_moving_image = zeros(size(moving_image)...)

        for ch = 1:4
            moving_image_tensor = pytorch.tensor(permutedims(moving_image[:,:,:,ch], [3, 1, 2]), device=device, dtype=pytorch.float32).unsqueeze(1).repeat(1,1,1,1);
            moving_image_transformed = euler_gpu.transform_image(moving_image_tensor, (best_transformation[1]).repeat(z_dim), best_transformation[2].repeat(z_dim), best_transformation[3].repeat(z_dim), memory_dict_3d);
            moving_image_transformed_cpu = permutedims(dropdims(moving_image_transformed.cpu().numpy(), dims=2), [2,3,1]);
            moving_image_transformed_z = translate_z(moving_image_transformed_cpu, dz, 0.0)

            transformed_moving_image[:,:,:,ch] .= moving_image_transformed_z
        end

        # Save the transformed_moving_image and fixed_image
        transposed_moving_image = permutedims(transformed_moving_image, (4, 3, 2, 1))


        for component in ["train", "val", "test"]
            # note that in ANTSUN 2U we use all registrations, including ones not localized to one component
            if moving_dataset in datasets_component[component] && fixed_dataset in datasets_component[component]
                create_dir(joinpath(root_path, "$(component)"))
                moving_img_file = joinpath(root_path, "$(component)/moving_images.h5")
                fixed_img_file = joinpath(root_path, "$(component)/fixed_images.h5")
                h5write(moving_img_file, "dataset_$(i)_$(j)", transposed_moving_image)
                h5write(fixed_img_file, "dataset_$(i)_$(j)", transposed_fixed_image)
                break
            end
        end

        create_dir(joinpath(root_path, "euler_tfm_moving"))
        moving_img_file = joinpath(root_path, "euler_tfm_moving/$(i)_$(j).h5")
        h5write(moving_img_file, "raw", transposed_moving_image)

        moving_roi_image = np.array(imgs_roi[i], dtype=np.int32)
        
        z_dim = size(moving_roi_image)[3]
        if isnothing(memory_dict_3d)
            memory_dict_3d = euler_gpu.initialize(moving_roi_image[:,:,1], moving_roi_image[:,:,1], zeros(z_dim), zeros(z_dim), zeros(z_dim), z_dim, device);
        end

        moving_roi_image_tensor = pytorch.tensor(permutedims(moving_roi_image[:,:,:], [3, 1, 2]), device=device, dtype=pytorch.float32).unsqueeze(1).repeat(1,1,1,1);
        moving_roi_image_transformed = euler_gpu.transform_image(moving_roi_image_tensor, (best_transformation[1]).repeat(z_dim), best_transformation[2].repeat(z_dim), best_transformation[3].repeat(z_dim), memory_dict_3d, interpolation="nearest");
        moving_roi_image_transformed_cpu = permutedims(dropdims(moving_roi_image_transformed.cpu().numpy(), dims=2), [2,3,1]);
        moving_roi_image_transformed_z = translate_z(moving_roi_image_transformed_cpu, dz, 0.0)

        # Save the transformed_moving_roi_image and fixed_image
        transposed_moving_roi_image = permutedims(moving_roi_image_transformed_z, (3, 2, 1))

        create_dir(joinpath(root_path, "euler_tfm_moving_roi"))
        moving_roi_img_file = joinpath(root_path, "euler_tfm_moving_roi/$(i)_to_$(j).h5")
        h5write(moving_roi_img_file, "roi", transposed_moving_roi_image)
    end
end

### Save registration parameters

In [None]:
JLD2.@save(joinpath(root_path, "euler_parameters.jld2"), euler_parameters)

### Load saved parameters

In [None]:
euler_parameters = JLD2.load(joinpath(root_path, "euler_parameters.jld2"))["euler_parameters"]