# Segmentation Example
> Train a U-Net for pixelwise segmentation of the prostate

In [1]:
import monai
import ignite
import yaml

from prostate158.utils import load_config
from prostate158.data import segmentation_dataloaders
from prostate158.train import SegmentationTrainer
from prostate158.report import ReportGenerator
from prostate158.viewer import ListViewer

All parameters needed for training and evaluation are set in `anatomy.yaml` file. 

In [2]:
config = load_config('anatomy.yaml') # change to 'tumor.yaml' for tumor segmentation
monai.utils.set_determinism(seed=config.seed)

Create supervised trainer for segmentation task

In [3]:
trainer=SegmentationTrainer(
    progress_bar=True, 
    early_stopping = True, 
    metrics = ["MeanDice", "HausdorffDistance", "SurfaceDistance"],
    save_latest_metrics = True,
    config=config
)

In [4]:
test_dl = segmentation_dataloaders(config=config, train=False, valid=False, test=True)
trainer.evaluate(
    checkpoint='models/anatomy.pt',
    dataloader=test_dl
)

2022-12-06 21:36:33,475 - Engine run resuming from iteration 0, epoch 0 until 1 epochs
2022-12-06 21:36:54,706 - Current run is terminating due to exception: y_pred and y should have same shapes, got (1, 2, 384, 384, 163) and (1, 2, 139, 122, 148).
2022-12-06 21:36:54,713 - Exception: y_pred and y should have same shapes, got (1, 2, 384, 384, 163) and (1, 2, 139, 122, 148).
Traceback (most recent call last):
  File "/home/cosminciausu/miniconda3/envs/prostate158/lib/python3.8/site-packages/ignite/engine/engine.py", line 1069, in _run_once_on_dataset_as_gen
    self._fire_event(Events.ITERATION_COMPLETED)
  File "/home/cosminciausu/miniconda3/envs/prostate158/lib/python3.8/site-packages/ignite/engine/engine.py", line 425, in _fire_event
    func(*first, *(event_args + others), **kwargs)
  File "/home/cosminciausu/miniconda3/envs/prostate158/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/cosminciausu/m

ValueError: y_pred and y should have same shapes, got (1, 2, 384, 384, 163) and (1, 2, 139, 122, 148).

Adding a learning rate scheduler for one-cylce policy. 

In [None]:
trainer.fit_one_cycle()

Let's train. This can take several hours. 

In [None]:
trainer.run()

Finish the training with final evaluation of the best model. To allow visualization of all outputs, add OutputStore handler first. Otherwise only output form the last epoch will be accessible. 

In [None]:
eos_handler = ignite.handlers.EpochOutputStore()
eos_handler.attach(trainer.evaluator, 'output')

In [None]:
trainer.test(checkpoint='models/anatomy.pt')

Generate a markdown document with segmentation results

In [None]:
report_generator=ReportGenerator(
    config.run_id, 
    config.out_dir, 
    config.log_dir
)
report_generator.generate_report()

Have a look at some outputs

In [None]:
output = trainer.evaluator.state.output
keys = ['image', 'label', 'pred']
outputs = {k : [o[0][k].detach().cpu().squeeze() for o in output] for k in keys}

In [None]:
ListViewer(
    [o.transpose(0,2).flip(-2) for o in outputs['image'][0:3]] + 
    [o.argmax(0).transpose(0,2).flip(-2).float() for o in outputs['label'][0:3]] + 
    [o.argmax(0).transpose(0,2).flip(-2).float() for o in outputs['pred'][0:3]] 
).show()