Linajea Tracking Example
=====================

This example show all steps necessary to generate the final tracks, from training the network to finding the optimal ILP weights on the validation data to computing the tracks on the test data.

- train network
- predict on validation data
- grid search weights for ILP
  - solve once per set of weights
  - evaluate once per set of weights
  - select set with fewest errors
- predict on test data
- solve on test data with optimal weights
- evaluate on test data

In [None]:
%load_ext autoreload
%autoreload 2
import logging
import multiprocessing
import os
import sys
import time
import types

import numpy as np
import pandas as pd

from linajea.config import (dump_config,
                            maybe_fix_config_paths_to_machine_and_load,
                            SolveParametersConfig,
                            TrackingConfig)
from linajea.utils import getNextInferenceData
import linajea.evaluation
from linajea.process_blockwise import (extract_edges_blockwise,
                                       predict_blockwise,
                                       solve_blockwise)
from linajea.training import train
import linajea.utils

In [None]:
logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(name)s %(levelname)-8s %(message)s')

Experiment
-----------------

To start a new experiment create a new folder and copy the configuration file(s) you want to use into this folder.
For this example we have already done this for you (`example_advanced`). Then change the current working directory to that folder.
Make sure that the file paths contained in the configuration files point to the correct destination, for instance that they are adapted to your directory structure. And that `config.general.setup_dir` is set to the folder you just created.

In [None]:
setup_dir = "example_advanced"
os.chdir(setup_dir)

Setup
--------

Make sure that the `linajea` package is installed and that the correct kernel is selected in the jupyter notebook.


### Data

