In [None]:
# default_exp loss

# Loss

In [None]:
#hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export
from fastai.basics import *

In [None]:
#export

class CombinedLoss():
    """Applies loss functions to multiple model outputs and sums them. 
    If applicable, it can decode and compute activations for each model output."""

    def __init__(self, *loss_funcs, weight=None):
        if weight is None:
            weight = [1.]*len(loss_funcs)
        self.weight = weight
        self.loss_funcs = loss_funcs
    
    def __call__(self, outs, *targets, **kwargs):
        return sum([
            w*loss_func(out, target) 
            for loss_func, w, out, target in zip(self.loss_funcs, self.weight, outs, targets)
        ])
    
    def activation(self, outs): 
        return [getattr(loss_func, 'activation', noop)(out) for loss_func, out in zip(self.loss_funcs, outs)]
    
    def decodes(self, outs):
        return [getattr(loss_func, 'decodes', noop)(out) for loss_func, out in zip(self.loss_funcs, outs)]
    

Assume that a multi-task learning model produces two outputs:
1. The logits for multi-class single-label classification, for which we want to use cross-entropy loss and softmax activation
2. A logit for single-class classification, for which we want to use binary cross-entropy and sigmoid activation

`CombinedLoss` enables using the corresponding loss function and its activation function for each model output.

In [None]:
from fastai.vision.all import *

ce = CrossEntropyLossFlat()
bce = BCEWithLogitsLossFlat()
comb_loss = CombinedLoss(ce, bce)

bs = 8
target1, output1 = torch.randint(0, 5, (bs,)), torch.randn(bs, 5) # 5 classes
target2, output2 = torch.randint(0, 2, (bs,), dtype=float), torch.randn(bs)*10
actual = comb_loss((output1, output2), target1, target2)

loss1 = ce(output1, target1)
loss2 = bce(output2, target2)
expected = loss1 + loss2
test_close(expected, actual)

# activations
actual_acts_output1, actual_acts_output2 = comb_loss.activation([output1, output2])
expected_acts_output1, expected_acts_output2 = ce.activation(output1), bce.activation(output2)
test_close(expected_acts_output1, actual_acts_output1)
test_eq(expected_acts_output2, actual_acts_output2)

# decoding
actual_decoded_output1, actual_decoded_output2 = comb_loss.decodes([output1, output2])
expected_decoded_output1, expected_decoded_output2 = ce.decodes(output1), bce.decodes(output2)
test_close(expected_decoded_output1, actual_decoded_output1)
test_eq(expected_decoded_output2, actual_decoded_output2)


Here are raw model outputs (logits):

In [None]:
[output1, output2]

[tensor([[-1.1966, -0.9368, -1.4137, -0.2984,  0.4431],
         [ 0.0891, -1.4017, -0.6367,  0.2432, -0.5874],
         [ 0.3650, -0.3818,  0.9312, -1.9950,  0.2156],
         [ 0.3676, -1.1119, -1.0802, -1.3735, -0.4440],
         [ 0.9762, -0.8892,  1.0861, -0.4404, -0.2750],
         [-0.3843, -0.0126,  0.7231,  1.7333,  0.3376],
         [ 0.7361,  0.9656, -1.4666, -2.3691, -0.9199],
         [-0.0581, -0.6878, -1.9764, -0.9484, -0.6963]]),
 tensor([  1.9764,  21.9196, -11.3594,  11.4162,  -9.3616,   1.2402,   7.7606,
         -14.2065])]

When applicable, it can decode the raw model outputs and compute activations. For instance, let's decode logits to class label indices and binary classes.

In [None]:
comb_loss.decodes([output1, output2])

[tensor([4, 3, 2, 0, 2, 3, 1, 0]),
 tensor([ True,  True, False,  True, False,  True,  True, False])]

Similary, here are the activations for each model output.

In [None]:
comb_loss.activation([output1, output2])

[tensor([[0.0934, 0.1211, 0.0752, 0.2292, 0.4812],
         [0.2955, 0.0665, 0.1430, 0.3447, 0.1502],
         [0.2386, 0.1131, 0.4203, 0.0225, 0.2055],
         [0.4802, 0.1094, 0.1129, 0.0842, 0.2133],
         [0.3572, 0.0553, 0.3987, 0.0866, 0.1022],
         [0.0631, 0.0915, 0.1910, 0.5245, 0.1299],
         [0.3840, 0.4830, 0.0424, 0.0172, 0.0733],
         [0.3819, 0.2035, 0.0561, 0.1568, 0.2017]]),
 tensor([8.7830e-01, 1.0000e+00, 1.1660e-05, 9.9999e-01, 8.5953e-05, 7.7560e-01,
         9.9957e-01, 6.7640e-07])]