In [51]:
import sys
from collections import OrderedDict

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import tensorflow as tf

print(tf.__version__)
print(torch.__version__)


2.4.1
1.9.0+cpu


In [55]:
def cross_entropy_loss(logits, positive):
    nlogp = -F.log_softmax(logits, dim=0)
    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)


In [56]:
def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
    logp = torch.sigmoid(logits) + offset
    loss = torch.abs(logp - target)
    if mask is not None:
        w = mask.mean(2, True).mean(1, True)
        w[w == 0] = 1
        loss = loss * (mask / w)

    loss = loss.mean(2).mean(1)
    return loss

In [71]:
class MultitaskLearner(nn.Module):
    def __init__(self, backbone):
        super(MultitaskLearner, self).__init__()
        self.backbone = backbone
        # head_size = M.head_size
        head_size = [[2], [1], [2]]
        self.num_class = sum(sum(head_size, []))
        self.head_off = np.cumsum([sum(h) for h in head_size])

    def forward(self, input_dict, outputs, feature):
        image = input_dict["image"]
        # outputs, feature = self.backbone(image)
        result = {"feature": feature}
        batch, channel, row, col = outputs[0].shape

        T = input_dict["target"].copy()
        n_jtyp = T["jmap"].shape[1]

        # switch to CNHW
        for task in ["jmap"]:
            T[task] = T[task].permute(1, 0, 2, 3)
        for task in ["joff"]:
            T[task] = T[task].permute(1, 2, 0, 3, 4)
            

        offset = self.head_off # [2 3 5]
        print("offset", offset)
        # loss_weight = M.loss_weight
        loss_weight = {
            'jmap': 8.0,
            'lmap': 0.5,
            'joff': 0.25,
            'lpos': 1,
            'lneg': 1,
        }
        losses = []
        for stack, output in enumerate(outputs):
            print(f'stack: {stack}')
            print('output shape', output.shape)
            # 5 x N x H X W
            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
            print('output shape after transpose + reshape', output.shape)
            jmap = output[0 : offset[0]].reshape(n_jtyp, 2, batch, row, col)
            lmap = output[offset[0] : offset[1]].squeeze(0)
            print('joff shape', output[offset[1] : offset[2]].shape)
            joff = output[offset[1] : offset[2]].reshape(n_jtyp, 2, batch, row, col)
            
            if stack == 0:
                result["preds"] = {
                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
                    "lmap": lmap.sigmoid(),
                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
                }
                if input_dict["mode"] == "testing":
                    return result

            L = OrderedDict()
            L["jmap"] = sum(
                cross_entropy_loss(jmap[i], T["jmap"][i]) for i in range(n_jtyp) # n_jtype {R, G, B} or gray
            )
            L["lmap"] = (
                F.binary_cross_entropy_with_logits(lmap, T["lmap"], reduction="none")
                .mean(2)
                .mean(1)
            )
            L["joff"] = sum(
                sigmoid_l1_loss(joff[i, j], T["joff"][i, j], -0.5, T["jmap"][i])
                for i in range(n_jtyp)
                for j in range(2)
            )
            for loss_name in L:
                L[loss_name].mul_(loss_weight[loss_name])
            losses.append(L)
        result["losses"] = losses
        return result


In [72]:
outputs = [
    # batch x channel x width x heigth
    np.random.uniform(size=[2, 5, 128, 128]),
    np.random.uniform(size=[2, 5, 128, 128])
]
outputs_torch = [torch.from_numpy(x) for x in outputs]

feature = np.random.uniform(size=[2, 256, 128, 128])
feature_torch = torch.from_numpy(feature)

input_dict = {
    'image': np.random.uniform(size=[2, 3, 512, 512]),
    'target': {
        'jmap': np.random.uniform(size=[2, 1, 128, 128]),
        'joff': np.random.uniform(size=[2, 1, 2, 128, 128]),
        'lmap': np.random.uniform(size=[2, 128, 128]),
    }
}

input_dict_torch = {
    'image': torch.from_numpy(input_dict['image']),
    'target': {
        'jmap': torch.from_numpy(input_dict['target']['jmap']),
        'joff': torch.from_numpy(input_dict['target']['joff']),
        'lmap': torch.from_numpy(input_dict['target']['lmap']),
    },
    'mode': 'training'
}

In [73]:
torch_model = MultitaskLearner(backbone=None)
torch_model.train()
losses = torch_model(input_dict_torch, outputs_torch, feature_torch)

offset [2 3 5]
stack: 0
output shape torch.Size([2, 5, 128, 128])
output shape after transpose + reshape torch.Size([5, 2, 128, 128])
joff shape torch.Size([2, 2, 128, 128])
stack: 1
output shape torch.Size([2, 5, 128, 128])
output shape after transpose + reshape torch.Size([5, 2, 128, 128])
joff shape torch.Size([2, 2, 128, 128])


In [74]:
for i, loss_stack in enumerate(losses['losses']):
    print(f'loss stack {i}')
    for name, loss in loss_stack.items():
        print(f'{name}: {loss}')

loss stack 0
jmap: tensor([5.7048, 5.7063], dtype=torch.float64)
lmap: tensor([0.3667, 0.3664], dtype=torch.float64)
joff: tensor([0.1984, 0.1989], dtype=torch.float64)
loss stack 1
jmap: tensor([5.7127, 5.7047], dtype=torch.float64)
lmap: tensor([0.3669, 0.3674], dtype=torch.float64)
joff: tensor([0.1984, 0.1982], dtype=torch.float64)
