In [1]:
import json, os
from scipy.special import perm
import numpy as np
from LineageTracer.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
from LineageTracer.train import begin_training

### Specify the path to train, val crops which we would like to train the network for:

The train-val images, masks will be accessed from the path specified by `data_dir` and `project-name`.

In [2]:
data_dir = 'dicts'
project_name = 'Fluo-N2DH-GOWT1'

print("Project Name chosen as : {}. \nTrain-Val images-masks crops will be accessed from : {}".format(project_name, data_dir))

Project Name chosen as : Fluo-N2DH-GOWT1. 
Train-Val images-masks crops will be accessed from : dicts


### Obtain properties of the dataset



Here, we read the dataset.json file prepared in the 01-data notebook previously.


In [3]:
if os.path.isfile('data_properties.json'): 
    with open('data_properties.json') as json_file:
        data = json.load(json_file)
        object_size = int(data['min_object_size'])
        num_tracklets = int(data['mean_num_tracklets'])
        tracklet_length = int(data['mean_length_tracklet'])
        std_object_size= np.maximum(data['std_object_size_x'], data['std_object_size_y'])

### Specify training dataset-related parameters


Some hints:

* The `train_size` attribute indicates the number of triplets which the network would see in one complete epoch. 
Here, the triplets would include one anchor, one positive sample taken from the same tracklet as the anchor and one negative sample. 

Although the `train_size` can be explicitly specified, we use a heuristic to assess a rough measure of how many triplets would exist.

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [4]:
train_size = int(perm(len(os.listdir(os.path.join(data_dir, 'train'))), 2)/num_tracklets)

In [5]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         size = train_size, 
                                         num_sampled_tracklets = num_tracklets,
                                         num_fg_points = object_size,
                                         std_object_size = std_object_size,
                                         type = 'train')

`train_dataset_dict` dictionary successfully created with: 
 -- train images accessed from dicts/Fluo-N2DH-GOWT1/train/images, 
 -- number of images per epoch equal to 31, 
 -- batch size set at 1, 


### Specify validation dataset-related parameters

Some hints:

* The `val_size` attribute indicates the number of triplets which the network would see in one complete epoch. 
Here, the triplets would include one anchor, one positive sample taken from the same tracklet as the anchor and one negative sample. 

Although the `val_size` can be explicitly specified, we use a heuristic to assess a rough measure of how many triplets would exist.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!

In [6]:
val_size = int(perm(len(os.listdir(os.path.join(data_dir, 'val'))), 2)/num_tracklets)

In [7]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         size = val_size, 
                                         num_sampled_tracklets = num_tracklets,
                                         num_fg_points = object_size,
                                         std_object_size = std_object_size,
                                         type = 'val')

`val_dataset_dict` dictionary successfully created with: 
 -- val images accessed from dicts/Fluo-N2DH-GOWT1/val/images, 
 -- number of images per epoch equal to 21, 
 -- batch size set at 1, 


### Specify model-related parameters

Some hints:

* Set the `num_latent_channels` attribute equal to the dimensionality of the latent embedding. Set this to `zero`. in case you are not using latent embeddings

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!


In [8]:
num_offset_channels = 2
num_intensity_channels = 1
num_latent_channels = 0
num_output_channels = 32

### Create the `model_dict` dictionary

In [13]:
model_dict = create_model_dict(num_fg_points = object_size, 
                               num_latent_channels = num_latent_channels, 
                               num_intensity_channels =num_intensity_channels, 
                               num_output_channels=num_output_channels)


`model_dict` dictionary successfully created with: 
 -- number of offset channels equal to 2, 
 -- number of intensity channels equal to 1, 
 -- number of latent channels equal to 0, 
 -- number of output channels equal to 32


### Create the `loss_dict` dictionary


In [15]:
margin = 0.2

In [16]:
loss_dict = create_loss_dict(margin=margin)

`loss_dict` dictionary successfully created with: 
 -- margin equal to 0.200


