In [1]:
import json, os
from scipy.special import perm
import numpy as np
from LineageTracer.utils.create_dicts import create_dataset_dict, create_model_dict, create_loss_dict, create_configs
from LineageTracer.train import begin_training

### Specify the path to train, val crops which we would like to train the network for:

The train-val images, masks will be accessed from the path specified by `data_dir` and `project-name`.

In [2]:
data_dir = 'dicts'
project_name = 'Fluo-C3DL-MDA231'

print("Project Name chosen as : {}. \nTrain-Val images-masks crops will be accessed from : {}".format(project_name, data_dir))

Project Name chosen as : Fluo-C3DL-MDA231. 
Train-Val images-masks crops will be accessed from : dicts


### Obtain properties of the dataset



Here, we read the dataset.json file prepared in the 01-data notebook previously.


In [3]:
if os.path.isfile('data_properties.json'): 
    with open('data_properties.json') as json_file:
        data = json.load(json_file)
        object_size = int(data['min_object_size'])
        num_tracklets = int(data['mean_num_tracklets'])
        tracklet_length = int(data['mean_length_tracklet'])
        std_object_size= np.max(np.array([data['std_object_size_x'], data['std_object_size_y'], data['std_object_size_z']]))

### Specify training dataset-related parameters


Some hints:

* The `train_size` attribute indicates the number of triplets which the network would see in one complete epoch. 
Here, the triplets would include one anchor, one positive sample taken from the same tracklet as the anchor and one negative sample. 

Although the `train_size` can be explicitly specified, we use a heuristic to assess a rough measure of how many triplets would exist.

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [4]:
train_size = int(perm(len(os.listdir(os.path.join(data_dir, 'train'))), 2)/num_tracklets)

In [5]:
train_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         size = train_size, 
                                         num_sampled_tracklets = num_tracklets,
                                         num_fg_points = object_size,
                                         std_object_size = std_object_size,
                                         type = 'train')

`train_dataset_dict` dictionary successfully created with: 
 -- train images accessed from dicts/Fluo-C3DL-MDA231/train/images, 
 -- number of images per epoch equal to 36, 
 -- batch size set at 1, 


### Specify validation dataset-related parameters

Some hints:

* The `val_size` attribute indicates the number of triplets which the network would see in one complete epoch. 
Here, the triplets would include one anchor, one positive sample taken from the same tracklet as the anchor and one negative sample. 

Although the `val_size` can be explicitly specified, we use a heuristic to assess a rough measure of how many triplets would exist.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!

In [6]:
val_size = int(perm(len(os.listdir(os.path.join(data_dir, 'val'))), 2)/num_tracklets)

In [7]:
val_dataset_dict = create_dataset_dict(data_dir = data_dir, 
                                         project_name = project_name,  
                                         size = val_size, 
                                         num_sampled_tracklets = num_tracklets,
                                         num_fg_points = object_size,
                                         std_object_size = std_object_size,
                                         type = 'val')

`val_dataset_dict` dictionary successfully created with: 
 -- val images accessed from dicts/Fluo-C3DL-MDA231/val/images, 
 -- number of images per epoch equal to 32, 
 -- batch size set at 1, 


### Specify model-related parameters

Some hints:

* Set the `num_latent_channels` attribute equal to the dimensionality of the latent embedding. Set this to `zero`. in case you are not using latent embeddings

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!


In [8]:
num_offset_channels = 3
num_intensity_channels = 1
num_latent_channels = 0
num_output_channels = 32

### Create the `model_dict` dictionary

In [9]:
model_dict = create_model_dict(num_fg_points = object_size, 
                               num_offset_channels = num_offset_channels, 
                               num_latent_channels = num_latent_channels,
                               num_output_channels = num_output_channels
                              )


`model_dict` dictionary successfully created with: 
 -- number of offset channels equal to 3, 
 -- number of intensity channels equal to 1, 
 -- number of latent channels equal to 0, 
 -- number of output channels equal to 32


### Create the `loss_dict` dictionary


In [10]:
margin = 0.2

In [11]:
loss_dict = create_loss_dict(margin=margin)

`loss_dict` dictionary successfully created with: 
 -- margin equal to 0.200


### Create the `configs` dictionary

Some hints:

* The `n_epochs` attribute determines how long the training should proceed.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. If one wishes to resume training from a previous checkpoint, they could point resume_path attribute appropriately. For example, one could set `resume_path = './experiment/Fluo-C3DL-MDA231-demo/checkpoint.pth'` to resume training from the last checkpoint.

In the cell after this one, a `configs` dictionary is generated from the parameters specified here! 

In [12]:
n_epochs = 20
save_dir = os.path.join('experiment', project_name+'-'+'demo')
resume_path  = None

In [13]:
configs = create_configs(n_epochs = n_epochs,
                         resume_path = resume_path, 
                         save_dir = save_dir, 
                         )

