# Simple training for 3d segmentation


In [2]:
from trainlib.trainer import SegmentationTrainer
from trainlib.utils import load_config

## Config

In `trainlib` everything is build around a config file. 
This makes training portable, e.g. between a workstation and a computing cluster. 
To change the training, the config needs to be adapted. 
Understanding the logic of the config and how it influences the training is crucial. 
We use a config file from unit tests as example. 

In [32]:
config = load_config("../tests/test_config_segm.yaml")

```yaml
data:
  data_dir: ../data/
  train_csv: ../data/test_data_train_3d_segm.csv
  valid_csv: ../data/test_data_valid_3d_segm.csv
  test_csv: ../data/test_data_valid_3d_segm.csv
  image_cols: [image]
  label_cols: [label]
  train: True
  valid: True
  test: False
  dataset_type: persistent
  cache_dir: .monai-cache
  batch_size: 1
```
Trainlib uses CSV as API for filenames. As key concept, the CSV will provide the filename to the data, relative to `data_dir`. 
In this example the data_dir is `../data/` (absolute paths are recommended), the first filename of `train_csv` would be `images/radiopaedia_10_85902_1.nii.gz`. 
So `trainlib` will try to load the file from `/../data/images/radiopaedia_10_85902_1.nii.gz`. 

`image_cols` provide the col-name in the respective train/valid/test csv. Multiple names are possible. `label_cols` are treated the same.  
`train: true` means, a train dataloader is constructed from `train_csv`.  
`dataset_type: persistent` makes `trainlib` use `monai.data.PersistentDataset`, a dataset that caches files to disk for significant speedup during training. 
Especially for 3D i/o can become a major bottleneck of the training.   
`cache_dir` gives the directory `trainlib` caches to. This is not deleted after training and can get quite large!  
`batch_size` is the batch_size used during training. 

```yaml
seed: 42
debug: false
device: cuda:0
run_id: runs/heart-devices
overwrite: true
log_dir: logs
out_dir: output
ndim: 2
model_dir: models
```

`seed` the random seed used throughout all trainlib.  
`debug` `trainlib` provides a debug mode, which can be toggled on/off in the config.  
`device` sets the hardware accelerator. `cpu` is also possible. Parallel training is not yet supported.   
`run_id` The id of the trainings run. This will become a folder where everything is stored in. It is recommended that `runs/some-name` is used, so `runs` can be conviniently added to `.gitignore`  
`overwrite` will overwrite the run_id or create a new one each run.  
`log_dir` writes logs to `run_id/log_dir`. Logs are loss, metrics and snapshots from the training.  
`out_dir` writes output to `run_id/out_dir`. Outputs are metrics. For segmentation this is Dice coefficient, Hausdorff distance, Surface distance.  
`model_dir` `trainlib` places checkpoints in this directory (and also tries to load from here).  
`ndim` the dimensionality of the data. 2 for 2d and 3 for 3d are supported.   

```yaml
loss:
  DiceLoss:
    include_background: true
    softmax: true
    to_onehot_y: true
optimizer:
  Adam:
    lr: 0.01
    weight_decay: 0.001
lr_scheduler:
  OneCycleLR:
    max_lr: 0.0001
model:
  UNet:
    act: PRELU
    channels: [16, 32]
    dropout: 0.1
    norm: BATCH
    num_res_units: 1
    out_channels: 2 # bg, label
    strides: [2]
```
Loss function and optimizer are parsed directly from `monai.losses` / `torch.nn` or `monai.optimizers` / `torch.optim`.   
Models are loaded from `monai.network.nets`.  
Number of input channels is parsed automatically from the number of input images (`len(image_cols)`), the rest is parsed to the monai class.  
All arguments for `model` are directly parsed to monai, except `in_channels`, which is automatically derived from the length of `data.image_cols`.  

```yaml
training:
  early_stopping_patience: 10
  max_epochs: 25
```
`early_stopping_patience` controlls how long `trainlib` tollerates that the key metrics does not improve until trainig is stopped prematurely.  
`max_epochs` max numbe of training epochs.  

```yaml
transforms:
  base:
    LoadImaged:
      allow_missing_keys: true
    EnsureChannelFirstd:
      allow_missing_keys: true
    Spacingd:
      pixdim: [2, 2, 2]
      mode: [bilinear, nearest]
   train:
     Identityd:
   valid:
     Identityd:
   postprocessing: 
     Identityd:
  prob: 0.1
```
`trainlib` uses different transforms pipelines: 

- `base` is always applied to the data. Use this for I/O and normalization. 
- `train` is just applied to the training data. Use this for augmentations. 
- `valid` is just applied to the validataion data.  
- `postprocessing` is applied to the labels and predictions, after the loss but before the metrics are calculated. 
    
`prob` controlls the probability that each transform is applied. 
If not explicitly stated with `key:`, each transform will be applied to each item in the data. 

```yaml
patch:
  transforms: /path/to/custom/transforms.py
  model: /path/to/custom/model.py
  loss: /path/to/custom/loss.py
  optimizer: /path/to/custom/optimizer.py
```
Sometimes one needs to use custom code for transforms, models, loss or optimizer. For this `trainlib` uses patch-functionality. 
`trainlib` will try to load transforms/models/loss/optimizers from the patch-file first, then fall back to monai/torch. This way transforms can be overwritten or custom transforms can be used. 