### Create the `configs` dictionary

Some hints:

* The `n_epochs` attribute determines how long the training should proceed.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. If one wishes to resume training from a previous checkpoint, they could point resume_path attribute appropriately. For example, one could set `resume_path = './experiment/Fluo-N2DH-GOWT1-demo/checkpoint.pth'` to resume training from the last checkpoint.

In the cell after this one, a `configs` dictionary is generated from the parameters specified here! 

In [17]:
n_epochs = 5
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = None

In [18]:
configs = create_configs(n_epochs = n_epochs,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         )

`configs` dictionary successfully created with: 
 -- n_epochs equal to 5, 
 -- save_dir equal to experiment/Fluo-N2DH-GOWT1-demo, 


Save the properties of the model in a `json` file. This will be accessed later during inference.

In [19]:
train_properties_dir = {}
train_properties_dir['num_offset_channels']=num_offset_channels
train_properties_dir['num_intensity_channels']=num_intensity_channels
train_properties_dir['num_latent_channels'] =num_latent_channels
train_properties_dir['num_output_channels'] =num_output_channels
train_properties_dir['margin']= margin
                             
with open('train_properties.json', 'w') as outfile:
    json.dump(train_properties_dir, outfile)
    print("Train config properties of the `{}` dataset is saved to `train_properties.json`".format(project_name))

Train config properties of the `Fluo-N2DH-GOWT1` dataset is saved to `train_properties.json`


### Begin Training

In [20]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

`train` dataloader created! Accessing data from dicts/train/
Number of tracklets in `train` directory is 27
`val` dataloader created! Accessing data from dicts/val/
Number of tracklets in `val` directory is 22


  0%|                                                    | 0/31 [00:00<?, ?it/s]

Created  loss function with: margin: 0.2
*************************
Created logger with keys:  ('train', 'val')
Starting epoch 0
learning rate: 0.0005


	addmm_(Number beta, Number alpha, Tensor mat1, Tensor mat2)
Consider using one of the following signatures instead:
	addmm_(Tensor mat1, Tensor mat2, *, Number beta, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484772347/work/torch/csrc/utils/python_arg_parser.cpp:1174.)
  embeddings.t())  # stands for add matrix multiplication in place -2*dist + embeddings*embeddings_tranposed
100%|███████████████████████████████████████████| 31/31 [00:11<00:00,  2.81it/s]
100%|███████████████████████████████████████████| 21/21 [00:02<00:00,  8.07it/s]


===> train loss: 0.073
===> val loss: 0.001


  0%|                                                    | 0/31 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 1
learning rate: 0.0004977494364660346


100%|███████████████████████████████████████████| 31/31 [00:09<00:00,  3.36it/s]
100%|███████████████████████████████████████████| 21/21 [00:02<00:00,  9.46it/s]


===> train loss: 0.015
===> val loss: 0.000


  0%|                                                    | 0/31 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 2
learning rate: 0.0004954977417064171


100%|███████████████████████████████████████████| 31/31 [00:09<00:00,  3.39it/s]
100%|███████████████████████████████████████████| 21/21 [00:02<00:00,  8.98it/s]


===> train loss: 0.007
===> val loss: 0.000


  0%|                                                    | 0/31 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 3
learning rate: 0.0004932449094349202


100%|███████████████████████████████████████████| 31/31 [00:09<00:00,  3.22it/s]
100%|███████████████████████████████████████████| 21/21 [00:02<00:00,  9.37it/s]


===> train loss: 0.004
===> val loss: 0.000


  0%|                                                    | 0/31 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 4
learning rate: 0.0004909909332982877


100%|███████████████████████████████████████████| 31/31 [00:09<00:00,  3.41it/s]
100%|███████████████████████████████████████████| 21/21 [00:02<00:00,  9.43it/s]


===> train loss: 0.005
===> val loss: 0.000
=> saving checkpoint


<Figure size 640x480 with 0 Axes>