Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
1293 lines (1008 sloc) 48.2 KB

How To: Integrate your own Computation Backend

Author: Justus Schock

Date: 15.05.2019

This howto will take you on a trip through the delira internals, while we will see, how to add a custom computation backend on the examplaric case of the torch.jit or TorchScript backend

Model Definitions

In order to implement a network, we will first have to define the network itself. In delira there is a single backend-specific implementation of an abstract network class for each of the backends. These interface classes are all based on the AbstractNetwork-class, defining the major API.

So let's start having a look at this class to see, what we will have to implement for our own backend.

Of course we will have to implement an __init__ defining our class. The __init__ of AbstractNetwork (which should be called during our the __init__ of our baseclass) accepts a number of kwargs and simply registers them to be init_kwargs, so there is nothing we have to take care of.

The next function to inspect is the __call__ function, which makes the class callable and the docstrings indicate, that it should take care of our model's forward-pass.

After the __call__ we now have the closure function, which defines a single training step (including, but not limited to, forward-pass, calculation of losses and train-metrics, backward-pass and optimization).

The last method to implement is the prepare_batch function which converts the input to a suitable format and the correct data-type and device.

TorchScript Limitations

Since we want to implement an abstract network class for this specific backend, we should have a look on how to generally implement models in this backend.

According the the PyTorch docs this works as follows:

You can write TorchScript code directly using Python syntax. You do this using the torch.jit.script decorator (for functions) or torch.jit.script_method decorator (for methods) on subclasses of ScriptModule. With this decorator the body of the annotated function is directly translated into TorchScript. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations.

Since our use-case is to implement the interface class for networks, we want to use the way of subclassing torch.jit.ScriptModule, implement it's forward and use the torch.jit.script_method decorator on it.

The example given in the very same docs for this case is:

import torch
class MyScriptModule(torch.jit.ScriptModule):
    def __init__(self, N, M):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    @torch.jit.script_method
    def forward(self, input):
        return self.weight.mv(input)

my_script_module = MyScriptModule(5, 3)
input_tensor = torch.rand(3)
my_script_module(input_tensor)
tensor([0.4997, 0.2955, 0.1588, 0.1873, 0.4753], grad_fn=<MvBackward>)

Merging TorchScript into our Abstract Class

This little example gives us a few things, we have to do for a successful definition of our base class:

1.) Our class has to subclass both, the AbstractNetwork and the torch.jit.ScriptModule classes.

2.) We need to implement a forward method, which takes care of the forward-pass (as it's name indicates).

3.) We don't have to take care of the backward-pass (thanks to PyTorch's and TorchScript's AutoGrad (which is a framework for automatic differentiation).

4.) Since torch.jit.ScriptModule is callable (seen in the example), it already implements a __call__ method and we may simply use this one.

5.) The closure is completely network-dependent and thus has to remain an abstract method here.

6.) The prepare_batch function also depends on the combination of network, inputs and loss functions to use, but we can at least give a prototype of such an function, which handles the devices correctly and converts everything to float

Actual Implementation

Now, let's start with the actual implementation and do one function by another and keep the things in mind, we just discovered.

Class Signature and __init__-Method

To subclass both networks, we cannot use the simple super().__init__ approach, because we have to init both parent classes, so we do

class AbstractTorchScriptNetwork(AbstractNetwork, torch.jit.ScriptModule):

    @abc.abstractmethod
    def __init__(self, optimize=True, **kwargs):
        """

        Parameters
        ----------
        optimize : bool
            whether to optimize the network graph or not; default: True
        **kwargs :
            additional keyword arguments (passed to :class:`AbstractNetwork`)
        """
        torch.jit.ScriptModule.__init__(self, optimize=optimize)
        AbstractNetwork.__init__(self, **kwargs)

instead. This ensures all parent classes to be initialized correctly.

__call__-Method

As mentioned above, the __call__ method is very easy to implement, because we can simply use the implementation of our TorchScript base class like this:

def __call__(self, *args, **kwargs):
    """
    Calls Forward method

    Parameters
    ----------
    *args :
        positional arguments (passed to `forward`)
    **kwargs :
        keyword arguments (passed to `forward`)

    Returns
    -------
    Any
        result: module results of arbitrary type and number

    """
    return torch.jit.ScriptModule.__call__(self, *args, **kwargs)

