In [None]:
# default_exp loss

# Model and loss function

In [None]:
#hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export
from fastai.basics import *

In [None]:
#export

class CombinedLoss():
    def __init__(self, *loss_funcs, weight=None):
        if weight is None:
            weight = [1.]*len(loss_funcs)
        self.weight = weight
        self.loss_funcs = loss_funcs
    
    def __call__(self, outs, *targets, **kwargs):
        return sum([
            w*loss_func(out, target) 
            for loss_func, w, out, target in zip(self.loss_funcs, self.weight, outs, targets)
        ])
    
    def activation(self, outs): 
        return [getattr(loss_func, 'activation', noop)(out) for loss_func, out in zip(self.loss_funcs, outs)]
    
    def decodes(self, outs):
        return [getattr(loss_func, 'decodes', noop)(out) for loss_func, out in zip(self.loss_funcs, outs)]
    

In [None]:
#hide
from fastai.vision.all import *

ce = CrossEntropyLossFlat()
fl = FocalLossFlat()
comb_loss = CombinedLoss(ce, fl)

target1, output1 = torch.randint(0, 10, (8,5)), torch.randn(8, 5, 10)
target2, output2 = torch.randint(0, 4, (8,5)), torch.randn(8, 5, 4)
actual = comb_loss((output1, output2), target1, target2)

loss1 = ce(output1, target1)
loss2 = fl(output2, target2)
expected = loss1 + loss2
test_close(expected, actual)