In [None]:
# default_exp loss

# Loss function

In [None]:
#hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export
from fastai.basics import *

In [None]:
#export

class CombinedLoss():
    """Applies loss functions to multiple model outputs and sums them. If applicable, it can decode and compute activations for each model output."""
    def __init__(self, *loss_funcs, weight=None):
        if weight is None:
            weight = [1.]*len(loss_funcs)
        self.weight = weight
        self.loss_funcs = loss_funcs
    
    def __call__(self, outs, *targets, **kwargs):
        return sum([
            w*loss_func(out, target) 
            for loss_func, w, out, target in zip(self.loss_funcs, self.weight, outs, targets)
        ])
    
    def activation(self, outs): 
        return [getattr(loss_func, 'activation', noop)(out) for loss_func, out in zip(self.loss_funcs, outs)]
    
    def decodes(self, outs):
        return [getattr(loss_func, 'decodes', noop)(out) for loss_func, out in zip(self.loss_funcs, outs)]
    

Assume that a multi-task learning model produces two outputs:
1. The logits for multi-class single-label classification, for which we want to use cross-entropy loss and softmax activation
2. A logit for single-class classification, for which we want to use binary cross-entropy and sigmoid activation

`CombinedLoss` enables using the corresponding loss function and its activation function for each model output.

In [None]:
from fastai.vision.all import *

ce = CrossEntropyLossFlat()
bce = BCEWithLogitsLossFlat()
comb_loss = CombinedLoss(ce, bce)

bs = 8
target1, output1 = torch.randint(0, 5, (bs,)), torch.randn(bs, 5) # 5 classes
target2, output2 = torch.randint(0, 2, (bs,), dtype=float), torch.randn(bs)*10
actual = comb_loss((output1, output2), target1, target2)

loss1 = ce(output1, target1)
loss2 = bce(output2, target2)
expected = loss1 + loss2
test_close(expected, actual)

Here are raw model outputs (logits):

In [None]:
[output1, output2]

[tensor([[-0.1412,  0.5191, -0.8597,  0.8974,  0.9123],
         [ 0.5627, -0.3377, -0.1815, -0.5530,  0.0656],
         [-2.3806, -0.1824,  0.0982,  0.3367, -0.4743],
         [ 1.8762, -0.6613,  0.9129,  0.4719, -0.8340],
         [ 1.3965,  1.0323,  0.5446,  0.6935,  1.0835],
         [-2.0293,  0.5273, -0.7488, -1.1144, -0.7592],
         [ 0.2765, -1.1856, -0.1731,  0.8288,  0.1402],
         [-1.3383, -1.0835, -0.8836, -0.3483, -0.1096]]),
 tensor([ -8.0780,   3.7517, -10.2364, -11.9647,  -5.8247,  -5.0219,  -0.8134,
         -10.2135])]

When applicable, it can decode the raw model outputs and compute activations. For instance, let's decode logits to class label indices and binary classes.

In [None]:
comb_loss.decodes([output1, output2])

[tensor([4, 0, 3, 0, 0, 1, 3, 4]),
 tensor([False,  True, False, False, False, False, False, False])]

Similary, here are the activations for each model output.

In [None]:
comb_loss.activation([output1, output2])

[tensor([[0.1097, 0.2123, 0.0535, 0.3099, 0.3146],
         [0.3549, 0.1442, 0.1686, 0.1163, 0.2159],
         [0.0228, 0.2057, 0.2723, 0.3456, 0.1536],
         [0.5641, 0.0446, 0.2153, 0.1385, 0.0375],
         [0.2987, 0.2075, 0.1274, 0.1479, 0.2184],
         [0.0425, 0.5475, 0.1528, 0.1060, 0.1512],
         [0.2232, 0.0517, 0.1424, 0.3878, 0.1948],
         [0.1003, 0.1294, 0.1580, 0.2698, 0.3426]]),
 tensor([3.1019e-04, 9.7706e-01, 3.5840e-05, 6.3648e-06, 2.9451e-03, 6.5491e-03,
         3.0717e-01, 3.6672e-05])]