This also ensures, that we can pass an arbitrary number or positional and keyword arguments of arbitrary types to it (which are all passed to the forward-function). The advantage over directly calling the forward method here, is that the ScriptModule.__call__ already does the handling of forward-pre-hooks, forward-hooks and backward-hooks.

closure-Method

Since this method is highly model-dependant, we just don't implement it, which forces the user to implement it (since it is marked as an abstractmethod in AbstractExperiment).

prepare_batch-Method

The above mentioned prototype of pushing everything to the correct device and convert it to float looks like this:

@staticmethod
def prepare_batch(batch: dict, input_device, output_device):
    """
    Helper Function to prepare Network Inputs and Labels (convert them to
    correct type and shape and push them to correct devices)

    Parameters
    ----------
    batch : dict
        dictionary containing all the data
    input_device : torch.device
        device for network inputs
    output_device : torch.device
        device for network outputs

    Returns
    -------
    dict
        dictionary containing data in correct type and shape and on correct
        device

    """
    return_dict = {"data": torch.from_numpy(batch.pop("data")).to(
        input_device).to(torch.float)}

    for key, vals in batch.items():
        return_dict[key] = torch.from_numpy(vals).to(output_device).to(
            torch.float)

    return return_dict

Since we don't want to use any of the model's attributes here (and for conformity with the AbstractNetwork class), this method is defined as staticmethod, meaning it is class-bound, not instance-bound. The closure method has to be a staticmethod too.

forward-Method

The only thing left now, is the forward method, which is internally called by ScriptModule.__call__. The bad news is: We currently can't implement it. Subclassing a ScriptModule to overwrite a function decorated with torch.jit.script_method is not (yet) supported, but will be soon, once this PR is merged and released.

For now: you simply have to implement this method in your own network despite the missing of an abstract interface-method.

Putting it all together

If we combine all the function implementations to one class, it looks like this:

class AbstractTorchScriptNetwork(AbstractNetwork, torch.jit.ScriptModule):

    """
    Abstract Interface Class for TorchScript Networks. For more information
    have a look at https://pytorch.org/docs/stable/jit.html#torchscript

    Warnings
    --------
    In addition to the here defined API, a forward function must be
    implemented and decorated with ``@torch.jit.script_method``

    """
    @abc.abstractmethod
    def __init__(self, optimize=True, **kwargs):
        """

        Parameters
        ----------
        optimize : bool
            whether to optimize the network graph or not; default: True
        **kwargs :
            additional keyword arguments (passed to :class:`AbstractNetwork`)
        """
        torch.jit.ScriptModule.__init__(self, optimize=optimize)
        AbstractNetwork.__init__(self, **kwargs)

    def __call__(self, *args, **kwargs):
        """
        Calls Forward method

        Parameters
        ----------
        *args :
            positional arguments (passed to `forward`)
        **kwargs :
            keyword arguments (passed to `forward`)

        Returns
        -------
        Any
            result: module results of arbitrary type and number

        """
        return torch.jit.ScriptModule.__call__(self, *args, **kwargs)

    @staticmethod
    def prepare_batch(batch: dict, input_device, output_device):
        """
        Helper Function to prepare Network Inputs and Labels (convert them to
        correct type and shape and push them to correct devices)

        Parameters
        ----------
        batch : dict
            dictionary containing all the data
        input_device : torch.device
            device for network inputs
        output_device : torch.device
            device for network outputs

        Returns
        -------
        dict
            dictionary containing data in correct type and shape and on correct
            device

        """
        return_dict = {"data": torch.from_numpy(batch.pop("data")).to(
            input_device).to(torch.float)}

        for key, vals in batch.items():
            return_dict[key] = torch.from_numpy(vals).to(output_device).to(
                torch.float)

        return return_dict

Saving and loading

Now that we have the ability to implement delira-suitable TorchScript models, we want to store them on disk and load them again, so that we don't have to retrain them every time we want to use them. These I/O functions are usually located in delira.io.

Saving

