# Segmentation Network Training

The original Segmentation 3D-UNet was trained on twelve original training images from non-NeuroPAL strains SWF358, SWF359, SWF360, and SWF366. This notebook takes as input the raw images and labels, and crops, resizes, and formats them to be compatible with the segmentation network.

The input images are expected to be an image and a label file of the same shape and voxel size. The exact shape and voxel size do not matter as the notebook can automatically crop and resample the images as appropriate. The label images should take on an integer value between 0 and 3, as follows:

- **0** represents unlabeled pixels. These will be assigned a weight of 0 during training, effectively causing them to be ignored.
- **1** represents foreground pixels (aka neurons). They are assigned a high weight during training.
- **2** represents background pixels (aka pixels that do not belong to any neuron). They are assigned a low weight during training, with the weight decreasing the further away they are from neurons.
- **3** represents gap pixels. These are pixels between two adjacent neurons. In order for the Segmentation Net to learn to perform instance segmentation (splitting adjacent neurons into two separate objects), it is imperative that pixels between two adjacent neurons be labeled as background. Thus, those pixels are labeled with this special label rather than the typical background label 2, and when the network input data is created, they are assigned a very high weight.

To label an image, you can load the image in ITK-SNAP from its MHD file via File -> Open Main Image. Then you can load the labels from their NRRD files via Segmentation -> Load Segmentation. You can then edit the segmentation in ITK-SNAP and save it to an NRRD file. Alternatively, to create an entirely new segmentation, you can simply start drawing on the image using the various label markers (as described above), and then save the result as an NRRD file. Remember to draw in all three dimensions! It is not necessary to label every pixel - indeed, in the initial twelve training and validation images, only about 1% of pixels had a nonzero label. Focus on labeling difficult pixels where the Segmentation Net is struggling - for example, very dim neurons, or a tight ball of neurons.

The input images had the following original properties:
- Images 1 through 14 have voxel size $0.36 \times 0.36 \times 1.0$. They were not used during training as 1.0 is too large to label gaps in the z-dimension competently. They are from a variety of different strains.
- Images 15 through 19 have voxel size $0.36 \times 0.36 \times 0.2$. They are from SWF358-360 strains. They are unusually high SNR, so it is recommended to reduce their SNR before adding them to training data. This notebook does that automatically.
- Images 20 through 22 have voxel size $0.36 \times 0.36 \times 0.36$. They were not used during training due to unknown reasons. Their strain information is unknown.
- Images 23 through 29 have voxel size $0.54 \times 0.54 \times 0.54$, the same size as is currently standard for use in our lab. They are from the SWF366 strain.

All of the images in the directories associated with this notebook have **already been pre-processed** to voxel size $0.54 \times 0.54 \times 0.54$ for training (see the `bin_images.ipynb` notebook for more details).

The images 15, 16, 17, 18, 23, 25, 26, 27, and 29 were used for training while 19, 24, and 28 were used for validation.

Eric later added four NeuroPAL animals to the training data to create the second version of the Segmentation Net. This consisted of 4 additional images (3 training and 1 validation). These labels were created by manual curation of the output of the first version of Segmentation Net, but they were not passed through this notebook and had all pixels weighted identically.

The original segmentation network crop size was $210\times96\times51$, as this was the largest size that would fit on the 8GB GPUs we had available then, and there were some images with $x$-dimension of only 210. These images have since been padded, so that larger crop sizes will be possible.

For more information, see https://github.com/flavell-lab/SegmentationTools.jl and https://github.com/flavell-lab/pytorch-3dunet

In [None]:
using ImageDataIO, SegmentationTools, ProgressMeter, FlavellBase, MHDIO, WormFeatureDetector, HDF5, FileIO, NRRDIO, Statistics, StatsBase, Distributions
using Plots, PyPlot

## Set training and validation dataset paths

Here, you can set which images to process and the relevant paths.

