# Custom Loss Wrapper

In this example we will see how we can create a custom loss_wrapper. This loss_wraper will calculate the sum of two losses.

#### Import Statements

In [None]:
from pytorch_wrapper.loss_wrappers import AbstractLossWrapper


#### Loss Wrapper definition


In order to create a custom loss wrapper we need to inherit from `pytorch_wrapper.loss_wrappers.AbstractLossWrapper` and implement the `calculate_loss` method. This method must take as input the output of the model, the current batch, the training_context, and the last activation and must return the resulted loss. This method will be called during
training for each batch.

In [None]:
class SumLossWrapper(AbstractLossWrapper):
    def __init__(self, loss1, loss2, model_output_key=None, batch_target_key='target'):
        super(SumLossWrapper, self).__init__()
        self._loss1 = loss1
        self._loss2 = loss2
        self._model_output_key = model_output_key
        self._batch_target_key = batch_target_key

    def calculate_loss(self, output, batch, training_context, last_activation=None):
        if self._model_output_key is not None:
            output = output[self._model_output_key]

        batch_targets = batch[self._batch_target_key].to(output.device)

        return self._loss1(output, batch_targets) + self._loss2(output, batch_targets)


Now we can create this loss wrapper and pass it to the `train` method of a `System` object.