In [1]:
import numpy as np
import os
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
import torch
from matplotlib.colors import ListedColormap
import json
#%matplotlib tk # comment this line, if running in the headless mode 

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.
<a id='center'></a>

In [2]:
data_dir = 'crops'
project_name = 'Mouse-Skull-Nuclei-CBG'
center = 'medoid' # 'centroid', 'medoid'

print("Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(project_name, data_dir))

Project Name chosen as : Mouse-Skull-Nuclei-CBG. 
Train-Val images-masks-center-images will be accessed from : crops


In [3]:
try:
    assert center in {'medoid', 'centroid'}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "centroid"}', 42)
    raise

Spatial Embedding Location chosen as : medoid


### Obtain properties of the dataset 

Here, we read the `dataset.json` file prepared in the `01-data` notebook previously.

In [4]:
if os.path.isfile('data_properties.json'): 
    with open('data_properties.json') as json_file:
        data = json.load(json_file)
        one_hot, data_type, foreground_weight, n_z, n_y, n_x, pixel_size_z_microns, pixel_size_x_microns = data['one_hot'], data['data_type'], float(data['foreground_weight']), int(data['n_z']), int(data['n_y']), int(data['n_x']), float(data['pixel_size_z_microns']), float(data['pixel_size_x_microns'])

In [5]:
normalization_factor = 65535 if data_type=='16-bit' else 255

### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. For the `Mouse-Skull-Nuclei-CBG` dataset, we obtain ~ 128 crops, since this is a small number, we set `train_size` to double the size 256. 
* The effective batch size is determined as a product of the attributes `train_batch_size` and `virtual_train_batch_multiplier`. For example, one could set a small `batch_size` say equal to 2 (to fit in one's GPU memory), and a large `virtual_train_batch_multiplier` say equal to 8, to get an effective batch size equal to 16. 

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [6]:
train_size = 256
train_batch_size = 2 
virtual_train_batch_multiplier = 8 

### Create the `train_dataset_dict` dictionary  

In [7]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         center = center, 
                                         size = train_size, 
                                         batch_size = train_batch_size, 
                                         virtual_batch_multiplier = virtual_train_batch_multiplier, 
                                         normalization_factor= normalization_factor,
                                         type = 'train',
                                         name = '3d')

`train_dataset_dict` dictionary successfully created with: 
 -- train images accessed from crops/Mouse-Skull-Nuclei-CBG/train/images, 
 -- number of images per epoch equal to 256, 
 -- batch size set at 2, 
 -- virtual batch multiplier set as 8, 
 -- normalization_factor set as 65535, 
 -- one_hot set as False, 


### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to the total number of validation image crops. For example, for the `Mouse-Skull-NucleiCBG` dataset, we notice ~22 validation crops, since this is a small number, hence we set `val_size = 176`.
* The effective batch size is determined as a product of the attributes `val_batch_size` and `virtual_val_batch_multiplier`. Here at times, it is okay to set a higher effective batch size for the validation dataset than the train dataset, since evaluating on validation data consumes lesser GPU memory.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!

In [8]:
val_size = 176
val_batch_size = 16
virtual_val_batch_multiplier = 1

### Create the `val_dataset_dict` dictionary

In [9]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                       project_name = project_name, 
                                       center = center, 
                                       size = val_size, 
                                       batch_size = val_batch_size, 
                                       virtual_batch_multiplier = virtual_val_batch_multiplier,
                                       normalization_factor= normalization_factor,
                                       type ='val',
                                       name ='3d')

`val_dataset_dict` dictionary successfully created with: 
 -- val images accessed from crops/Mouse-Skull-Nuclei-CBG/val/images, 
 -- number of images per epoch equal to 176, 
 -- batch size set at 16, 
 -- virtual batch multiplier set as 1, 
 -- normalization_factor set as 65535, 
 -- one_hot set as False, 


### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to the number of channels in the input images. 
* Set the `num_classes = [6, 1]` for `3d` training and `num_classes = [4, 1]` for `2d` training
<br>(here, 6 implies the offsets and bandwidths in x, y and z dimensions and 1 implies the `seediness` value per pixel)

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [10]:
input_channels = 1
num_classes = [6, 1] 

### Create the `model_dict` dictionary

In [11]:
model_dict = create_model_dict(input_channels = input_channels,
                              num_classes = num_classes,
                              name = '3d')

`model_dict` dictionary successfully created with: 
 -- num of classes equal to 1, 
 -- input channels equal to [6, 1], 
 -- name equal to 3d


### Create the `loss_dict` dictionary

In [12]:
loss_dict = create_loss_dict(n_sigma = 3, foreground_weight = foreground_weight)

`loss_dict` dictionary successfully created with: 
 -- foreground weight equal to 3.174, 
 -- w_inst equal to 1, 
 -- w_var equal to 10, 
 -- w_seed equal to 1


### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for reasonable results, you should atleast train for longer than 50 epochs.
* The `display` attribute, if set to True, allows you to see the network predictions as the training proceeds. 
* The `display_embedding` attribute, if set to True, allows you to see some sample embedding as the training proceeds. Setting this to False leads to faster training times.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './experiment/Mouse-Organoid-Cells-CBG-demo/checkpoint.pth'` to resume training from the last checkpoint.


In [13]:
n_epochs = 200
display = False
display_embedding = False
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = save_dir+'/checkpoint.pth'

In the cell after this one, a `configs` dictionary is generated from the parameters specified here!
<a id='resume'></a>

### Create the  `configs` dictionary 

In [14]:
configs = create_configs(n_epochs = n_epochs,
                         display = display, 
                         display_embedding = display_embedding,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         n_z = n_z,
                         n_y = n_y, 
                         n_x = n_x,
                         anisotropy_factor = pixel_size_z_microns/pixel_size_x_microns)

`configs` dictionary successfully created with: 
 -- n_epochs equal to 200, 
 -- display equal to False, 
 -- save_dir equal to experiment/Mouse-Skull-Nuclei-CBG-demo, 
 -- n_z equal to 128, 
 -- n_y equal to 512, 
 -- n_x equal to 512, 
 -- one_hot equal to False, 


### Begin training!

Executing the next cell would begin the training. 

If `display` attribute was set to `True` above, then you would see the network predictions at every $n^{th}$ step (equals 5, by default) on training and validation images. 

Going clockwise from top-left is 

    * the raw-image which needs to be segmented, 
    * the corresponding ground truth instance mask, 
    * the network predicted instance mask, and 
    * (if display_embedding = True) from each object instance, 5 pixels are randomly selected (indicated with `+`), their embeddings are plotted (indicated with `.`) and the predicted margin for that object is visualized as an axis-aligned ellipse centred on the ground-truth - center (indicated with `x`)  for that object


In [15]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

3-D `train` dataloader created! Accessing data from crops/Mouse-Skull-Nuclei-CBG/train/
Number of images in `train` directory is 128
Number of instances in `train` directory is 128
Number of center images in `train` directory is 128
*************************
3-D `val` dataloader created! Accessing data from crops/Mouse-Skull-Nuclei-CBG/val/
Number of images in `val` directory is 22
Number of instances in `val` directory is 22
Number of center images in `val` directory is 22
*************************
Creating branched erfnet 3d with [6, 1] classes
initialize last layer with size:  torch.Size([16, 6, 2, 2, 2])
Created spatial emb loss function with: n_sigma: 3, foreground_weight: 3.173831196077842
*************************
Created logger with keys:  ('train', 'val', 'iou')
Resuming model from experiment/Mouse-Skull-Nuclei-CBG-demo/checkpoint.pth


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 120
learning rate: 0.00021919164527704348


100%|██████████| 128/128 [02:39<00:00,  1.25s/it]
100%|██████████| 11/11 [00:53<00:00,  4.88s/it]


===> train loss: 0.45
===> val loss: 0.48, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 121
learning rate: 0.0002167241909659091


100%|██████████| 128/128 [02:35<00:00,  1.21s/it]
100%|██████████| 11/11 [00:51<00:00,  4.64s/it]


===> train loss: 0.45
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 122
learning rate: 0.00021425361122954587


100%|██████████| 128/128 [02:34<00:00,  1.21s/it]
100%|██████████| 11/11 [00:50<00:00,  4.62s/it]


===> train loss: 0.45
===> val loss: 0.48, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 123
learning rate: 0.00021177986196077483


100%|██████████| 128/128 [02:32<00:00,  1.19s/it]
100%|██████████| 11/11 [00:50<00:00,  4.58s/it]


===> train loss: 0.43
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 124
learning rate: 0.00020930289784861406


100%|██████████| 128/128 [02:32<00:00,  1.19s/it]
100%|██████████| 11/11 [00:50<00:00,  4.59s/it]


===> train loss: 0.44
===> val loss: 0.49, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 125
learning rate: 0.0002068226723291381


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.71s/it]


===> train loss: 0.44
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 126
learning rate: 0.00020433913753364945


100%|██████████| 128/128 [02:35<00:00,  1.21s/it]
100%|██████████| 11/11 [00:51<00:00,  4.72s/it]


===> train loss: 0.44
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 127
learning rate: 0.00020185224423397574


100%|██████████| 128/128 [02:34<00:00,  1.21s/it]
100%|██████████| 11/11 [00:51<00:00,  4.66s/it]


===> train loss: 0.45
===> val loss: 0.46, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 128
learning rate: 0.0001993619417846922


100%|██████████| 128/128 [02:32<00:00,  1.19s/it]
100%|██████████| 11/11 [00:51<00:00,  4.70s/it]


===> train loss: 0.45
===> val loss: 0.47, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 129
learning rate: 0.0001968681780620511


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:50<00:00,  4.57s/it]


===> train loss: 0.45
===> val loss: 0.48, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 130
learning rate: 0.00019437089939938174


100%|██████████| 128/128 [02:31<00:00,  1.18s/it]
100%|██████████| 11/11 [00:51<00:00,  4.66s/it]


===> train loss: 0.44
===> val loss: 0.48, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 131
learning rate: 0.0001918700505187031


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.71s/it]


===> train loss: 0.45
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 132
learning rate: 0.00018936557445826965


100%|██████████| 128/128 [02:31<00:00,  1.18s/it]
100%|██████████| 11/11 [00:50<00:00,  4.61s/it]


===> train loss: 0.45
===> val loss: 0.48, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 133
learning rate: 0.00018685741249574434


100%|██████████| 128/128 [02:34<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.64s/it]


===> train loss: 0.44
===> val loss: 0.50, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 134
learning rate: 0.00018434550406666597


100%|██████████| 128/128 [02:37<00:00,  1.23s/it]
100%|██████████| 11/11 [00:52<00:00,  4.75s/it]


===> train loss: 0.45
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 135
learning rate: 0.0001818297866778471


100%|██████████| 128/128 [02:34<00:00,  1.20s/it]
100%|██████████| 11/11 [00:50<00:00,  4.63s/it]


===> train loss: 0.45
===> val loss: 0.46, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 136
learning rate: 0.00017931019581530385


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:49<00:00,  4.48s/it]


===> train loss: 0.44
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 137
learning rate: 0.00017678666484628193


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.65s/it]


===> train loss: 0.44
===> val loss: 0.45, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 138
learning rate: 0.0001742591249149002


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:50<00:00,  4.64s/it]


===> train loss: 0.44
===> val loss: 0.48, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 139
learning rate: 0.00017172750483088596


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.66s/it]


===> train loss: 0.44
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 140
learning rate: 0.00016919173095082495


100%|██████████| 128/128 [02:34<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.69s/it]


===> train loss: 0.44
===> val loss: 0.50, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 141
learning rate: 0.00016665172705128707


100%|██████████| 128/128 [02:34<00:00,  1.21s/it]
100%|██████████| 11/11 [00:51<00:00,  4.65s/it]


===> train loss: 0.43
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 142
learning rate: 0.00016410741419312688


100%|██████████| 128/128 [02:32<00:00,  1.19s/it]
100%|██████████| 11/11 [00:50<00:00,  4.57s/it]


===> train loss: 0.43
===> val loss: 0.49, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 143
learning rate: 0.00016155871057618057


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.66s/it]


===> train loss: 0.44
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 144
learning rate: 0.00015900553138349974


100%|██████████| 128/128 [02:32<00:00,  1.19s/it]
100%|██████████| 11/11 [00:51<00:00,  4.64s/it]


===> train loss: 0.42
===> val loss: 0.46, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 145
learning rate: 0.00015644778861416783


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:50<00:00,  4.61s/it]


===> train loss: 0.42
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 146
learning rate: 0.00015388539090363925


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:50<00:00,  4.63s/it]


===> train loss: 0.44
===> val loss: 0.47, val iou: 0.80
=> saving checkpoint
Starting epoch 147
learning rate: 0.00015131824333042122


100%|██████████| 128/128 [02:34<00:00,  1.20s/it]
100%|██████████| 11/11 [00:51<00:00,  4.65s/it]


===> train loss: 0.43
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 148
learning rate: 0.00014874624720778262


100%|██████████| 128/128 [02:33<00:00,  1.20s/it]
100%|██████████| 11/11 [00:50<00:00,  4.59s/it]


===> train loss: 0.43
===> val loss: 0.47, val iou: 0.81
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 149
learning rate: 0.00014616929985901934


100%|██████████| 128/128 [02:34<00:00,  1.21s/it]
100%|██████████| 11/11 [00:51<00:00,  4.69s/it]


===> train loss: 0.43
===> val loss: 0.48, val iou: 0.80
=> saving checkpoint


  0%|          | 0/128 [00:00<?, ?it/s]

Starting epoch 150
learning rate: 0.00014358729437462936


 72%|███████▏  | 92/128 [01:51<00:43,  1.21s/it]


KeyboardInterrupt: 

<div class="alert alert-block alert-warning"> 
  Common causes for errors during training, may include : <br>
    1. Not having <b>center images</b> for  <b>both</b> train and val directories  <br>
    2. <b>Mismatch</b> between type of center-images saved in <b>01-data.ipynb</b> and the type of center chosen in this notebook (see the <b><a href="#center"> center</a></b> parameter in the third code cell in this notebook)   <br>
    3. In case of resuming training from a previous checkpoint, please ensure that the model weights are read from the correct directory, using the <b><a href="#resume"> resume_path</a></b> parameter. Additionally, please ensure that the <b>save_dir</b> parameter for saving the model weights points to a relevant directory. 
</div>