In [None]:
# default_exp loss

# Loss

In [None]:
#hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export
from fastai.basics import *

In [None]:
#export

from typing import Callable
from dataclasses import dataclass

@dataclass
class LossRouting:
    loss_func: Callable
    pred_idx: int
    target_idx: int
    weight: float = 1.0

class CombinedLoss:
    """Applies loss functions to multiple model outputs and sums them. 
    If applicable, it can decode and compute activations for each model output."""

    def __init__(self, *loss_routings):
        self.loss_routings = loss_routings
    
    def __call__(self, outs, *targets, **kwargs):
        return sum([
            routing.weight*routing.loss_func(outs[routing.pred_idx], targets[routing.target_idx]) 
            for routing in self.loss_routings
        ])
    
    def activation(self, outs): 
        return [getattr(routing.loss_func, 'activation', noop)(outs[routing.pred_idx]) for routing in self.loss_routings]
    
    def decodes(self, outs):
        return [getattr(routing.loss_func, 'decodes', noop)(outs[routing.pred_idx]) for routing in self.loss_routings]
    
    @classmethod
    def from_one_to_one_routing(cls, *loss_funcs):
        return cls(*[LossRouting(loss_func, pred_idx=i, target_idx=i, weight=1.0) for i, loss_func in enumerate(loss_funcs)])


Assume that a multi-task learning model produces two outputs:
1. The logits for multi-class single-label classification, for which we want to use cross-entropy loss and softmax activation
2. A logit for single-class classification, for which we want to use binary cross-entropy and sigmoid activation

`CombinedLoss` enables using the corresponding loss function and its activation function for each model output.

In [None]:
from fastai.vision.all import *

ce = CrossEntropyLossFlat()
bce = BCEWithLogitsLossFlat()
comb_loss = CombinedLoss.from_one_to_one_routing(ce, bce)

bs = 8
target1, output1 = torch.randint(0, 5, (bs,)), torch.randn(bs, 5) # 5 classes
target2, output2 = torch.randint(0, 2, (bs,), dtype=float), torch.randn(bs)*10
actual = comb_loss((output1, output2), target1, target2)

loss1 = ce(output1, target1)
loss2 = bce(output2, target2)
expected = loss1 + loss2
test_close(expected, actual)

# activations
actual_acts_output1, actual_acts_output2 = comb_loss.activation([output1, output2])
expected_acts_output1, expected_acts_output2 = ce.activation(output1), bce.activation(output2)
test_close(expected_acts_output1, actual_acts_output1)
test_eq(expected_acts_output2, actual_acts_output2)

# decoding
actual_decoded_output1, actual_decoded_output2 = comb_loss.decodes([output1, output2])
expected_decoded_output1, expected_decoded_output2 = ce.decodes(output1), bce.decodes(output2)
test_close(expected_decoded_output1, actual_decoded_output1)
test_eq(expected_decoded_output2, actual_decoded_output2)


Here are raw model outputs (logits):

In [None]:
[output1, output2]

[tensor([[ 0.1850, -1.1469,  1.2270,  0.8119, -1.4351],
         [-1.0543,  1.5350, -1.0777,  0.4074,  0.2599],
         [ 0.1193,  2.9036,  0.1536,  1.0384, -1.2029],
         [-0.8484,  1.8163,  0.6120,  1.1216,  0.4150],
         [ 1.3586, -1.7410,  1.6133, -1.1984,  0.9948],
         [ 0.8181, -1.2083, -1.2216, -0.7513,  0.2764],
         [ 1.2261, -1.1696, -1.1735,  1.0326, -0.0525],
         [ 0.2331,  0.1555,  1.1627, -1.2526, -0.0451]]),
 tensor([  3.2092,  -6.4007,  -0.3515,  -3.4715,  -1.5865, -16.1386,   9.8419,
         -13.8744])]

When applicable, it can decode the raw model outputs and compute activations. For instance, let's decode logits to class label indices and binary classes.

In [None]:
comb_loss.decodes([output1, output2])

[tensor([2, 1, 1, 1, 2, 0, 0, 2]),
 tensor([ True, False, False, False, False, False,  True, False])]

Similary, here are the activations for each model output.

In [None]:
comb_loss.activation([output1, output2])

[tensor([[0.1621, 0.0428, 0.4596, 0.3034, 0.0321],
         [0.0429, 0.5709, 0.0419, 0.1849, 0.1595],
         [0.0476, 0.7710, 0.0493, 0.1194, 0.0127],
         [0.0329, 0.4728, 0.1418, 0.2360, 0.1164],
         [0.3218, 0.0145, 0.4151, 0.0250, 0.2236],
         [0.4874, 0.0642, 0.0634, 0.1015, 0.2835],
         [0.4378, 0.0399, 0.0397, 0.3607, 0.1219],
         [0.1837, 0.1700, 0.4655, 0.0416, 0.1391]]),
 tensor([9.6118e-01, 1.6577e-03, 4.1302e-01, 3.0135e-02, 1.6988e-01, 9.7967e-08,
         9.9995e-01, 9.4281e-07])]