# Data extraction and assembly

This example shows how to use the `pymia.data` package to extract chunks of data from the dataset and to assemble the chunks
to feed a deep neural network. It also shows how the predicted chunks are assembled back to full-images predictions.

The extraction-assemble principle is essential for large three-dimensional images that do not fit entirely in the GPU memory
and thus require some kind of patch-based approach.

For simplicity reasons we use slice-wise extraction in this example, meaning that the two-dimensional slices are extracted
from the three-dimensional image.

The example uses PyTorch as a deep learning (DL) framework. The minimal adaptions needed for TensorFlow are shown at the end.

The Jupyter notebook can be found at [./examples/data/extraction_assembling.ipynb](https://github.com/rundherum/pymia/blob/master/examples/data/extraction_assembling.ipynb).

<div class="alert alert-info">

Note

To be able to run this example:

- Get the example data by executing [./examples/example-data/pull_example_data.py](https://github.com/rundherum/pymia/blob/master/examples/example-data/pull_example_data.py).
- Install torch (`pip install torch`).

</div>


Import the required modules.


In [1]:
import pymia.data.assembler as assm
import pymia.data.transformation as tfm
import pymia.data.definition as defs
import pymia.data.extraction as extr
import pymia.data.backends.pytorch as pymia_torch

First, we create the the access to the .h5 dataset by defining: (i) the indexing strategy (`indexing_strategy`)
that defines the chunks of data to be retrieved, (ii) the information to be extracted (`extractor`), and (iii)
the transformation (`transform`) to be applied after extraction.

The permutation transform is required since the channels (here _T1_, _T2_) are stored in the last dimension in the .h5 dataset
but PyTorch requires channel-first format.

In [2]:
hdf_file = '../example-data/example-dataset.h5'

# Data extractor for extracting the "images" entries
extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES,))
# Permutation transform to go from HWC to CHW.
transform = tfm.Permute(permutation=(2, 0, 1), entries=(defs.KEY_IMAGES,))
# Indexing defining a slice-wise extraction of the data
indexing_strategy = extr.SliceIndexing()

dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, transform)


Next, we define an assembler that will puts the data/image chunks back together after prediction of the input chunks. This is
required to perform a evaluation on entire subjects, and any further processing such as saving the predictions.

Also, we define extractors that we will use to extract information required after prediction. This information not need
to be chunked (/indexed/sliced) and not need to interact with the DL framework. Thus, it can be extracted
directly form the dataset.


In [3]:
assembler = assm.SubjectAssembler(dataset)

direct_extractor = extr.ComposeExtractor([
    extr.ImagePropertiesExtractor(),  # Extraction of image properties (origin, spacing, etc.) for storage
    extr.DataExtractor(categories=(defs.KEY_LABELS,))  # Extraction of "labels" entries for evaluation
])

**PyTorch-specific**
The loop over the batches and the neural network architecture are framework dependent. We show here the PyTorch case, but the
TensorFlow equivalent can be found at the end of this notebook.

Basically, all we have to do is to wrap our dataset as PyTorch dataset, to build a PyTorch data loader, and to create/load a
network.


In [4]:
import torch
import torch.nn as nn
import torch.utils.data as torch_data

# Wrap the pymia datasource
pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset)
loader = torch_data.dataloader.DataLoader(pytorch_dataset, batch_size=2, shuffle=False)

# Dummy network representing a placeholder for a trained network
dummy_network = nn.Sequential(
    nn.Conv2d(in_channels=2, out_channels=8, kernel_size=3, padding=1),
    nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1),
    nn.Sigmoid()
).eval()
torch.set_grad_enabled(False)  # no gradients needed for testing

nb_batches = len(loader)

We are now ready to loop over batches of data chunks. After the usual prediction of the network, the predicted data is
provided to the assembler, which takes care of putting chunks back together. Once some subjects are assembled
(`subjects_ready`) we extract the data required for evaluation and storing.


In [5]:
for i, batch in enumerate(loader):

    # Get data from batch and predict
    x, sample_indices = batch[defs.KEY_IMAGES], batch[defs.KEY_SAMPLE_INDEX]
    prediction = dummy_network(x)

    # translate the prediction to numpy and back to (B)HWC (channel last)
    numpy_prediction = prediction.numpy().transpose((0, 2, 3, 1))

    # add the batch prediction to the assembler
    is_last = i == nb_batches - 1
    assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last)

    # Process the subjects/images that are fully assembled
    for subject_index in assembler.subjects_ready:
        subject_prediction = assembler.get_assembled_subject(subject_index)

        # Extract the target and image properties via direct extract
        direct_sample = dataset.direct_extract(direct_extractor, subject_index)
        target, image_properties = direct_sample[defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]

        # # Do whatever you desire...
        # do_eval()
        # do_save()

**TensorFlow adaptions**
For the presented data handling to work with the TensorFlow framework, only minor modifications are required:
(1) Modifications in the input, (2) different framework-specific code, and (3) no permutations required.

```python
# 1)
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import pymia.data.backends.tensorflow as pymia_tf

# 2)
gen_fn = pymia_tf.get_tf_generator(dataset)
tf_dataset = tf.data.Dataset.from_generator(generator=gen_fn,
                                            output_types={defs.KEY_IMAGES: tf.float32,
                                                          defs.KEY_SAMPLE_INDEX: tf.int64})
loader = tf_dataset.batch(2)

dummy_network = keras.Sequential([
    layers.Conv2D(8, kernel_size=3, padding='same'),
    layers.Conv2D(2, kernel_size=3, padding='same', activation='sigmoid')]
)
nb_batches = len(dataset) // 2

# 3)
# No permutation transform needed. Thus the lines
transform = tfm.Permute(permutation=(2, 0, 1), entries=(defs.KEY_IMAGES,))
numpy_prediction = prediction.numpy().transpose((0, 2, 3, 1))
# become
transform = None
numpy_prediction = prediction.numpy()
```