# Segmentation Network Training

The original Segmentation 3D-UNet was trained on twelve original training images from non-NeuroPAL strains SWF358, SWF359, SWF360, and SWF366. This notebook takes as input the raw images and labels, and crops, resizes, and formats them to be compatible with the segmentation network.

The input images are expected to be an image and a label file of the same shape and voxel size. The exact shape and voxel size do not matter as the notebook can automatically crop and resample the images as appropriate. The label images should take on an integer value between 0 and 3, as follows:

- **0** represents unlabeled pixels. These will be assigned a weight of 0 during training, effectively causing them to be ignored.
- **1** represents foreground pixels (aka neurons). They are assigned a high weight during training.
- **2** represents background pixels (aka pixels that do not belong to any neuron). They are assigned a low weight during training, with the weight decreasing the further away they are from neurons.
- **3** represents gap pixels. These are pixels between two adjacent neurons. In order for the Segmentation Net to learn to perform instance segmentation (splitting adjacent neurons into two separate objects), it is imperative that pixels between two adjacent neurons be labeled as background. Thus, those pixels are labeled with this special label rather than the typical background label 2, and when the network input data is created, they are assigned a very high weight.

To label an image, you can load the image in ITK-SNAP from its MHD file via File -> Open Main Image. Then you can load the labels from their NRRD files via Segmentation -> Load Segmentation. You can then edit the segmentation in ITK-SNAP and save it to an NRRD file. Alternatively, to create an entirely new segmentation, you can simply start drawing on the image using the various label markers (as described above), and then save the result as an NRRD file. Remember to draw in all three dimensions! It is not necessary to label every pixel - indeed, in the initial twelve training and validation images, only about 1% of pixels had a nonzero label. Focus on labeling difficult pixels where the Segmentation Net is struggling - for example, very dim neurons, or a tight ball of neurons.

The input images had the following original properties:
- Images 1 through 14 have voxel size $0.36 \times 0.36 \times 1.0$. They were not used during training as 1.0 is too large to label gaps in the z-dimension competently. They are from a variety of different strains.
- Images 15 through 19 have voxel size $0.36 \times 0.36 \times 0.2$. They are from SWF358-360 strains. They are unusually high SNR, so it is recommended to reduce their SNR before adding them to training data. This notebook does that automatically.
- Images 20 through 22 have voxel size $0.36 \times 0.36 \times 0.36$. They were not used during training due to unknown reasons. Their strain information is unknown.
- Images 23 through 29 have voxel size $0.54 \times 0.54 \times 0.54$, the same size as is currently standard for use in our lab. They are from the SWF366 strain.

All of the images in the directories associated with this notebook have already been pre-processed to voxel size $0.54 \times 0.54 \times 0.54$ for training.

The images 15, 16, 17, 18, 23, 25, 26, 27, and 29 were used for training while 19, 24, and 28 were used for validation.

Eric later added four NeuroPAL animals to the training data to create the second version of the Segmentation Net. This consisted of 4 additional images (3 training and 1 validation). These labels were created by manual curation of the output of the first version of Segmentation Net, but they were not passed through this notebook and had all pixels weighted identically.

The original segmentation network crop size was $210\times96\times51$, as this was the largest size that would fit on the 8GB GPUs we had available then, and there were some images with $x$-dimension of only 210. These images have since been padded, so that larger crop sizes will be possible.

For more information, see https://github.com/flavell-lab/SegmentationTools.jl and https://github.com/flavell-lab/pytorch-3dunet

In [None]:
using ImageDataIO, SegmentationTools, ProgressMeter, FlavellBase, MHDIO, WormFeatureDetector, HDF5, FileIO, NRRDIO, Statistics, StatsBase, Distributions
using Plots, PyPlot

## Set training and validation dataset paths

Here, you can set which images to process and the relevant paths.