Set `download_data` to `True` and execute the next cell to download a subset of: *3D+time nuclei tracking dataset of confocal fluorescence microscopy time series of C. elegans embryos* (https://zenodo.org/record/6460303).

Per step (train, validation, test) we will use multiple samples now (For demonstration purposes each 'sample' is a different set of frames from the whole dataset). For training we use frames 20-40 and frames 50-90 (while excluding frames 65-75) from the `mskcc_emb1` volume, for validation we use frames 30-45 and frames 50-65 from the `mskcc_emb2` volume und to test we use frames 50-65 from the `mskcc_emb3` volume.

You can of course use your own data, it has to be in a compatible format (see the main README file for more information).

In [None]:
download_data = False
if download_data:
    !wget -c https://figshare.com/ndownloader/files/36792993?private_link=a5aff1bc78b005be7ea6 -O mskcc_emb1.zip -q
    !unzip mskcc_emb1.zip
    !wget -c https://figshare.com/ndownloader/files/36793002?private_link=a5aff1bc78b005be7ea6 -O mskcc_emb2.zip -q
    !unzip mskcc_emb2.zip
    !wget -c https://figshare.com/ndownloader/files/36793017?private_link=a5aff1bc78b005be7ea6 -O mskcc_emb3.zip -q
    !unzip mskcc_emb3.zip

### Database

MongoDB is used to store the computed results. A `mongod` server has to be running before executing the remaining cells.
See https://www.mongodb.com/docs/manual/administration/install-community/ for a guide on how to install it (Linux/Windows/MacOS).
Alternatively you might want to create a singularity image (https://github.com/singularityhub/mongo). This can be used locally, too, but will be necessary if you want to run the code on an HPC cluster and there is no server installed already.

Set `setup_databases` to `True` to add the ground truth tracks to the database server. Make sure to set `db_host` to the correct server (if you run it locally you can usually just set it to `"localhost"`). If you use different data you also have to adapt `csv_tracks_file` and `db_name`.

In [None]:
setup_databases = False
db_host = "h09u15"
if setup_databases:
    csv_tracks_file = "mskcc_emb1_fr20-40_tracks.csv"
    db_name = "linajea_mskcc_emb1_fr20-40_gt"
    linajea.utils.add_tracks_to_database(
        csv_tracks_file,
        db_name,
        db_host)
    
    csv_tracks_file = "mskcc_emb1_fr50-90_tracks.csv"
    db_name = "linajea_mskcc_emb1_fr50-60_fr65-75_fr80-90_gt"
    linajea.utils.add_tracks_to_database(
        csv_tracks_file,
        db_name,
        db_host)
    
    csv_tracks_file = "mskcc_emb2_fr30-45_tracks.csv"
    db_name = "linajea_mskcc_emb2_fr30-45_gt"
    linajea.utils.add_tracks_to_database(
        csv_tracks_file,
        db_name,
        db_host)
    
    csv_tracks_file = "mskcc_emb2_fr50-65_tracks.csv"
    db_name = "linajea_mskcc_emb2_fr50-65_gt"
    linajea.utils.add_tracks_to_database(
        csv_tracks_file,
        db_name,
        db_host)
    
    csv_tracks_file = "mskcc_emb3_fr50-65_tracks.csv"
    db_name = "linajea_mskcc_emb3_fr50-65_gt"
    linajea.utils.add_tracks_to_database(
        csv_tracks_file,
        db_name,
        db_host)

Configuration
--------------------

All parameters to control the pipeline (e.g. model architecture, data augmentation, training parameters, ILP weights) are contained in a configuration file (in the TOML format https://toml.io)

You can use a single monolithic configuration file or separate configuration files for a subset of the steps of the pipeline, as long as the parameters required for the respective steps are there.

Familiarize yourself with the example configuration files and have a look at the documentation for the configuration to see what is needed. Most parameters have sensible defaults; usually setting the correct paths and the data configuration is all that is needed to start. See `run_simple.ipynb` for a simpler example setup that can only handle one sample/volume per dataset and that requires manual selection of the data used in the individual steps.

In this setup `train_data`, `val_data` and `test_data` have to be set once and depending on the processing step the correct data is selected automatically.

In [None]:
config_file = "config.toml"
config = TrackingConfig.from_file(config_file)

Training
------------

To start training simply pass the configuration object to the train function. Make sure that the training data and parameters such as the number of iterations/setps are set correctly.

To train until convergence will take from several hours to multiple days.

In [None]:
# done in child process to automatically free cuda resources
p = multiprocessing.Process(target=train, args=(config,))
p.start()
p.join()

As training until convergence will take a while we provide a pretrained model that can be used to test the following steps of the tracking pipeline.
To use the pretained model set `use_pretrained` to `True` and execute the next cell:

In [None]:
use_pretrained = False
if use_pretrained:
    !wget https://figshare.com/ndownloader/files/36793554 -N -q
    shutil.copy2("36793554", f"train_net_checkpoint_{train_config.train.max_iterations}")

Validation
--------------

After the training is completed we first have to determine the optimal ILP weights.
This is achieved by first creating the prediction on the validation data and then performing a grid search by solving the ILP and evaluating the results repeatedly.

`getNextInferenceData` can be used to loop over all samples in a dataset, it returns a generator.
If `validation` is set to `True` in `args` the validation data is used, otherwise the test data. Other details (e.g. which training checkpoint to use, which database to store the results in to use) are determined automatically based on the configuration file. Internally it adds an `inference_data` entry that is used by the postprocessing functions such as `*_blockwise` and `evaluate_setup`. This entry is updated automatically after each iteration to point to the correct sample (and to use the correct checkpoint).

In [None]:
val_args = types.SimpleNamespace(
    config=config_file, validation=True)

### Predict Validation Data

First we predict the `cell_indicator` and `movement_vectors` on the validation data. Make sure that `args.validation` is set to `True`, then execute the next cell. The extracted maxima of the `cell_indicator` map correspond to potential cells in our candidate graph.

This command starts a number of workers (`predict.job.num_workers`) in the background, each worker tries to access a GPU. Do not start more workers than GPUs available. By default the workers are started locally. If you are working on a compute cluster (`lsf` supported, `slurm` and `gridengine` experimental) set `predict.job.run_on` to the respective string value, the code will communicate with the cluster scheduler and allocate the appropriate jobs.

Depending on the number of workers used (see config file) and the size of the data this can take a while. If there is no progress for a while check the log files in `<setup_dir>/daisy_logs/linajea_prediction`!

In [None]:
for inf_config in linajea.utils.getNextInferenceData(val_args):
    predict_blockwise(inf_config)

### Extract Edges Validation Data

In the next step we extract potential edges for our candidate graph. For each cell candidate, look for neighboring cells in the next time frame and insert an edge candidate for each into the database.

In [None]:
for inf_config in linajea.utils.getNextInferenceData(val_args):
    extract_edges_blockwise(inf_config)

### ILP Weights Grid Search

Cell/Node and edge candidates form together our candidate graph. By solving the ILP we extract tracks from this graph. However the ILP is parameterized by a set of weights. First we have to find the optimal values for these weights. To achieve this we perform a grid search over a predefined search space. For each set of parameter candidates we solve the ILP once on the validation data.


#### Solve on Validation Data

Make sure that `solve.grid_search` is set to `True` and that the search space (`solve.parameters_search_grid`) is defined. The parameter sets to try are then generated automatically. (Alternatively you can set `solve.parameters` to a list of parameter sets to try manually, as in `example_basic`).

In [None]:
config.solve.grid_search = True
config.solve.parameters = None
val_args.config = dump_config(config)

In [None]:
for inf_config in linajea.utils.getNextInferenceData(val_args, is_solve=True):
    linajea.process_blockwise.solve_blockwise(inf_config)

#### Evaluate on Validation Data

And as a last validation step we evaluate the performance for each set of parameter candidates.

In [None]:
for inf_config in linajea.utils.getNextInferenceData(val_args, is_evaluate=True):
    linajea.evaluation.evaluate_setup(inf_config)
    parameters = inf_config.solve.parameters[0]

#### Determine best ILP weights

The set of weights/parameters resulting in the best performance (fewest number of errors) will then be used to get the performance on the test set. 

In [None]:
config.solve.grid_search = False
val_args.config = dump_config(config)

score_columns = ['fn_edges', 'identity_switches',
                 'fp_divisions', 'fn_divisions']
if not config.general.sparse:
    score_columns = ['fp_edges'] + score_columns

sort_by = "sum_errors"
results = {}

samples = set()
for inf_config in linajea.utils.getNextInferenceData(val_args):
    sample = inf_config.inference_data.data_source.datafile.filename
    checkpoint = inf_config.inference_data.checkpoint
    cell_score_threshold = inf_config.inference_data.cell_score_threshold
    samples.add(sample)
    print("getting results for:", sample, checkpoint, cell_score_threshold)
    res = linajea.evaluation.get_results_sorted(
        inf_config,
        filter_params={"val": True},
        score_columns=score_columns,
        sort_by=sort_by)

    res = res.assign(checkpoint=checkpoint).assign(cell_score_threshold=cell_score_threshold)
    results[(os.path.basename(sample), checkpoint, cell_score_threshold)] = res.reset_index()

results = pd.concat(list(results.values())).reset_index()
del results['param_id']
del results['_id']

by = [
    "matching_threshold",
    "weight_node_score",
    "selection_constant",
    "track_cost",
    "weight_division",
    "division_constant",
    "weight_child",
    "weight_continuation",
    "weight_edge_score",
    "checkpoint",
    "cell_score_threshold"
]
if "cell_cycle_key" in results:
    by.append("cell_cycle_key")

results = results.groupby(by, dropna=False, as_index=False).agg(
    lambda x: -1 if len(x) != len(samples) else sum(x))
results = results[results.sum_errors != -1]
results.sort_values(sort_by, ascending=True, inplace=True)

parameters.weight_node_score = float(results.at[0, 'weight_node_score'])
parameters.selection_constant = float(results.at[0, 'selection_constant'])
parameters.track_cost = float(results.at[0, 'track_cost'])
parameters.weight_edge_score = float(results.at[0, 'weight_edge_score'])
parameters.weight_division = float(results.at[0, 'weight_division'])
parameters.weight_child = float(results.at[0, 'weight_child'])
parameters.weight_continuation = float(results.at[0, 'weight_continuation'])

cell_score_threshold = float(results.at[0, 'cell_score_threshold'])
checkpoint = int(results.at[0, 'checkpoint'])
print("Best parameters:\n", parameters)
print("Best model checkpoint:\n", checkpoint)
print(f"(used cell_score_threshold : {cell_score_threshold})\n")

Test
------

Now that we know which ILP weights to use we can create the candidate graph on the test data and compute the tracks. 

First load the test configuration file and set the parameters to the previously determined values (alternatively set the values manually directly in the configuration file). 
Make sure that `args.validation` is set to `False` and that `solve.grid_search` and `solve.random_search` are not set or set to `False`.

In [None]:
config.solve.parameters = [parameters]
config.solve.grid_search = False
config.solve.random_search = False

config.test_data.cell_score_threshold = cell_score_threshold
config.test_data.checkpoint = checkpoint

test_args = types.SimpleNamespace()
test_args.config = dump_config(config)
test_args.validation = False

### Predict Test Data

Now that we know which ILP weights to use we can predict the `cell_indicator` and `movement_vectors` on the test data and compute the tracks. Make sure that `args.validation` is set to `False` and `solve.grid_search` and `solve.random_search` are set to `False`.
If there is no progress for a while check the log files in `<setup_dir>/daisy_logs/linajea_prediction`!

In [None]:
for inf_config in linajea.utils.getNextInferenceData(test_args):
    predict_blockwise(inf_config)

### Extract Edges on Test Data

In the next step we extract again the potential edges for our candidate graph.

In [None]:
for inf_config in linajea.utils.getNextInferenceData(test_args):
    extract_edges_blockwise(inf_config)

### Solve on Test Data

Then we can solve the ILP on the test data. We select the weights that resulted in the lowest overall number of errors on the validation data.

In [None]:
for inf_config in linajea.utils.getNextInferenceData(test_args, is_solve=True):
    solve_blockwise(inf_config)

### Evaluate on Test Data

And finally we can evaluate the performance of our tracks.

In [None]:
for inf_config in linajea.utils.getNextInferenceData(test_args, is_evaluate=True):
    report = linajea.evaluation.evaluate_setup(inf_config)
    for k, v in report.get_short_report().items():
        print(f"\t{k: <32}: {v}")