Our saving function utilizes multiple functions: torch.jit.save to simply save the model (including it's graph) and the save_checkpoint_torch function implemented for the PyTorch backend to store the trainer state, since TorchScript allows us to use plain PyTorch optimizers.

The implementation of the function looks like this:

def save_checkpoint_torchscript(file: str, model=None, optimizers={},
                                epoch=None, **kwargs):
    """
    Save current checkpoint to two different files:
        1.) ``file + "_model.ptj"``: Will include the state of the model
            (including the graph; this is the opposite to
            :func:`save_checkpoint`)
        2.) ``file + "_trainer_state.pt"``: Will include the states of all
            optimizers and the current epoch (if given)

    Parameters
    ----------
    file : str
        filepath the model should be saved to
    model : AbstractPyTorchJITNetwork or None
        the model which should be saved
        if None: empty dict will be saved as state dict
    optimizers : dict
        dictionary containing all optimizers
    epoch : int
        current epoch (will also be pickled)

    """

    # remove file extension if given
    if any([file.endswith(ext) for ext in [".pth", ".pt", ".ptj"]]):
        file = file.rsplit(".", 1)[0]

    if isinstance(model, AbstractPyTorchJITNetwork):
        torch.jit.save(model, file + "_model.ptj")

    if optimizers or epoch is not None:
        save_checkpoint_torch(file + "_trainer_state.pt", None,
                        optimizers=optimizers, epoch=epoch, **kwargs)

Loading

To load a model, which has been saved to disk by this function we have to revert each part of it. We do this by using torch.jit.load for the model (and the graph) and load_checkpoint_torch by the PyTorch backend. The actual implementation is given here:

def load_checkpoint_torchscript(file: str, **kwargs):
    """
    Loads a saved checkpoint consisting of 2 files
    (see :func:`save_checkpoint_jit` for details)

    Parameters
    ----------
    file : str
        filepath to a file containing a saved model
    **kwargs:
        Additional keyword arguments (passed to torch.load)
        Especially "map_location" is important to change the device the
        state_dict should be loaded to

    Returns
    -------
    OrderedDict
        checkpoint state_dict

    """
    # remove file extensions
    if any([file.endswith(ext) for ext in [".pth", ".pt", ".ptj"]]):
        file = file.rsplit(".", 1)[0]

    # load model
    if os.path.isfile(file + ".ptj"):
        model_file = file
    elif os.path.isfile(file + "_model.ptj"):
        model_file = file + "_model.ptj"
    else:
        raise ValueError("No Model File found for %s" % file)

    # load trainer state (if possible)
    trainer_file = model_file.replace("_model.ptj", "_trainer_state.pt")
    if os.path.isfile(trainer_file):
        trainer_state = load_checkpoint_torch(trainer_file, **kwargs)

    else:
        trainer_state = {"optimizer": {},
                         "epoch": None}

    trainer_state.update({"model": torch.jit.load(model_file)})

    return trainer_state

A Trainer to train

Now, that we can define and save/load our models, we want to train them. Luckily delira has already implemented a very modular backend-agnostic trainer (the BaseNetworkTrainer) and build upon this a PyTorchNetworkTrainer. Since the training process in PyTorch and TorchScript is nearly the same, we can just extend the PyTorchNetworkTrainer. Usually one would have to extend the BaseNetworkTrainer to provide some backend specific functions (like necessary initializations, optimizer setup, seeding etc.). To see how this is done, you could either have a look at the PyTorchNetworkTrainer or the TfNetworkTrainer for tensorflow, which are both following this principle. Usually the only stuff to completely change is the loading/saving behavior and the _setup function, which defines the backend-specific initialization. Some other functions may have to be extended (by implementing the extension and calling the parent-classes function).

Things to change:

By Subclassing the PyTorchNetworkTrainer we have to change the following things:

  • The trainer's default arguments
  • The behavior for trying to resume a previous training
  • The saving, loading and updating behavior

We will access this one by one:

The Default Arguments

We want to use AbstractTorchScriptNetworks instead of AbstractPyTorchNetworks here and we have to change the behavior if passing multiple GPUs, because currently Multi-GPU training is not supported by TorchScript.

To do this: we implement the functions __init__, apply our changes and forward these changes to the call of the base-classes __init__ like this (omitted docstrings for the sake of shortness):

class TorchScriptNetworkTrainer(PyTorchNetworkTrainer):
        def __init__(self,
                     network: AbstractTorchScriptNetwork,
                     save_path: str,
                     key_mapping,
                     losses=None,
                     optimizer_cls=None,
                     optimizer_params={},
                     train_metrics={},
                     val_metrics={},
                     lr_scheduler_cls=None,
                     lr_scheduler_params={},
                     gpu_ids=[],
                     save_freq=1,
                     optim_fn=create_optims_default,
                     logging_type="tensorboardx",
                     logging_kwargs={},
                     fold=0,
                     callbacks=[],
                     start_epoch=1,
                     metric_keys=None,
                     convert_batch_to_npy_fn=convert_torch_tensor_to_npy,
                     criterions=None,
                     val_freq=1,
                     **kwargs):

            if len(gpu_ids) > 1:
                # only use first GPU due to
                # https://github.com/pytorch/pytorch/issues/15421
                gpu_ids = [gpu_ids[0]]
                logging.warning("Multiple GPUs specified. Torch JIT currently "
                                "supports only single-GPU training. "
                                "Switching to use only the first GPU for now...")

            super().__init__(network=network, save_path=save_path,
                             key_mapping=key_mapping, losses=losses,
                             optimizer_cls=optimizer_cls,
                             optimizer_params=optimizer_params,
                             train_metrics=train_metrics,
                             val_metrics=val_metrics,
                             lr_scheduler_cls=lr_scheduler_cls,
                             lr_scheduler_params=lr_scheduler_params,
                             gpu_ids=gpu_ids, save_freq=save_freq,
                             optim_fn=optim_fn, logging_type=logging_type,
                             logging_kwargs=logging_kwargs, fold=fold,
                             callbacks=callbacks,
                             start_epoch=start_epoch, metric_keys=metric_keys,
                             convert_batch_to_npy_fn=convert_batch_to_npy_fn,
                             mixed_precision=False, mixed_precision_kwargs={},
                             criterions=criterions, val_freq=val_freq, **kwargs
                             )

Resuming Training

For resuming the training, we have to completely change the try_resume_training function and cannot reuse the parent's implementation of it. Thus, we don't call super().try_resume_training here, but completely reimplement it from scratch:

def try_resume_training(self):
    """
    Load the latest state of a previous training if possible

    """
    # Load latest epoch file if available
    if os.path.isdir(self.save_path):
        # check all files in directory starting with "checkpoint" and
        # not ending with "_best.pth"
        files = [x for x in os.listdir(self.save_path)
                 if os.path.isfile(os.path.join(self.save_path, x))
                 and x.startswith("checkpoint")
                 and not x.endswith("_best.ptj")
                 ]

        # if list is not empty: load previous state
        if files:

            latest_epoch = max([
                int(x.rsplit("_", 1)[-1].rsplit(".", 1)[0])
                for x in files])

            latest_state_path = os.path.join(self.save_path,
                                             "checkpoint_epoch_%d.ptj"
                                             % latest_epoch)

            # if pth file does not exist, load pt file instead
            if not os.path.isfile(latest_state_path):
                latest_state_path = latest_state_path[:-1]

            logger.info("Attempting to load state from previous \
                        training from %s" % latest_state_path)
            try:
                self.update_state(latest_state_path)
            except KeyError:
                logger.warning("Previous State could not be loaded, \
                                although it exists.Training will be \
                                restarted")

Saving and Loading

Now we need to change the saving and loading behavior. As always we try to reuse as much code as possible to avoid code duplication.

Saving

To save the current training state, we simply call the save_checkpoint_torchscript function:

def save_state(self, file_name, epoch, **kwargs):
    """
    saves the current state via
    :func:`delira.io.torch.save_checkpoint_jit`

    Parameters
    ----------
    file_name : str
        filename to save the state to
    epoch : int
        current epoch (will be saved for mapping back)
    **kwargs :
        keyword arguments

    """
    if file_name.endswith(".pt") or file_name.endswith(".pth"):
        file_name = file_name.rsplit(".", 1)[0]

    save_checkpoint_torchscript(file_name, self.module, self.optimizers,
                                **kwargs)
Loading

To load the training state, we simply return the state loaded by load_checkpoint_torchscript. Since we don't use any arguments of the trainer itself here, the function is a staticmethod:

@staticmethod
def load_state(file_name, **kwargs):
    """
    Loads the new state from file via
    :func:`delira.io.torch.load_checkpoint:jit`

    Parameters
    ----------
    file_name : str
        the file to load the state from
    **kwargs : keyword arguments

    Returns
    -------
    dict
        new state

    """
    return load_checkpoint_torchscript(file_name, **kwargs)
Updating

After we loaded the new state, we need to update the trainer's internal state by this new state.

We do this by directly assigning the model here (since the graph was stored/loaded too) instead of only updating the state_dict and calling the parent-classes method afterwards:

def _update_state(self, new_state):
    """
    Update the state from a given new state

    Parameters
    ----------
    new_state : dict
        new state to update internal state from

    Returns
    -------
    :class:`PyTorchNetworkJITTrainer`
        the trainer with a modified state

    """
    if "model" in new_state:
        self.module = new_state.pop("model").to(self.input_device)

    return super()._update_state(new_state)

A Whole Trainer

After combining all the changes above, we finally get our new trainer as:

class TorchScriptNetworkTrainer(PyTorchNetworkTrainer):
    def __init__(self,
                 network: AbstractTorchScriptNetwork,
                 save_path: str,
                 key_mapping,
                 losses=None,
                 optimizer_cls=None,
                 optimizer_params={},
                 train_metrics={},
                 val_metrics={},
                 lr_scheduler_cls=None,
                 lr_scheduler_params={},
                 gpu_ids=[],
                 save_freq=1,
                 optim_fn=create_optims_default,
                 logging_type="tensorboardx",
                 logging_kwargs={},
                 fold=0,
                 callbacks=[],
                 start_epoch=1,
                 metric_keys=None,
                 convert_batch_to_npy_fn=convert_torch_tensor_to_npy,
                 criterions=None,
                 val_freq=1,
                 **kwargs):
        """

        Parameters
        ----------
        network : :class:`AbstractPyTorchJITNetwork`
            the network to train
        save_path : str
            path to save networks to
        key_mapping : dict
            a dictionary containing the mapping from the ``data_dict`` to
            the actual model's inputs.
            E.g. if a model accepts one input named 'x' and the data_dict
            contains one entry named 'data' this argument would have to
            be ``{'x': 'data'}``
        losses : dict
            dictionary containing the training losses
        optimizer_cls : subclass of tf.train.Optimizer
            optimizer class implementing the optimization algorithm of
            choice
        optimizer_params : dict
            keyword arguments passed to optimizer during construction
        train_metrics : dict, optional
            metrics, which will be evaluated during train phase
            (should work on framework's tensor types)
        val_metrics : dict, optional
            metrics, which will be evaluated during test phase
            (should work on numpy arrays)
        lr_scheduler_cls : Any
            learning rate schedule class: must implement step() method
        lr_scheduler_params : dict
            keyword arguments passed to lr scheduler during construction
        gpu_ids : list
            list containing ids of GPUs to use; if empty: use cpu instead
            Currently ``torch.jit`` only supports single GPU-Training,
            thus only the first GPU will be used if multiple GPUs are passed
        save_freq : int
            integer specifying how often to save the current model's state.
            State is saved every state_freq epochs
        optim_fn : function
            creates a dictionary containing all necessary optimizers
        logging_type : str or callable
            the type of logging. If string: it must be one of
            ["visdom", "tensorboardx"]
            If callable: it must be a logging handler class
        logging_kwargs : dict
            dictionary containing all logging keyword arguments
        fold : int
            current cross validation fold (0 per default)
        callbacks : list
            initial callbacks to register
        start_epoch : int
            epoch to start training at
        metric_keys : dict
            dict specifying which batch_dict entry to use for which metric as
            target; default: None, which will result in key "label" for all
            metrics
        convert_batch_to_npy_fn : type, optional
            function converting a batch-tensor to numpy, per default this is
            a function, which detaches the tensor, moves it to cpu and the
            calls ``.numpy()`` on it
        mixed_precision : bool
            whether to use mixed precision or not (False per default)
        mixed_precision_kwargs : dict
            additional keyword arguments for mixed precision
        val_freq : int
            validation frequency specifying how often to validate the trained
            model (a value of 1 denotes validating every epoch,
            a value of 2 denotes validating every second epoch etc.);
            defaults to 1
        **kwargs :
            additional keyword arguments

        """

        if len(gpu_ids) > 1:
            # only use first GPU due to
            # https://github.com/pytorch/pytorch/issues/15421
            gpu_ids = [gpu_ids[0]]
            logging.warning("Multiple GPUs specified. Torch JIT currently "
                            "supports only single-GPU training. "
                            "Switching to use only the first GPU for now...")

        super().__init__(network=network, save_path=save_path,
                         key_mapping=key_mapping, losses=losses,
                         optimizer_cls=optimizer_cls,
                         optimizer_params=optimizer_params,
                         train_metrics=train_metrics,
                         val_metrics=val_metrics,
                         lr_scheduler_cls=lr_scheduler_cls,
                         lr_scheduler_params=lr_scheduler_params,
                         gpu_ids=gpu_ids, save_freq=save_freq,
                         optim_fn=optim_fn, logging_type=logging_type,
                         logging_kwargs=logging_kwargs, fold=fold,
                         callbacks=callbacks,
                         start_epoch=start_epoch, metric_keys=metric_keys,
                         convert_batch_to_npy_fn=convert_batch_to_npy_fn,
                         mixed_precision=False, mixed_precision_kwargs={},
                         criterions=criterions, val_freq=val_freq, **kwargs
                         )

    def try_resume_training(self):
        """
        Load the latest state of a previous training if possible

        """
        # Load latest epoch file if available
        if os.path.isdir(self.save_path):
            # check all files in directory starting with "checkpoint" and
            # not ending with "_best.pth"
            files = [x for x in os.listdir(self.save_path)
                     if os.path.isfile(os.path.join(self.save_path, x))
                     and x.startswith("checkpoint")
                     and not x.endswith("_best.ptj")
                     ]

            # if list is not empty: load previous state
            if files:

                latest_epoch = max([
                    int(x.rsplit("_", 1)[-1].rsplit(".", 1)[0])
                    for x in files])

                latest_state_path = os.path.join(self.save_path,
                                                 "checkpoint_epoch_%d.ptj"
                                                 % latest_epoch)

                # if pth file does not exist, load pt file instead
                if not os.path.isfile(latest_state_path):
                    latest_state_path = latest_state_path[:-1]

                logger.info("Attempting to load state from previous \
                            training from %s" % latest_state_path)
                try:
                    self.update_state(latest_state_path)
                except KeyError:
                    logger.warning("Previous State could not be loaded, \
                                    although it exists.Training will be \
                                    restarted")

    def save_state(self, file_name, epoch, **kwargs):
        """
        saves the current state via
        :func:`delira.io.torch.save_checkpoint_jit`

        Parameters
        ----------
        file_name : str
            filename to save the state to
        epoch : int
            current epoch (will be saved for mapping back)
        **kwargs :
            keyword arguments

        """
        if file_name.endswith(".pt") or file_name.endswith(".pth"):
            file_name = file_name.rsplit(".", 1)[0]

        save_checkpoint_torchscript(file_name, self.module, self.optimizers,
                                    **kwargs)

    @staticmethod
    def load_state(file_name, **kwargs):
        """
        Loads the new state from file via
        :func:`delira.io.torch.load_checkpoint:jit`

        Parameters
        ----------
        file_name : str
            the file to load the state from
        **kwargs : keyword arguments

        Returns
        -------
        dict
            new state

        """
        return load_checkpoint_torchscript(file_name, **kwargs)

    def _update_state(self, new_state):
        """
        Update the state from a given new state

        Parameters
        ----------
        new_state : dict
            new state to update internal state from

        Returns
        -------
        :class:`PyTorchNetworkJITTrainer`
            the trainer with a modified state

        """
        if "model" in new_state:
            self.module = new_state.pop("model").to(self.input_device)

        return super()._update_state(new_state)

Wrapping it all in an Experiment

To have access to methods like a K-Fold (and the not yet finished) hyperparameter tuning, we need to wrap the trainer in an Experiment. We will use the same approach as we did for implementing the trainer: Extending an already provided class.

This time we extend the PyTorchExperiment which itself extends the BaseExperiment by some backend-specific defaults, types and seeds.

Our whole class definition just changes the default arguments of the PyTorchExperiment and thus, we only have to implenent it's __init__:

class TorchScriptExperiment(PyTorchExperiment):
    def __init__(self,
                 params: typing.Union[str, Parameters],
                 model_cls: AbstractTorchScriptNetwork, # not AbstractPyTorchNetwork anymore
                 n_epochs=None,
                 name=None,
                 save_path=None,
                 key_mapping=None,
                 val_score_key=None,
                 optim_builder=create_optims_default_pytorch,
                 checkpoint_freq=1,
                 trainer_cls=TorchScriptNetworkTrainer, # not PyTorchNetworkTrainer anymore
                 **kwargs):
        """

        Parameters
        ----------
        params : :class:`Parameters` or str
            the training parameters, if string is passed,
            it is treated as a path to a pickle file, where the
            parameters are loaded from
        model_cls : Subclass of :class:`AbstractTorchScriptNetwork`
            the class implementing the model to train
        n_epochs : int or None
            the number of epochs to train, if None: can be specified later
            during actual training
        name : str or None
            the Experiment's name
        save_path : str or None
            the path to save the results and checkpoints to.
            if None: Current working directory will be used
        key_mapping : dict
            mapping between data_dict and model inputs (necessary for
            prediction with :class:`Predictor`-API), if no keymapping is
            given, a default key_mapping of {"x": "data"} will be used here
        val_score_key : str or None
            key defining which metric to use for validation (determining
            best model and scheduling lr); if None: No validation-based
            operations will be done (model might still get validated,
            but validation metrics can only be logged and not used further)
        optim_builder : function
            Function returning a dict of backend-specific optimizers.
            defaults to :func:`create_optims_default_pytorch`
        checkpoint_freq : int
            frequency of saving checkpoints (1 denotes saving every epoch,
            2 denotes saving every second epoch etc.); default: 1
        trainer_cls : subclass of :class:`TorchScriptNetworkTrainer`
            the trainer class to use for training the model, defaults to
            :class:`TorchScriptNetworkTrainer`
        **kwargs :
            additional keyword arguments

        """
        super().__init__(params=params, model_cls=model_cls,
                         n_epochs=n_epochs, name=name, save_path=save_path,
                         key_mapping=key_mapping,
                         val_score_key=val_score_key,
                         optim_builder=optim_builder,
                         checkpoint_freq=checkpoint_freq,
                         trainer_cls=trainer_cls,
                         **kwargs)

Testing it

Now that we finished the implementation of the backend (which is the outermost wrapper; Congratulations!), we can just test it. We'll use a very simple network and test it with dummy data. We also only test the run and test functionality of our experiment, since everything else is just used for setting up the internal state or a composition of these two methods and already tested: Now, let's just define our dataset, instantiate it three times (for training, validation and testing) and wrap each of them into a DataManager:

from delira.data_loading import AbstractDataset
from delira.data_loading import DataManager


class DummyDataset(AbstractDataset):
    def __init__(self, length):
        super().__init__(None, None)
        self.length = length

    def __getitem__(self, index):
        return {"data": np.random.rand(32),
                "label": np.random.randint(0, 1, 1)}

    def __len__(self):
        return self.length

    def get_sample_from_index(self, index):
        return self.__getitem__(index)

dset_train = DummyDataset(500)
dset_val = DummyDataset(50)
dset_test = DummyDataset(10)

# training, validation and testing with
#a batchsize of 16, 1 loading thread and no transformations.
dmgr_train = DataManager(dset_train, 16, 1, None)
dmgr_val = DataManager(dset_val, 16, 1, None)
dmgr_test = DataManager(dset_test, 16, 1, None)

Now, that we have created three datasets, we need to define our small dummy network. We do this by subclassing delira.models.AbstractTorchScriptNetwork (which is the exactly implementation given above, be we need to use the internal one, because there are some typechecks against this one).

from delira.models import AbstractTorchScriptNetwork
import torch


class DummyNetworkTorchScript(AbstractTorchScriptNetwork):
    __constants__ = ["module"]

    def __init__(self):
        super().__init__()
        self.module = self._build_model(32, 1)

    @torch.jit.script_method
    def forward(self, x):
        return {"pred": self.module(x)}

    @staticmethod
    def prepare_batch(batch_dict, input_device, output_device):
        return {"data": torch.from_numpy(batch_dict["data"]
                                         ).to(input_device,
                                              torch.float),
                "label": torch.from_numpy(batch_dict["label"]
                                          ).to(output_device,
                                               torch.float)}

    @staticmethod
    def closure(model: AbstractTorchScriptNetwork, data_dict: dict,
                optimizers: dict, losses={}, metrics={},
                fold=0, **kwargs):
        """
        closure method to do a single backpropagation step


        Parameters
        ----------
        model :
            trainable model
        data_dict : dict
            dictionary containing the data
        optimizers : dict
            dictionary of optimizers to optimize model's parameters
        losses : dict
            dict holding the losses to calculate errors
            (gradients from different losses will be accumulated)
        metrics : dict
            dict holding the metrics to calculate
        fold : int
            Current Fold in Crossvalidation (default: 0)
        **kwargs:
            additional keyword arguments

        Returns
        -------
        dict
            Metric values (with same keys as input dict metrics)
        dict
            Loss values (with same keys as input dict losses)
        list
            Arbitrary number of predictions as torch.Tensor

        Raises
        ------
        AssertionError
            if optimizers or losses are empty or the optimizers are not
            specified

        """

        assert (optimizers and losses) or not optimizers, \
            "Criterion dict cannot be emtpy, if optimizers are passed"

        loss_vals = {}
        metric_vals = {}
        total_loss = 0

        # choose suitable context manager:
        if optimizers:
            context_man = torch.enable_grad

        else:
            context_man = torch.no_grad

        with context_man():

            inputs = data_dict.pop("data")
            preds = model(inputs)

            if data_dict:

                for key, crit_fn in losses.items():
                    _loss_val = crit_fn(preds["pred"], *data_dict.values())
                    loss_vals[key] = _loss_val.item()
                    total_loss += _loss_val

                with torch.no_grad():
                    for key, metric_fn in metrics.items():
                        metric_vals[key] = metric_fn(
                            preds["pred"], *data_dict.values()).item()

        if optimizers:
            optimizers['default'].zero_grad()
            # perform loss scaling via apex if half precision is enabled
            with optimizers["default"].scale_loss(total_loss) as scaled_loss:
                scaled_loss.backward()
            optimizers['default'].step()

        else:

            # add prefix "val" in validation mode
            eval_loss_vals, eval_metrics_vals = {}, {}
            for key in loss_vals.keys():
                eval_loss_vals["val_" + str(key)] = loss_vals[key]

            for key in metric_vals:
                eval_metrics_vals["val_" + str(key)] = metric_vals[key]

            loss_vals = eval_loss_vals
            metric_vals = eval_metrics_vals

        return metric_vals, loss_vals, {k: v.detach()
                                        for k, v in preds.items()}

    @staticmethod
    def _build_model(in_channels, n_outputs):
        return torch.nn.Sequential(
            torch.nn.Linear(in_channels, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, n_outputs)
        )

Now, that we defined our model, let's just test, if we really can forward some tensors through it. We will just use some random torch.Tensors (created by torch.rand). Since our model accepts 1d inputs of length 32, we need to pass 2d tensors to it (the additional dimension is the batch-dimension).

input_tensor_single = torch.rand(1, 32) # use a single-sample batch (batchsize=1) here
input_tensor_batched = torch.rand(4, 32) # use a batch with batchsize 4 here

# create model instance
model = DummyNetworkTorchScript()

outputs = {"single": model(input_tensor_single)["pred"], "batched": model(input_tensor_batched)["pred"]}
outputs
{'single': tensor([[-0.1934]], grad_fn=<DifferentiableGraphBackward>),
 'batched': tensor([[-0.0525],
         [-0.0884],
         [-0.1492],
         [-0.0431]], grad_fn=<DifferentiableGraphBackward>)}
from sklearn.metrics import mean_absolute_error
from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch
from delira.training import Parameters
params = Parameters(fixed_params={
                    "model": {},
                    "training": {
                        "losses": {"CE": torch.nn.BCEWithLogitsLoss()},
                        "optimizer_cls": torch.optim.Adam,
                        "optimizer_params": {"lr": 1e-3},
                        "num_epochs": 2,
                        "val_metrics": {"mae": mean_absolute_error},
                        "lr_sched_cls": ReduceLROnPlateauCallbackPyTorch,
                        "lr_sched_params": {"mode": "min"}
                    }
                }
          )

from delira.training import TorchScriptExperiment

exp = TorchScriptExperiment(params, DummyNetworkTorchScript,
                            key_mapping={"x": "data"},
                            val_score_key="mae",
                            val_score_mode="min")

trained_model = exp.run(dmgr_train, dmgr_val)
exp.test(trained_model, dmgr_test, params.nested_get("val_metrics"))

Congratulations. You have implemented your first fully-workable delira-Backend. Wasn't that hard, was it?

Before you start implementing backends for all the other frameworks out there, let me just give you some advices:

  • You should test everything you implement or extend
  • Make sure, to keep your backend-specification in mind
  • Always follow the API of already existing backends. If this is not possible: test this extensively
  • If you extend another backend (like we did here; we extended the PyTorch-backend for TorchScript), make sure, that the "base-backend" is always installed (best if they can only be installed together)
  • If you have questions regarding the implementation, don't hestiate to contact us.
You can’t perform that action at this time.