```yaml
pushover_credentials: /path/to/.pushover_credentials.yaml
```
Provide a file with pushover credentials to get updates on your mobile device. 

## Simple training and data inspection
For this project, the `SegmentationTrainer` is used and the class is initialized from the config file. 
It is possible to override arguments in the config before passing it to the trainer. 

In [7]:
trainer = SegmentationTrainer(config=config)

Setuptools is replacing distutils.


2022-11-18 10:21:33,727 - No pushover credentials file submitted, will not try to push trainings progress to pushover device. If you want to receive status updated via pushover, provide the path to a yaml file, containing the `app_token`, `user_key` and `proxies` (optional) in the config at `pushover_credentials`


With `show_batch`, `trainlib` provides a tool for quick visualization of the data. Data seen here, is passed directly to the model. 
Because `ipywidgets` are used, the output is interactive and not visible once the notebook is shutdown. Masks can be toggled on/off and intensities can be changed. 
```python
trainer.data_loader.show_batch()
```
![Show Batch Example](figures/example-show-batch.png)

`show_batch` uses `trainlib.viewer.ListViewer` to visualize output. But this class, as well as relatec classes can also be used directly in interactive sessions. 
Monai re-orients images to comply with the NIfTI header. It might be therefore nessecary to re-arrange the array before viewing. Here, this is done by transposing the array. 

In [25]:
from trainlib.viewer import BasicViewer, ListViewer, DicomExplorer
from monai.transforms import LoadImage

reader = LoadImage(image_only=True)
image_1 = reader("../data/images/radiopaedia_10_85902_1.nii.gz").transpose(0, 2)
image_2 = reader("../data/images/radiopaedia_10_85902_3.nii.gz").transpose(0, 2)

label_1 = reader("../data/labels/radiopaedia_10_85902_1.nii.gz").transpose(0, 2)
label_2 = reader("../data/labels/radiopaedia_10_85902_3.nii.gz").transpose(0, 2)

`BasicViewer` can show a single image, with mask overlay or classification label. An optional description can also be added to the plot. 

```python
BasicViewer(image_1, label_1, description="A CT showing COVID Pneumonia").show()
```
![Show Batch Example](figures/example-basic-viewer-3d.png)

The `DicomExplorer` class provides additional information about the pixel/voxel values in the image array. 

```python
DicomExplorer(image_1, label_1, description="A CT showing COVID Pneumonia").show()
```
![Show Batch Example](figures/example-dicom-explorer-3d.png)

Lastly, the `ListViewer` class allows to view multiple arrays at once. 
Here, each display is an individual instance of `BasicViewer`, so it is possible to mix images with/without masks/labels/descriptions and also show 2d and 3d images at the same time. 

```python
ListViewer([image_1, image_2], [label_1, label_2]).show()
```

![Show Batch Example](figures/example-list-viewer-3d.png)

## Last check before training
With sanity check, trainlib also provide a tool to test of all transforms can be applied without errors. Simple summary statistics about the labels are provided after the check. This way, you may catch errors early and not at the end of an two hour epoch. 

In [31]:
trainer.data_loader.sanity_check()
trainer.evaluator.data_loader.sanity_check()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  4.30it/s]

2022-11-18 10:43:05,396 - Frequency of label values:
2022-11-18 10:43:05,397 - Value 0.0 appears in 3 items in the dataset
2022-11-18 10:43:05,398 - Value 1.0 appears in 3 items in the dataset



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.59it/s]

2022-11-18 10:43:05,961 - Frequency of label values:
2022-11-18 10:43:05,962 - Value 0.0 appears in 2 items in the dataset
2022-11-18 10:43:05,963 - Value 1.0 appears in 2 items in the dataset





## Training
After everything has been configured in the YAML file and data was checked, start training simply with:

In [12]:
trainer.run()



2022-11-18 10:24:20,886 - Engine run resuming from iteration 0, epoch 0 until 1 epochs
2022-11-18 10:24:28,525 - Epoch: 1/1, Iter: 1/3 -- train_loss: 0.6676 


[1/3]  33%|###3       [00:00<?]

2022-11-18 10:24:29,236 - Epoch: 1/1, Iter: 2/3 -- train_loss: 0.6667 
2022-11-18 10:24:29,421 - Epoch: 1/1, Iter: 3/3 -- train_loss: 0.6548 
2022-11-18 10:24:29,426 - Engine run resuming from iteration 0, epoch 0 until 1 epochs


[1/2]  50%|#####      [00:00<?]

2022-11-18 10:24:39,147 - Got new best metric of val_mean_dice: 0.008058685809373856
2022-11-18 10:24:39,708 - Epoch[1] Complete. Time taken: 00:00:10.183
2022-11-18 10:24:39,709 - Engine run complete. Time taken: 00:00:10.282
2022-11-18 10:24:49,717 - Key metric: None best value: -1 at epoch: -1
2022-11-18 10:24:49,718 - Key metric: None best value: -1 at epoch: -1
2022-11-18 10:24:49,719 - Epoch[1] Complete. Time taken: 00:00:28.746
2022-11-18 10:24:49,721 - Engine run complete. Time taken: 00:00:28.834


trainlib is not primarily designed for notebooks. While data analysis and checking should be done in notebooks, the final training is better carried out using a simple training script. 

## Logging

`trainlib` uses logging, which allows to control the verbosity of the trainier. Per default, the log-level is `INFO`. 
The loggers can be accessed at `trainer.logger` and `trainer.evaluator.logger`