# 1. Using `nets.py`

`nets` is a library of 3D U-Net family of architectures. Although it suffices to use the library without needing to modify `nets.py`, adding further network improvements requires that you modify that file.

## Choosing models 

Below is the table indicating the U-Net architectures available in `nets.py`:

| Model | Function | Description |
|-------|----------|------------|
|U-Net|`unet`| 3D U-Net architecture with kernel sizes of 3x3x3 |
|U-Net 2D|`unet2d`| 3D U-Net architecture with kernel sizes of 3x3x1 |
|U-Net++|`unetpp`| 3D U-Net++ architecture with kernel sizes of 3x3x3 |
|U-Net w/ scSE|`scSEunet`| 3D U-Net architecture with kernel sizes of 3x3x3, and <br>Spatial and Channel-wise Squeeze and Excitation (scSE)<br>[[View paper for scSE](https://arxiv.org/abs/1709.01507)]|
|U-Net 2D w/ scSE|`scSEunet2d`|3D U-Net architecture with kernel sizes of 3x3x1, and <br>Spatial and Channel-wise Squeeze and Excitation (scSE)|
|U-Net++ w/ scSE|`scSEunetpp`|3D U-Net++ architecture with kernel sizes of 3x3x3, and <br>Spatial and Channel-wise Squeeze and Excitation (scSE)|

In order to use the functions listed in the table above, make sure you have imported them from `nets`. Below is an example of importing U-Net++ w/ Squeeze and Excitation blocks (this model is also refered to as GlobalSegNet).

In [1]:
from nets import scSEunetpp

Using TensorFlow backend.


## Initializing models

As can be seen in `train.py`, all models takeas parameters `(W, H, D, C)` for width, height, depth, and number of input channels respectively. To complete the example of U-Net++, below is a code for initializing the network:


In [3]:
model = scSEunetpp(128, 128, 64, 1)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


## Preparing data

All networks in `nets.py` expect a similar input shate of `(N, W, H, D, C_in)` and produces an output of shape `(N, W, H, D, C_out)` where:
- `N`: dataset size
- `W`: input/output image width
- `H`: input/output image height
- `D`: input/output image depth
- `C_in`: input number of channels
- `C_out`: output number of channels

For the sake of this tutorial, we will use `numpy` to create a random tensor of that shape.


In [8]:
import numpy as np 

# Training data
X = np.random.rand(10, 128, 128, 64, 1)
Y = np.random.rand(10, 128, 128, 64, 6)

# Validation data
Xv = np.random.rand(10, 128, 128, 64, 1)
Yv = np.random.rand(10, 128, 128, 64, 6)


### U-Net++ with multiple outputs
U-Net++ produces multiple outputs and is deeply supervised. This means that our model expects multiple outputs for each semantic level ([Refer to this paper](https://arxiv.org/abs/1807.10165)). The following code processes `Y` and `Yv` to be used to fit the model:

In [10]:
Y  = {'out_{}'.format(o):Y  for o in range(len(model.outputs))}
Yv = {'out_{}'.format(o):Yv for o in range(len(model.outputs))}

## Fitting a model
In order to fit the model, we use `model.fit` function. The most basic way to do so is as follows (for the sake of demonstration, we only run the training for one epoch usin the argument `epochs=1`. In a real-case scenario, use a considerable number of epochs):

In [11]:
model.fit(X, Y, batch_size=1, validation_data=(Xv, Yv), epochs=1)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 10 samples, validate on 10 samples


<tensorflow.python.keras.callbacks.History at 0x7f6025d96cd0>

## Checkpoints during training (saving the model)
In order to save the model, this code uses a Keras callback during training to save the model with the least loss value for validation data. The callback is defined as follows:

In [13]:
import tensorflow as tf

In [15]:
checkpointer = tf.keras.callbacks.ModelCheckpoint('model.p5', save_best_only=True)

Additionally, you can have an **early stopping** condition in order to stop the training if the validation loss does not improve in `n` epochs:


In [16]:
earlystopper = tf.keras.callbacks.EarlyStopping(patience=20, monitor='val_loss')

with `n` being 20 in the above case.

## Training with `ModelCheckpoint` and `EarlyStopping`

In [18]:
model.fit(X, Y, batch_size=1, validation_data=(Xv, Yv), callbacks=[checkpointer, earlystopper])

Train on 10 samples, validate on 10 samples


<tensorflow.python.keras.callbacks.History at 0x7f6025d03b50>

## Congratulations
Congratulations! you finished the tutorials, this should be all you need to work this code. Please refer to `train.py` and see how everything is put together.