In [1]:
import numpy as np
import os
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
import torch
from matplotlib.colors import ListedColormap
import json
# comment the following line, if running in the headless mode
#%matplotlib tk 

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.
<a id='center'></a>

In [2]:
data_dir = 'crops'
project_name = 'Platynereis-ISH-Nuclei-CBG'
center = 'medoid' # 'centroid', 'approximate-medoid', 'medoid'

print("Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(project_name, data_dir))

Project Name chosen as : Platynereis-ISH-Nuclei-CBG. 
Train-Val images-masks-center-images will be accessed from : crops


In [3]:
try:
    assert center in {'medoid', 'approximate-medoid', 'centroid'}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "approximate-medoid", "centroid"}', 42)
    raise

Spatial Embedding Location chosen as : medoid


### Obtain properties of the dataset 

Here, we read the `dataset.json` file prepared in the `01-data` notebook previously.

In [4]:
if os.path.isfile('data_properties.json'): 
    with open('data_properties.json') as json_file:
        data = json.load(json_file)
        one_hot, data_type, foreground_weight, n_y, n_x = data['one_hot'], data['data_type'], int(data['foreground_weight']), int(data['n_y']), int(data['n_x'])

### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. 
* The effective batch size is determined as a product of the attributes `train_batch_size` and `virtual_train_batch_multiplier`. For example, one could set a small `batch_size` say equal to 2 (to fit in one's GPU memory), and a large `virtual_train_batch_multiplier` say equal to 8, to get an effective batch size equal to 16. 
In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [5]:
train_size = len(os.listdir(os.path.join(data_dir, project_name, 'train', 'images')))
train_batch_size = 16 
virtual_train_batch_multiplier = 1 

### Create the `train_dataset_dict` dictionary  

In [6]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         center = center, 
                                         size = train_size, 
                                         batch_size = train_batch_size, 
                                         virtual_batch_multiplier = virtual_train_batch_multiplier, 
                                         one_hot = one_hot,
                                         type = 'train')

`train_dataset_dict` dictionary successfully created with: 
 -- train images accessed from crops/Platynereis-ISH-Nuclei-CBG/train/images, 
 -- number of images per epoch equal to 2127, 
 -- batch size set at 16, 
 -- virtual batch multiplier set as 1, 
 -- one_hot set as False, 


### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to the total number of validation image crops. For example, for the `dsb-2018` dataset, we notice ~2600 validation crops.
* The effective batch size is determined as a product of the attributes `val_batch_size` and `virtual_val_batch_multiplier`. Here at times, it is okay to set a higher effective batch size for the validation dataset than the train dataset, since evaluating on validation data consumes lesser GPU memory.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!



In [7]:
val_size = len(os.listdir(os.path.join(data_dir, project_name, 'val', 'images')))
val_batch_size = 16
virtual_val_batch_multiplier = 1

### Create the `val_dataset_dict` dictionary

In [8]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                       project_name = project_name, 
                                       center = center, 
                                       size = val_size, 
                                       batch_size = val_batch_size, 
                                       virtual_batch_multiplier = virtual_val_batch_multiplier,
                                       one_hot = one_hot,
                                       type ='val',)

`val_dataset_dict` dictionary successfully created with: 
 -- val images accessed from crops/Platynereis-ISH-Nuclei-CBG/val/images, 
 -- number of images per epoch equal to 375, 
 -- batch size set at 16, 
 -- virtual batch multiplier set as 1, 
 -- one_hot set as False, 


### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to the number of channels in the input images. 

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [9]:
input_channels = 1

### Create the `model_dict` dictionary

In [10]:
model_dict = create_model_dict(input_channels = input_channels)

`model_dict` dictionary successfully created with: 
 -- num of classes equal to 1, 
 -- input channels equal to [4, 1], 
 -- name equal to branched_erfnet


### Create the `loss_dict` dictionary

In [11]:
loss_dict = create_loss_dict()

`loss_dict` dictionary successfully created with: 
 -- foreground weight equal to 10.000, 
 -- w_inst equal to 1, 
 -- w_var equal to 10, 
 -- w_seed equal to 1


### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for reasonable results, you should atleast train for longer than 50 epochs.
* The `display` attribute, if set to True, allows you to see the network predictions as the training proceeds. 
* The `display_embedding` attribute, if set to True, allows you to see some sample embedding as the training proceeds. Setting this to False leads to faster training times.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './experiment/dsb-2018-demo/checkpoint.pth'` to resume training from the last checkpoint. 

In the cell after this one, a `configs` dictionary is generated from the parameters specified here!
<a id='resume'></a>

In [12]:
n_epochs = 200
display = False
display_embedding = False
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = None

### Create the  `configs` dictionary 

In [13]:
configs = create_configs(n_epochs = n_epochs,
                         one_hot = one_hot,
                         display = display, 
                         display_embedding = display_embedding,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         n_y = n_y, 
                         n_x = n_x,)

`configs` dictionary successfully created with: 
 -- n_epochs equal to 200, 
 -- display equal to False, 
 -- save_dir equal to experiment/Platynereis-ISH-Nuclei-CBG-demo, 
 -- n_z equal to None, 
 -- n_y equal to 648, 
 -- n_x equal to 648, 
 -- one_hot equal to False, 


### Choose a `color map`

Here, we load a `glasbey`-style color map. But other color maps such as `viridis`, `magma` etc would work equally well.

In [14]:
new_cmap = np.load('../../../cmaps/cmap_60.npy')
new_cmap = ListedColormap(new_cmap) # new_cmap = 'magma' would also work! 

### Begin training!

Executing the next cell would begin the training. 

If `display` attribute was set to `True` above, then you would see the network predictions at every $n^{th}$ step (equals 5, by default) on training and validation images. 

Going clockwise from top-left is 

    * the raw-image which needs to be segmented, 
    * the corresponding ground truth instance mask, 
    * the network predicted instance mask, and 
    * (if display_embedding = True) from each object instance, 5 pixels are randomly selected (indicated with `+`), their embeddings are plotted (indicated with `.`) and the predicted margin for that object is visualized as an axis-aligned ellipse centred on the ground-truth - center (indicated with `x`)  for that object


In [15]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs, color_map=new_cmap)

2-D `train` dataloader created! Accessing data from crops/Platynereis-ISH-Nuclei-CBG/train/
Number of images in `train` directory is 2127
Number of instances in `train` directory is 2127
Number of center images in `train` directory is 2127
*************************
2-D `val` dataloader created! Accessing data from crops/Platynereis-ISH-Nuclei-CBG/val/
Number of images in `val` directory is 375
Number of instances in `val` directory is 375
Number of center images in `val` directory is 375
*************************
Creating Branched Erfnet with [4, 1] outputs
Initialize last layer with size:  torch.Size([16, 4, 2, 2])
*************************


  0%|                                                   | 0/132 [00:00<?, ?it/s]

Created spatial emb loss function with: n_sigma: 2, foreground_weight: 10
*************************
Created logger with keys:  ('train', 'val', 'iou')
Starting epoch 0
learning rate: 0.0005


100%|█████████████████████████████████████████| 132/132 [03:17<00:00,  1.50s/it]
100%|███████████████████████████████████████████| 23/23 [00:25<00:00,  1.09s/it]


===> train loss: 1.11
===> val loss: 0.96, val iou: 0.58
=> saving checkpoint


  0%|                                                   | 0/132 [00:00<?, ?it/s]

Starting epoch 1
learning rate: 0.0004977494364660346


100%|█████████████████████████████████████████| 132/132 [06:03<00:00,  2.76s/it]
100%|███████████████████████████████████████████| 23/23 [00:35<00:00,  1.53s/it]


===> train loss: 0.88
===> val loss: 0.84, val iou: 0.63
=> saving checkpoint


  0%|                                                   | 0/132 [00:00<?, ?it/s]

Starting epoch 2
learning rate: 0.0004954977417064171


100%|█████████████████████████████████████████| 132/132 [06:40<00:00,  3.04s/it]
100%|███████████████████████████████████████████| 23/23 [00:47<00:00,  2.06s/it]


===> train loss: 0.80
===> val loss: 0.78, val iou: 0.65
=> saving checkpoint


  0%|                                                   | 0/132 [00:00<?, ?it/s]

Starting epoch 3
learning rate: 0.0004932449094349202


100%|█████████████████████████████████████████| 132/132 [08:29<00:00,  3.86s/it]
100%|███████████████████████████████████████████| 23/23 [00:47<00:00,  2.08s/it]


===> train loss: 0.75
===> val loss: 0.82, val iou: 0.62
=> saving checkpoint


  0%|                                                   | 0/132 [00:00<?, ?it/s]

Starting epoch 4
learning rate: 0.0004909909332982877


100%|█████████████████████████████████████████| 132/132 [08:38<00:00,  3.93s/it]
100%|███████████████████████████████████████████| 23/23 [00:48<00:00,  2.10s/it]


===> train loss: 0.71
===> val loss: 0.72, val iou: 0.68
=> saving checkpoint


  0%|                                                   | 0/132 [00:00<?, ?it/s]

Starting epoch 5
learning rate: 0.0004887358068751748


100%|█████████████████████████████████████████| 132/132 [09:01<00:00,  4.11s/it]
100%|███████████████████████████████████████████| 23/23 [00:50<00:00,  2.20s/it]


===> train loss: 0.70
===> val loss: 0.70, val iou: 0.68
=> saving checkpoint


  0%|                                                   | 0/132 [00:00<?, ?it/s]

Starting epoch 6
learning rate: 0.0004864795236750653


  0%|                                                   | 0/132 [00:01<?, ?it/s]


KeyboardInterrupt: 

<div class="alert alert-block alert-warning"> 
  Common causes for errors during training, may include : <br>
    1. Not having <b>center images</b> for  <b>both</b> train and val directories  <br>
    2. <b>Mismatch</b> between type of center-images saved in <b>01-data.ipynb</b> and the type of center chosen in this notebook (see the <b><a href="#center"> center</a></b> parameter in the third code cell in this notebook)   <br>
    3. In case of resuming training from a previous checkpoint, please ensure that the model weights are read from the correct directory, using the <b><a href="#resume"> resume_path</a></b> parameter. Additionally, please ensure that the <b>save_dir</b> parameter for saving the model weights points to a relevant directory. 
</div>