`configs` dictionary successfully created with: 
 -- n_epochs equal to 20, 
 -- save_dir equal to experiment/Fluo-C3DL-MDA231-demo, 


Save the properties of the model in a `json` file. This will be accessed later during inference.

In [14]:
train_properties_dir = {}
train_properties_dir['num_offset_channels']=num_offset_channels
train_properties_dir['num_intensity_channels']=num_intensity_channels
train_properties_dir['num_latent_channels'] =num_latent_channels
train_properties_dir['num_output_channels'] =num_output_channels
train_properties_dir['margin']= margin
                             
with open('train_properties.json', 'w') as outfile:
    json.dump(train_properties_dir, outfile)
    print("Train config properties of the `{}` dataset is saved to `train_properties.json`".format(project_name))

Train config properties of the `Fluo-C3DL-MDA231` dataset is saved to `train_properties.json`


### Begin Training

In [15]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

`train` dataloader created! Accessing data from dicts/train/
Number of tracklets in `train` directory is 33
`val` dataloader created! Accessing data from dicts/val/
Number of tracklets in `val` directory is 31


  0%|                                                    | 0/36 [00:00<?, ?it/s]

Created  loss function with: margin: 0.2
*************************
Created logger with keys:  ('train', 'val')
Starting epoch 0
learning rate: 0.0005


	addmm_(Number beta, Number alpha, Tensor mat1, Tensor mat2)
Consider using one of the following signatures instead:
	addmm_(Tensor mat1, Tensor mat2, *, Number beta, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484772347/work/torch/csrc/utils/python_arg_parser.cpp:1174.)
  embeddings.t())  # stands for add matrix multiplication in place -2*dist + embeddings*embeddings_tranposed
100%|███████████████████████████████████████████| 36/36 [00:12<00:00,  2.85it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.27it/s]


===> train loss: 0.076
===> val loss: 0.009


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 1
learning rate: 0.0004977494364660346


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.21it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.26it/s]


===> train loss: 0.006
===> val loss: 0.005


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 2
learning rate: 0.0004954977417064171


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.19it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.34it/s]


===> train loss: 0.001
===> val loss: 0.002


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 3
learning rate: 0.0004932449094349202


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.18it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.35it/s]


===> train loss: 0.001
===> val loss: 0.006


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 4
learning rate: 0.0004909909332982877


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.17it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.44it/s]


===> train loss: 0.001
===> val loss: 0.003


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 5
learning rate: 0.0004887358068751748


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.18it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.33it/s]


===> train loss: 0.000
===> val loss: 0.004


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 6
learning rate: 0.0004864795236750653


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.25it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.34it/s]


===> train loss: 0.000
===> val loss: 0.002


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 7
learning rate: 0.00048422207713716544


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.16it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.45it/s]


===> train loss: 0.000
===> val loss: 0.003


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 8
learning rate: 0.00048196346062927547


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.12it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00,  9.98it/s]


===> train loss: 0.000
===> val loss: 0.002


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 9
learning rate: 0.00047970366744663594


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.26it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00,  9.89it/s]


===> train loss: 0.000
===> val loss: 0.002


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 10
learning rate: 0.00047744269081074987


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.18it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.21it/s]


===> train loss: 0.000
===> val loss: 0.001


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 11
learning rate: 0.0004751805238681794


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.07it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.03it/s]


===> train loss: 0.000
===> val loss: 0.002


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 12
learning rate: 0.000472917159689316


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.18it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.03it/s]


===> train loss: 0.000
===> val loss: 0.001


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 13
learning rate: 0.00047065259126712457


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.23it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.19it/s]


===> train loss: 0.000
===> val loss: 0.003


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 14
learning rate: 0.00046838681151585874


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.16it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00,  9.84it/s]


===> train loss: 0.000
===> val loss: 0.002


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 15
learning rate: 0.0004661198132697498


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.18it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.02it/s]


===> train loss: 0.000
===> val loss: 0.004


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 16
learning rate: 0.0004638515892816641


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.02it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00, 10.14it/s]


===> train loss: 0.000
===> val loss: 0.002


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 17
learning rate: 0.00046158213222173284


100%|███████████████████████████████████████████| 36/36 [00:12<00:00,  2.97it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00,  9.73it/s]


===> train loss: 0.000
===> val loss: 0.003


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 18
learning rate: 0.0004593114346759497


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.11it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00,  9.79it/s]


===> train loss: 0.000
===> val loss: 0.003


  0%|                                                    | 0/36 [00:00<?, ?it/s]

=> saving checkpoint
Starting epoch 19
learning rate: 0.00045703948914473726


100%|███████████████████████████████████████████| 36/36 [00:11<00:00,  3.21it/s]
100%|███████████████████████████████████████████| 32/32 [00:03<00:00,  9.85it/s]


===> train loss: 0.000
===> val loss: 0.004
=> saving checkpoint


<Figure size 640x480 with 0 Axes>