In [1]:
# test multihead_loss

In [2]:
import numpy as np
import math

import torch
import torch.nn.functional as F
from torch.autograd import Variable

In [3]:
class Multihead_loss(torch.nn.Module):
    """
    Compute the loss on multiple outputs.

    Arguments:
        outputs: List of network outputs.
        target: List of targets where len(outputs) = len(target).
        loss_function: either list of loss functions with
        len(loss_function) = len(targets) or len(loss_function) = 1.
    """
    def __init__(self):
        super(Multihead_loss, self).__init__()

    def forward(self, outputs, target, loss_function):
        assert(len(outputs) == len(target))
        assert(len(loss_function) == len(target) or len(loss_function) == 1)
        # expand loss_function list if univariate
        if len(loss_function) == 1:
            loss_function = [loss_function[0] for i in range(len(target))]
        # compute loss for each head
        total_loss = 0.
        for out, gt, loss_func in zip(outputs, target, loss_function):
            loss = loss_func(out, gt)
            total_loss += loss
        return total_loss

In [4]:
multihead_loss = Multihead_loss()

In [5]:
def test_01():
    outputs = [Variable(torch.FloatTensor([0, 0, 0])),
               Variable(torch.FloatTensor([0, 1, 0])),
               Variable(torch.FloatTensor([1, 1, 1]))]

    target = [Variable(torch.FloatTensor([1, 1, 1])),
              Variable(torch.FloatTensor([1, 1, 1])),
              Variable(torch.FloatTensor([1, 1, 1]))]

    loss_function = [F.mse_loss]
    
    total_loss = multihead_loss(outputs, target, loss_function)
    assert(math.isclose(total_loss.item(), 1+(2/3)+0, rel_tol=1e-05))

In [6]:
test_01()

In [7]:
def test_02():
    outputs = [Variable(torch.FloatTensor([0, 0, 0])),
               Variable(torch.FloatTensor([0, 1, 0])),
               Variable(torch.FloatTensor([1, 1, 1]))]

    target = [Variable(torch.FloatTensor([1, 1, 1])),
              Variable(torch.FloatTensor([1, 1, 1])),
              Variable(torch.FloatTensor([1, 1, 1]))]

    loss_function = [F.mse_loss, F.binary_cross_entropy, F.mse_loss]
    
    total_loss = multihead_loss(outputs, target, loss_function)
    assert(math.isclose(total_loss.item(), 1+-math.log(1e-8)+0, rel_tol=1e-07))

In [8]:
test_02()