The data is available in [our Dropbox](https://www.dropbox.com/scl/fo/ealblchspq427pfmhtg7h/ALZ7AE5o3bT0VUQ8TTeR1As?rlkey=1e6tseyuwd04rbj7wmn2n6ij7&e=2&st=ybsvv0ry&dl=0) under `SegmentationNet`. Set `rootpath` to the location of this data on your local machine. Please copy the data from the Dropbox before running this notebook to avoid modifying the contents of the Dropbox.

This notebook only deals with the non-NeuroPAL images. The NeuroPAL ones were processed separately and their H5 files are available in `SegmentationNet/hdf5_train` and `SegmentationNet/hdf5_val` with non-integer filenames.

In [None]:
imgs = [15,16,17,18,19,23,24,25,26,27,28,29] # which datasets to use during training/validation
rootpath="/store1/adam/test" # root path to all input and output data.
label_dir = "label_binned_uncropped" # subpath containing labels
raw_dir = "img_binned_uncropped" # subpath containing images
hdf5_dir = "hdf5"; # subpath containing output path where `pytorch-3dunet`-compatible H5 files will be written

## Set parameters

In [None]:
fg=5 # weight of foreground (neuron) labels
gap=20 # base weight of background-gap labels
scale_bkg_gap=true; # whether to increase the background-gap weight based on the number of adjacent neurons

You may want to update these cropping parameters to your desired crop size.

In [None]:
crop_dict = Dict(
    15 => [67:279, 139:234, 21:71],
    16 => [33:245, 71:166, 21:71],
    17 => [19:231, 147:242, 1:51],
    18 => [141:236, 125:337, 21:71],
    19 => [117:212, 83:295, 11:61],
    20 => [195:338, 135:454, 1:77],
    21 => [175:318, 135:454, 1:77],
    22 => [185:328, 135:454, 1:77],
    23 => [100:195, 57:266, 32:82],
    24 => [61:270, 70:165, 21:71],
    25 => [61:270, 60:155, 34:84],
    26 => [51:260, 40:135, 28:78],
    27 => [51:260, 40:135, 36:86],
    28 => [51:260, 55:150, 40:90],
    29 => [105:200, 57:266, 40:90]
); # crop size parameters for all binned datasets

## Format data

This code crops the images and labels, and generates `pytorch-3dunet`-compatible training and validation datasets with image, label, and weight data.

In [None]:
@showprogress for i in imgs
    crop = crop_dict[i]

    bin_scale=[1,1,1]
    reduction_factor=1
    if i in 15:19
        bin_scale = [1,1,1] # the binning already happened
        reduction_factor = 2
    elseif i in 20:22
        bin_scale = [1.5, 1.5, 1.5]
        reduction_factor = 1
    elseif i < 15
        bin_scale = [1.5, 1.5, 0.54]
        reduction_factor = 1
    end
    
    transpose = (i in [10,11,12,13,14,18,19,20,21,22,23,29])

    img = MHDIO.read_img(MHD(joinpath(rootpath, raw_dir, "$(i)_img.mhd")))
    label = NRRDIO.read_img(NRRD(joinpath(rootpath, label_dir, "$(i)_label.nrrd")))


    h5_dir = joinpath(rootpath, hdf5_dir)
    create_dir(h5_dir)
    
    make_unet_input_h5(img, label, joinpath(h5_dir, string(i, pad=2)*".h5"), scale_xy=0.36, scale_z=1,
        scale_bkg_gap=scale_bkg_gap, crop=crop, transpose=transpose, weight_foreground=fg, weight_bkg_gap=gap, bin_scale=bin_scale,
        SN_reduction_factor=reduction_factor
    )
end

### Copy files into training and validation directories

In [None]:
let
    train_dir = joinpath(rootpath, "hdf5_train")
    create_dir(train_dir)
    for i in [15,16,17,18,23,25,26,27,29] # training datasets
        cp(joinpath(rootpath, hdf5_dir, "$(i).h5"), joinpath(train_dir, "$(i).h5"))
    end
    val_dir = joinpath(rootpath, "hdf5_val")
    create_dir(val_dir)
    for i in [19,24,28] # validation datasets
        cp(joinpath(rootpath, hdf5_dir, "$(i).h5"), joinpath(val_dir, "$(i).h5"))
    end
end