In [1]:
import numpy as np
import os
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
import torch
%matplotlib tk

In [2]:
from datetime import date
today = date.today()
today = today.strftime("%b-%d-%Y")

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.

In [3]:
data_dir = 'crops'
project_name = 'basel-2020'
center = 'centroid'

print("Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(project_name, data_dir))

Project Name chosen as : basel-2020. 
Train-Val images-masks-center-images will be accessed from : crops


In [4]:
try:
    assert center in {'medoid', 'approximate-medoid', 'centroid'}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "approximate-medoid", "centroid"}', 42)
    raise

Spatial Embedding Location chosen as : centroid


### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. For the `basel-2020` dataset, we obtain ~60000 crops, which leads to a slow computation. Hence, we go for something quicker and set `train_size` to 20000. 
* The effective batch size is determined as a product of the attributes `train_batch_size` and `virtual_train_batch_multiplier`. For example, one could set a small `batch_size` say equal to 2 (to fit in one's GPU memory), and a large `virtual_train_batch_multiplier` say equal to 8, to get an effective batch size equal to 16. 
* The `normalization_factor` attribute normalizes the raw images to always be between 0 and 1. Since the phase contrast image intensities are already between 0 and 1, we set `normalization_factor`=1
* The `one_hot` attribute should be set to True if the instance image is present in an one-hot encoded style (i.e. object instance is encoded as 1 in its own individual image slice) and False if the instance image is the same dimensions as the raw-image. 

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [5]:
train_size = 20000
train_batch_size = 128 
virtual_train_batch_multiplier = 1 
normalization_factor = 1
one_hot = False

### Create the `train_dataset_dict` dictionary  

In [6]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         center = center, 
                                         size = train_size, 
                                         batch_size = train_batch_size, 
                                         virtual_batch_multiplier = virtual_train_batch_multiplier, 
                                         normalization_factor= normalization_factor,
                                         one_hot = one_hot,
                                         type = 'train')

`train_dataset_dict` dictionary successfully created with: 
 -- train images accessed from crops/basel-2020/train/images, 
 -- number of images per epoch equal to 20000, 
 -- batch size set at 128, 
 -- virtual batch multiplier set as 1, 
 -- normalization_factor set as 1, 
 -- one_hot set as False, 


### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to the total number of validation image crops. For example, for the `basel-2020` dataset, we notice ~10000 validation crops, hence we set `val_size = 10000`.
* The effective batch size is determined as a product of the attributes `val_batch_size` and `virtual_val_batch_multiplier`. Here at times, it is okay to set a higher effective batch size for the validation dataset than the train dataset, since evaluating on validation data consumes lesser GPU memory.
* The `normalization_factor` attribute normalizes the raw images to always be between 0 and 1. 
* The `one_hot` attribute should be set to True if the instance image is present in an one-hot encoded style (i.e. object instance is encoded as 1 in its own individual image slice) and False if the instance image is the same dimensions as the raw-image. 

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!



In [7]:
val_size = 10000
val_batch_size = 128
virtual_val_batch_multiplier = 1
normalization_factor = 1
one_hot = False

### Create the `val_dataset_dict` dictionary

In [8]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                       project_name = project_name, 
                                       center = center, 
                                       size = val_size, 
                                       batch_size = val_batch_size, 
                                       virtual_batch_multiplier = virtual_val_batch_multiplier,
                                       normalization_factor= normalization_factor,
                                       one_hot = one_hot,
                                       type ='val',)

`val_dataset_dict` dictionary successfully created with: 
 -- val images accessed from crops/basel-2020/val/images, 
 -- number of images per epoch equal to 10000, 
 -- batch size set at 128, 
 -- virtual batch multiplier set as 1, 
 -- normalization_factor set as 1, 
 -- one_hot set as False, 


### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to 1 for gray-scale and 3 for `RGB` input images. 

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [9]:
input_channels = 1

### Create the `model_dict` dictionary

In [10]:
model_dict = create_model_dict(input_channels = input_channels)

`model_dict` dictionary successfully created with: 
 -- num of classes equal to 1, 
 -- input channels equal to [4, 1], 
 -- name equal to branched_erfnet


### Create the `loss_dict` dictionary

In [11]:
loss_dict = create_loss_dict()

`loss_dict` dictionary successfully created with: 
 -- foreground weight equal to 10, 
 -- w_inst equal to 1, 
 -- w_var equal to 10, 
 -- w_seed equal to 1


### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for reasonable results, you should atleast train for longer than 25 epochs.
* The `display` attribute, if set to True, allows you to see the network predictions as the training proceeds. If you would like the `display` to not be `inline`, you could add a new cell after this one and say `%matplotlib tk`, which would pop the visualization as a separate window. 
* The `display_embedding` attribute, if set to True, allows you to see some sample embedding as the training proceeds. Setting this to False leads to faster training times.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './exp/basel-2020/checkpoint.pth'` to resume training from the last checkpoint. 
* The `grid_y` and `grid_x` attributes should be set to equal or more than the dimensions of the largest evaluation image one wishes to test the trained model on. (Here, we assume that the pixel sizes in the height and width dimension are equal). If you are unsure of the evaluation image sizes at this stage, best to leave these attributes as they are.  
* The `one_hot` attribute should be set to True if the instance image is present in an one-hot encoded style (i.e. object instance is encoded as 1 in its own individual image slice) and False if the instance image is the same dimensions as the raw-image. 



In the cell after this one, a `configs` dictionary is generated from the parameters specified here!

In [12]:
n_epochs = 200
one_hot = False
display = False
display_embedding = False
save_dir = os.path.join('experiment', project_name+'-'+today)
resume_path  = None
grid_y = 1024
grid_x = 1024

### Create the  `configs` dictionary 

In [13]:
configs = create_configs(n_epochs = n_epochs,
                         one_hot = one_hot,
                         display = display, 
                         display_embedding = display_embedding,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         grid_y = grid_y, 
                         grid_x = grid_x,)

`configs` dictionary successfully created with: 
 -- n_epochs equal to 200, 
 -- display equal to False, 
 -- save_dir equal to experiment/basel-2020-Jan-03-2021, 
 -- grid_y equal to 1024, 
 -- grid_x equal to 1024, 
 -- one_hot equal to False, 


### Begin training!

Executing the next cell would begin the training. 

If `display` attribute was set to `True` above, then you would see the network predictions at every $n^{th}$ step (equals 5, by default) on training and validation images. 

Going clockwise from top-left is 

    * the raw-image which needs to be segmented, 
    * the corresponding ground truth instance mask, 
    * the network predicted instance mask, and 
    * (if display_embedding = True) from each object instance, 5 pixels are randomly selected (indicated with `+`), their embeddings are plotted (indicated with `.`) and the predicted margin for that object is visualized as an axis-aligned ellipse centred on the ground-truth - center (indicated with `x`)  for that object


<div class="alert alert-warning">
Training upto 200 epochs would take a while. Each epoch roughly takes around 10 minutes. <br>
One can get decent enough results in 10 epochs though!
</div>

In [None]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

2-D `train` dataloader created! Accessing data from crops/basel-2020/train/
Number of images in `train` directory is 67142
Number of instances in `train` directory is 67142
Number of center images in `train` directory is 67142
*************************
2-D `val` dataloader created! Accessing data from crops/basel-2020/val/
Number of images in `val` directory is 10410
Number of instances in `val` directory is 10410
Number of center images in `val` directory is 10410
*************************
Creating branched erfnet with [4, 1] classes
Initialize last layer with size:  torch.Size([16, 4, 2, 2])
*************************


  0%|          | 0/156 [00:00<?, ?it/s]

Created spatial emb loss function with: n_sigma: 2, foreground_weight: 10
*************************
Created logger with keys:  ('train', 'val', 'iou')
Starting epoch 0
learning rate: 0.0005


100%|██████████| 156/156 [07:40<00:00,  2.95s/it]
100%|██████████| 78/78 [01:55<00:00,  1.48s/it]


===> train loss: 0.85
===> val loss: 0.63, val iou: 0.82
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 1
learning rate: 0.0004977494364660346


100%|██████████| 156/156 [09:30<00:00,  3.66s/it]
100%|██████████| 78/78 [02:25<00:00,  1.87s/it]


===> train loss: 0.44
===> val loss: 0.42, val iou: 0.89
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 2
learning rate: 0.0004954977417064171


100%|██████████| 156/156 [11:56<00:00,  4.60s/it]
100%|██████████| 78/78 [02:50<00:00,  2.19s/it]


===> train loss: 0.32
===> val loss: 0.31, val iou: 0.90
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 3
learning rate: 0.0004932449094349202


100%|██████████| 156/156 [08:36<00:00,  3.31s/it]
100%|██████████| 78/78 [01:58<00:00,  1.52s/it]


===> train loss: 0.28
===> val loss: 0.29, val iou: 0.90
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 4
learning rate: 0.0004909909332982877


100%|██████████| 156/156 [08:43<00:00,  3.36s/it]
100%|██████████| 78/78 [01:57<00:00,  1.51s/it]


===> train loss: 0.25
===> val loss: 0.25, val iou: 0.91
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 5
learning rate: 0.0004887358068751748


100%|██████████| 156/156 [08:58<00:00,  3.45s/it]
100%|██████████| 78/78 [01:57<00:00,  1.51s/it]


===> train loss: 0.24
===> val loss: 0.23, val iou: 0.92
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 6
learning rate: 0.0004864795236750653


100%|██████████| 156/156 [09:06<00:00,  3.50s/it]
100%|██████████| 78/78 [02:15<00:00,  1.73s/it]


===> train loss: 0.23
===> val loss: 0.23, val iou: 0.92
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 7
learning rate: 0.00048422207713716544


100%|██████████| 156/156 [11:37<00:00,  4.47s/it]
100%|██████████| 78/78 [02:51<00:00,  2.20s/it]


===> train loss: 0.21
===> val loss: 0.21, val iou: 0.92
=> saving checkpoint


  0%|          | 0/156 [00:00<?, ?it/s]

Starting epoch 8
learning rate: 0.00048196346062927547


 35%|███▌      | 55/156 [03:28<07:23,  4.39s/it]