In [None]:
import numpy as np
import os
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
import torch
%matplotlib tk

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.
<a id='center'></a>
  

In [None]:
data_dir = 'crops'
project_name = 'bbbc010-2012'
center = 'medoid'

print("Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(project_name, data_dir))

In [None]:
try:
    assert center in {'medoid', 'approximate-medoid', 'centroid'}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "approximate-medoid", "centroid"}', 42)
    raise

### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. For the `bbbc010-2012` dataset, we obtain ~600 crops, so we set `train_size` to 600. 
* The effective batch size is determined as a product of the attributes `train_batch_size` and `virtual_train_batch_multiplier`. For one-hot encoded instances (such as the `bbbc010-2012` dataset), **`train_batch_size` should be set equal to 1**. 
* The `normalization_factor` attribute normalizes the raw images to always be between 0 and 1. For 8- bit images, please set `normalization_factor = 255`, for 16-bit images, please set `normalization_factor = 65535`.
* The `one_hot` attribute should be set to True if the instance image is present in an one-hot encoded style (i.e. object instance is encoded as 1 in its own individual image slice) and False if the instance image is the same dimensions as the raw-image. 

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [None]:
train_size = 600
train_batch_size = 1 
virtual_train_batch_multiplier = 1 
normalization_factor = 65535
one_hot = True

### Create the `train_dataset_dict` dictionary  

In [None]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         center = center, 
                                         size = train_size, 
                                         batch_size = train_batch_size, 
                                         virtual_batch_multiplier = virtual_train_batch_multiplier, 
                                         normalization_factor= normalization_factor,
                                         one_hot = one_hot,
                                         type = 'train')

### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to an integral multiple of total number of validation image crops. For example, for the `bbbc010-2012` dataset, we notice ~100 validation crops, but we set `val_size = 800` to obtain the optimal results across the 8-fold augmentation of these 100 crops. It would also be fine to set `val_size = 100`.
* The effective batch size is determined as a product of the attributes `val_batch_size` and `virtual_val_batch_multiplier`. For one-hot encoded instances (such as the `bbbc010-2012` dataset), **`val_batch_size` should be set equal to 1**. 
* The `normalization_factor` attribute normalizes the raw images to always be between 0 and 1. For 8- bit images, please set `normalization_factor = 255`, for 16-bit images, please set `normalization_factor = 65535`. Please note that the `normalization_factor` should be the same for both train and validation images.
* The `one_hot` attribute should be set to True if the instance image is present in an one-hot encoded style (i.e. object instance is encoded as 1 in its own individual image slice) and False if the instance image is the same dimensions as the raw-image. 

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!



In [None]:
val_size = 800
val_batch_size = 1
virtual_val_batch_multiplier = 1
normalization_factor = 65535
one_hot = True

### Create the `val_dataset_dict` dictionary

In [None]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                       project_name = project_name, 
                                       center = center, 
                                       size = val_size, 
                                       batch_size = val_batch_size, 
                                       virtual_batch_multiplier = virtual_val_batch_multiplier,
                                       normalization_factor= normalization_factor,
                                       one_hot = one_hot,
                                       type ='val',)

### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to the number of channels in the input images. 

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [None]:
input_channels = 1

### Create the `model_dict` dictionary

In [None]:
model_dict = create_model_dict(input_channels = input_channels)

### Create the `loss_dict` dictionary

In [None]:
loss_dict = create_loss_dict()

### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for good results on `bbbbc_010` dataset with the configurations above, you should train for longer than 100 epochs.
* The `display` attribute, if set to True, allows you to see the network predictions as the training proceeds. 
* The `display_embedding` attribute, if set to True, allows you to see some sample embedding as the training proceeds. Setting this to False leads to faster training times.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './experiment/Jan-24-2021-bbbc010-2012/checkpoint.pth'` to resume training from the last checkpoint. 
* The `grid_y` and `grid_x` attributes should be set to equal or more than the dimensions of the largest evaluation image one wishes to test the trained model on. (Here, we assume that the pixel sizes in the height and width dimension are equal). If you are unsure of the evaluation image sizes at this stage, best to leave these attributes as they are.  
* The `one_hot` attribute should be set to True if the instance image is present in an one-hot encoded style (i.e. object instance is encoded as 1 in its own individual image slice) and False if the instance image is the same dimensions as the raw-image. 



In the cell after this one, a `configs` dictionary is generated from the parameters specified here!
<a id='resume'></a>

In [None]:
n_epochs = 200
one_hot = True
display = True
display_embedding = False
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = None
grid_y = 1024
grid_x = 1024

### Create the  `configs` dictionary 

In [None]:
configs = create_configs(n_epochs = n_epochs,
                         one_hot = one_hot,
                         display = display, 
                         display_embedding = display_embedding,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         grid_y = grid_y, 
                         grid_x = grid_x,)

### Begin training!

Executing the next cell would begin the training. 

If `display` attribute was set to `True` above, then you would see the network predictions at every $n^{th}$ step (equals 5, by default) on training and validation images. 

Going clockwise from top-left is 

    * the raw-image which needs to be segmented, 
    * the corresponding ground truth instance mask, 
    * the network predicted instance mask, and 
    * (if display_embedding = True) from each object instance, 5 pixels are randomly selected (indicated with `+`), their embeddings are plotted (indicated with `.`) and the predicted margin for that object is visualized as an axis-aligned ellipse centred on the ground-truth - center (indicated with `x`)  for that object


In [None]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

<div class="alert alert-block alert-warning"> 
  Common causes for errors during training, may include : <br>
    1. Not having <b>center images</b> for  <b>both</b> train and val directories  <br>
    2. <b>Mismatch</b> between type of center-images saved in <b>01-data.ipynb</b> and the type of center chosen in this notebook (see the <b><a href="#center"> center</a></b> parameter in the third code cell in this notebook)   <br>
    3. In case of resuming training from a previous checkpoint, please ensure that the model weights are read from the correct directory, using the <b><a href="#resume"> resume_path</a></b> parameter. Additionally, please ensure that the <b>save_dir</b> parameter for saving the model weights points to a relevant directory. 
</div>