In [None]:
import numpy as np
import os
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
import torch
import json
%matplotlib tk

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.
<a id='center'></a>

In [None]:
data_dir = 'crops'
project_name = 'Mouse-Organoid-Cells-CBG'
center = 'medoid' # 'centroid', 'medoid'

print("Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(project_name, data_dir))

In [None]:
try:
    assert center in {'medoid', 'centroid'}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "centroid"}', 42)
    raise

### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. For the `Mouse-Organoid-Cells-CBG` dataset, we obtain ~ 600 crops, hence we set `train_size` to 600. 
* The effective batch size is determined as a product of the attributes `train_batch_size` and `virtual_train_batch_multiplier`. For example, one could set a small `batch_size` say equal to 2 (to fit in one's GPU memory), and a large `virtual_train_batch_multiplier` say equal to 8, to get an effective batch size equal to 16. 
* The `normalization_factor` attribute normalizes the raw images to always be between 0 and 1. For 8- bit images, please set `normalization_factor = 255`, for 16-bit images, please set `normalization_factor = 65535`.

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [None]:
train_size = 600
train_batch_size = 2 
virtual_train_batch_multiplier = 8 
normalization_factor = 65535

### Create the `train_dataset_dict` dictionary  

In [None]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         center = center, 
                                         size = train_size, 
                                         batch_size = train_batch_size, 
                                         virtual_batch_multiplier = virtual_train_batch_multiplier, 
                                         normalization_factor= normalization_factor,
                                         type = 'train',
                                         name = '3d')

### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to the total number of validation image crops. For example, for the `Mouse-Organoid-Cells-CBG` dataset, we notice ~120 validation crops, hence we set `val_size = 120`.
* The effective batch size is determined as a product of the attributes `val_batch_size` and `virtual_val_batch_multiplier`. Here at times, it is okay to set a higher effective batch size for the validation dataset than the train dataset, since evaluating on validation data consumes lesser GPU memory.
* The `normalization_factor` attribute normalizes the raw images to always be between 0 and 1. For 8- bit images, please set `normalization_factor = 255`, for 16-bit images, please set `normalization_factor = 65535`. Please note that the `normalization_factor` should be the same for both train and validation images.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!



In [None]:
val_size = 120
val_batch_size = 16
virtual_val_batch_multiplier = 1
normalization_factor = 65535

### Create the `val_dataset_dict` dictionary

In [None]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                       project_name = project_name, 
                                       center = center, 
                                       size = val_size, 
                                       batch_size = val_batch_size, 
                                       virtual_batch_multiplier = virtual_val_batch_multiplier,
                                       normalization_factor= normalization_factor,
                                       type ='val',
                                       name ='3d')

### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to the number of channels in the input images. 
* Set the `num_classes = [6, 1]` for `3d` training and `num_classes = [4, 1]` for `2d` training
<br>(here, 6 implies the offsets and bandwidths in x, y and z dimensions and 1 implies the `seediness` value per pixel)
* Set `name = 'branched_erfnet_3d'` for using network employing 3d convolutions and  `name = 'branched_erfnet'` for using network employing 2d convolutions.

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [None]:
input_channels = 1
num_classes = [6, 1] 
name = 'branched_erfnet_3d' 

### Create the `model_dict` dictionary

In [None]:
model_dict = create_model_dict(input_channels = input_channels,
                              num_classes = num_classes,
                              name = name)

### Create the `loss_dict` dictionary

In [None]:
if os.path.isfile('data_properties.json'): 
    with open('data_properties.json') as json_file:
        data = json.load(json_file)
        foreground_weight = data['foreground_weight']
else:
    foreground_weight = 10.0

In [None]:
loss_dict = create_loss_dict(n_sigma = 3, foreground_weight = foreground_weight)

### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for reasonable results, you should atleast train for longer than 25 epochs.
* The `display` attribute, if set to True, allows you to see the network predictions as the training proceeds. 
* The `display_embedding` attribute, if set to True, allows you to see some sample embedding as the training proceeds. Setting this to False leads to faster training times.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './experiment/Mouse-Organoid-Cells-CBG-demo/checkpoint.pth'` to resume training from the last checkpoint. 
* The `anisotropy_factor` attribute should be set equal to the ratio of the sizes of the voxel in z to the size of the voxel in x (or y). Here, we assume equal resolution of the image in x or y, but the code can accommodate unequal resolutions in x and y dimensions as well.



In [None]:
n_epochs = 25
display = False
display_embedding = False
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = None
anisotropy_factor = 5.88 

* The `n_z`, `n_y` and `n_x` attributes essentially allow the possibility of training on a low resolution, downsampled image but later evaluating on the full image. Check out <b>[this](https://github.com/juglab/EmbedSeg/wiki/Use-Labkit-to-prepare-instance-masks)</b> page on how to set these values.

In the cell after this one, a `configs` dictionary is generated from the parameters specified here!
<a id='resume'></a>

In [None]:
n_z = 72
n_y = 408
n_x = 408

### Create the  `configs` dictionary 

In [None]:
configs = create_configs(n_epochs = n_epochs,
                         display = display, 
                         display_embedding = display_embedding,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         n_z = n_z,
                         n_y = n_y, 
                         n_x = n_x,
                         anisotropy_factor = anisotropy_factor)

### Begin training!

Executing the next cell would begin the training. 

If `display` attribute was set to `True` above, then you would see the network predictions at every $n^{th}$ step (equals 5, by default) on training and validation images. 

Going clockwise from top-left is 

    * the raw-image which needs to be segmented, 
    * the corresponding ground truth instance mask, 
    * the network predicted instance mask, and 
    * (if display_embedding = True) from each object instance, 5 pixels are randomly selected (indicated with `+`), their embeddings are plotted (indicated with `.`) and the predicted margin for that object is visualized as an axis-aligned ellipse centred on the ground-truth - center (indicated with `x`)  for that object


In [None]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

<div class="alert alert-block alert-warning"> 
  Common causes for errors during training, may include : <br>
    1. Not having <b>center images</b> for  <b>both</b> train and val directories  <br>
    2. <b>Mismatch</b> between type of center-images saved in <b>01-data.ipynb</b> and the type of center chosen in this notebook (see the <b><a href="#center"> center</a></b> parameter in the third code cell in this notebook)   <br>
    3. In case of resuming training from a previous checkpoint, please ensure that the model weights are read from the correct directory, using the <b><a href="#resume"> resume_path</a></b> parameter. Additionally, please ensure that the <b>save_dir</b> parameter for saving the model weights points to a relevant directory. 
</div>