The data is available in [our Dropbox](https://www.dropbox.com/scl/fo/ealblchspq427pfmhtg7h/ALZ7AE5o3bT0VUQ8TTeR1As?rlkey=1e6tseyuwd04rbj7wmn2n6ij7&e=2&st=ybsvv0ry&dl=0) under `SegmentationNet`. Set `rootpath` to the location of this data on your local machine. Please copy the data from the Dropbox before running this notebook to avoid modifying the contents of the Dropbox.

This notebook only deals with the non-NeuroPAL images. The NeuroPAL ones were processed separately and their H5 files are available in `SegmentationNet/hdf5_train` and `SegmentationNet/hdf5_val` with non-integer filenames.

In [None]:
imgs = [15,16,17,18,19,23,24,25,26,27,28,29] # which datasets to use during training/validation
rootpath="/store1/adam/test" # root path to all input and output data.
label_dir = "label_binned_uncropped" # subpath containing labels
raw_dir = "img_binned_uncropped" # subpath containing images
hdf5_dir = "hdf5"; # subpath containing output path where `pytorch-3dunet`-compatible H5 files will be written

## Set parameters

In [None]:
fg=5 # weight of foreground (neuron) labels
gap=20 # base weight of background-gap labels
scale_bkg_gap=true; # whether to increase the background-gap weight based on the number of adjacent neurons

You may want to update these cropping parameters to your desired crop size.

In [None]:
crop_dict = Dict(
    15 => [67:279, 139:234, 21:71],
    16 => [33:245, 71:166, 21:71],
    17 => [19:231, 147:242, 1:51],
    18 => [141:236, 125:337, 21:71],
    19 => [117:212, 83:295, 11:61],
    20 => [195:338, 135:454, 1:77],
    21 => [175:318, 135:454, 1:77],
    22 => [185:328, 135:454, 1:77],
    23 => [100:195, 57:266, 32:82],
    24 => [61:270, 70:165, 21:71],
    25 => [61:270, 60:155, 34:84],
    26 => [51:260, 40:135, 28:78],
    27 => [51:260, 40:135, 36:86],
    28 => [51:260, 55:150, 40:90],
    29 => [105:200, 57:266, 40:90]
); # crop size parameters for all binned datasets

## Format data

This code crops the images and labels, and generates `pytorch-3dunet`-compatible training and validation datasets with image, label, and weight data.

In [None]:
@showprogress for i in imgs
    crop = crop_dict[i]

    bin_scale=[1,1,1]
    reduction_factor=1
    if i in 15:19
        bin_scale = [1,1,1] # the binning already happened
        reduction_factor = 2
    elseif i in 20:22
        bin_scale = [1.5, 1.5, 1.5]
        reduction_factor = 1
    elseif i < 15
        bin_scale = [1.5, 1.5, 0.54]
        reduction_factor = 1
    end
    
    transpose = (i in [10,11,12,13,14,18,19,20,21,22,23,29])

    img = MHDIO.read_img(MHD(joinpath(rootpath, raw_dir, "$(i)_img.mhd")))
    label = NRRDIO.read_img(NRRD(joinpath(rootpath, label_dir, "$(i)_label.nrrd")))


    h5_dir = joinpath(rootpath, hdf5_dir)
    create_dir(h5_dir)
    
    make_unet_input_h5(img, label, joinpath(h5_dir, string(i, pad=2)*".h5"), scale_xy=0.36, scale_z=1,
        scale_bkg_gap=scale_bkg_gap, crop=crop, transpose=transpose, weight_foreground=fg, weight_bkg_gap=gap, bin_scale=bin_scale,
        SN_reduction_factor=reduction_factor
    )
end

### Copy files into training and validation directories

In [None]:
let
    train_dir = joinpath(rootpath, "hdf5_train")
    create_dir(train_dir)
    for i in [15,16,17,18,23,25,26,27,29] # training datasets
        cp(joinpath(rootpath, hdf5_dir, "$(i).h5"), joinpath(train_dir, "$(i).h5"))
    end
    val_dir = joinpath(rootpath, "hdf5_val")
    create_dir(val_dir)
    for i in [19,24,28] # validation datasets
        cp(joinpath(rootpath, hdf5_dir, "$(i).h5"), joinpath(val_dir, "$(i).h5"))
    end
end

## Fix binning, padding, and cropping of various datasets

You should not need to interact with this code directly, but in case you need to deal with images with unusual voxel sizes, this code will help you rebin, pad, and crop them as necessary. This may become relevant for trying to use other labs' data in the Segmentation Net.

The original images 15-19 had unusual voxel sizes, with initial labels generated post-cropping. To change the crop size, we need to rebin the images and then un-crop the labels.

Additionally, original images 23 and 29 were transposed. They need to be padded in the y-dimension to update the crop size to a larger value.

In [None]:
crop_dict_orig = Dict(
    15 => [100:419, 210:353, 85:224],
    16 => [51:370, 108:251, 55:194],
    17 => [30:349, 220:363, 21:160],
    18 => [212:355, 189:508, 56:195],
    19 => [175:318, 125:444, 34:173],
    20 => [195:338, 135:454, 1:77],
    21 => [175:318, 135:454, 1:77],
    22 => [185:328, 135:454, 1:77],
    23 => [100:195, 1:210, 32:82],
    24 => [61:270, 70:165, 21:71],
    25 => [61:270, 60:155, 34:84],
    26 => [51:260, 40:135, 28:78],
    27 => [51:260, 40:135, 36:86],
    28 => [51:260, 55:150, 40:90],
    29 => [105:200, 1:210, 40:90]
); # original crop size parameters for all datasets

In [None]:
new_crop_dict = Dict(); # new crop sizes

In [None]:
input_label_dir = "label_cropped" # subpath containing labels
input_raw_dir = "img_uncropped" # subpath containing images

output_img_dir = joinpath(rootpath, "img_binned_uncropped")
output_label_dir = joinpath(rootpath, "NEW_label_uncropped");

### Bin images 15-19 and uncrop their labels

In [None]:
for i in 15:19
    mhd_img = MHD(joinpath(rootpath, input_raw_dir, "$(i)_img.mhd"))
    img = MHDIO.read_img(mhd_img) # UNBINNED, UNCROPPED image

    nrrd_label = NRRD(joinpath(rootpath, input_label_dir, "$(i)_label.nrrd"))
    label = NRRDIO.read_img(nrrd_label) # BINNED, CROPPED label

    # only the LABEL will be (un-)transposed
    if i in 18:19
        label = permutedims(label, [2,1,3])
    end
    
    bin_scale = [1.5,1.5,2.7]
    crop = crop_dict_orig[i]


    # starting points of crop on the new image
    new_crop = [1,1,1]
    for j=1:3
        is_int = false
        while !is_int
            resampled_crop = (crop[j][1]-new_crop[j])/bin_scale[j]+1 # starting point of crop
            if resampled_crop == floor(resampled_crop) # check if int
                is_int = true
            else
                new_crop[j] += 1
            end
        end
    end

    new_crop_dict[i] = new_crop

    img = img[new_crop[1]:end, new_crop[2]:end, new_crop[3]:end]; # crop image to ensure integer binning offset

    resampled_img = resample_img(img, bin_scale) # resample image to match label size

    label_uncropped = zeros(size(resampled_img))

    # Extract start and end indices from crop
    crop_start = [first(crop[j]) for j in 1:3]
    crop_end = [last(crop[j]) for j in 1:3]

    # Calculate the corresponding uncropped region in the resampled image
    label_uncropped_start = Int.((crop_start .- new_crop) ./ bin_scale .+ 1)
    label_uncropped_end = label_uncropped_start .+ Int.(floor.((crop_end .- crop_start .+ 1) ./ bin_scale)) .- 1

    # Initialize the uncropped label matrix with the size of the resampled image
    label_uncropped = zeros(size(resampled_img))

    # Place the cropped label into the correct uncropped position
    label_uncropped[label_uncropped_start[1]:label_uncropped_end[1], 
                label_uncropped_start[2]:label_uncropped_end[2], 
                label_uncropped_start[3]:label_uncropped_end[3]] .= label

    # Save the resampled image and the uncropped label
    MHDIO.write_raw(joinpath(output_img_dir, "$(i)_img.raw"), resampled_img)
    MHDIO.write_MHD_spec(joinpath(output_img_dir, "$(i)_img.mhd"), 0.54, 0.54, size(resampled_img)..., "$(i)_img.raw")
    NRRDIO.write_nrrd(joinpath(output_label_dir, "$(i)_label.nrrd"), label_uncropped, (0.54, 0.54, 0.54))
end


In [None]:
resampled_crop_dict = Dict()

for i=15:19
    crop = crop_dict_orig[i]
    new_crop = new_crop_dict[i]

    if i in 18:19
        crop = [crop[2], crop[1], crop[3]]
    end
    
    bin_scale = [1.5,1.5,2.7]

    crop_start = [first(crop[j]) for j in 1:3]
    crop_end = [last(crop[j]) for j in 1:3]

    # Calculate the corresponding uncropped region in the resampled image
    label_uncropped_start = Int.((crop_start .- new_crop) ./ bin_scale .+ 1)
    label_uncropped_end = label_uncropped_start .+ Int.(floor.((crop_end .- crop_start .+ 1) ./ bin_scale)) .- 1

    resampled_crop_dict[i] = [label_uncropped_start[j]:label_uncropped_end[j] for j in 1:3]
end

### Pad images 23 and 29

In [None]:
function pad_to_square(image::Array{T, 3}) where T
    size_x, size_y, size_z = size(image)
    
    # Determine the median value of the entire image and cast it to the original data type
    median_value = T(floor(median(image)))

    # Initialize padding amounts
    padding_front = (0, 0, 0)
    padded_image = image

    if size_x < size_y
        # Pad along the x-axis
        pad_size = size_y - size_x
        pad_before = div(pad_size, 2)
        pad_after = pad_size - pad_before
        padded_image = fill(median_value, size_y, size_y, size_z)
        padded_image[pad_before+1:pad_before+size_x, :, :] .= image
        padding_front = (pad_before, 0, 0)
    elseif size_y < size_x
        # Pad along the y-axis
        pad_size = size_x - size_y
        pad_before = div(pad_size, 2)
        pad_after = pad_size - pad_before
        padded_image = fill(median_value, size_x, size_x, size_z)
        padded_image[:, pad_before+1:pad_before+size_y, :] .= image
        padding_front = (0, pad_before, 0)
    end

    return padded_image, padding_front
end


In [None]:
for i in [23, 29]
    img = MHDIO.read_img(MHD(joinpath(rootpath, input_raw_dir, "$(i)_img.mhd")))
    label = NRRDIO.read_img(NRRD(joinpath(rootpath, input_label_dir, "$(i)_label.nrrd")))

    img, pad1 = pad_to_square(img)
    label, pad2 = pad_to_square(label)

    println("$(i): $(pad1)")
    println("$(i): $(pad2)")


    MHDIO.write_raw(joinpath(output_img_dir, "$(i)_img.raw"), img)
    MHDIO.write_MHD_spec(joinpath(output_img_dir, "$(i)_img.mhd"), 0.54, 0.54, size(img)..., "$(i)_img.raw")
    NRRDIO.write_nrrd(joinpath(output_label_dir, "$(i)_label.nrrd"), label, (0.54, 0.54, 0.54))
end