In [11]:
import torch.distributions as distributions
import torch

logits = torch.randn(1, 4, 6)
logits

tensor([[[-0.1916,  0.2263,  0.6764,  0.3331,  0.2804,  0.1842],
         [-1.0985, -0.7799, -0.2834, -1.5441, -0.9372,  0.7191],
         [-0.3629,  0.5302,  1.9323, -0.6160,  0.5157,  2.6066],
         [-0.4029,  0.8228,  0.2439,  0.8623,  0.0858,  0.0284]]])

In [12]:
sample = distributions.Categorical(logits=logits)
print(sample)
print(sample.sample())
print(sample.probs.argmax(dim=-1))

Categorical(logits: torch.Size([1, 4, 6]))
tensor([[2, 5, 5, 2]])
tensor([[2, 5, 5, 3]])


In [17]:
S = torch.tensor([5])
norm_ratio = torch.max(torch.ones(1), S)
print(S)
print(norm_ratio)

tensor([5])
tensor([5.])


In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F


@torch.no_grad()
def symlog(x):
    return torch.sign(x) * torch.log(1 + torch.abs(x))


@torch.no_grad()
def symexp(x):
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)


class SymLogLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, output, target):
        target = symlog(target)
        return 0.5*F.mse_loss(output, target)


class SymLogTwoHotLoss(nn.Module):
    def __init__(self, num_classes, lower_bound, upper_bound):
        super().__init__()
        self.num_classes = num_classes
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.bin_length = (upper_bound - lower_bound) / (num_classes-1)

        # use register buffer so that bins move with .cuda() automatically
        self.bins: torch.Tensor
        self.register_buffer(
            'bins', torch.linspace(-20, 20, num_classes), persistent=False)

    def forward(self, output, target):
        target = symlog(target)
        assert target.min() >= self.lower_bound and target.max() <= self.upper_bound

        index = torch.bucketize(target, self.bins)
        diff = target - self.bins[index-1]  # -1 to get the lower bound
        weight = diff / self.bin_length
        weight = torch.clamp(weight, 0, 1)
        weight = weight.unsqueeze(-1)

        target_prob = (1-weight)*F.one_hot(index-1, self.num_classes) + weight*F.one_hot(index, self.num_classes)

        loss = -target_prob * F.log_softmax(output, dim=-1)
        loss = loss.sum(dim=-1)
        return loss.mean()

    def decode(self, output):
        return symexp(F.softmax(output, dim=-1) @ self.bins)


loss_func = SymLogTwoHotLoss(255, -20, 20)
print(loss_func)
output = torch.randn(1, 1, 255).requires_grad_()
print(output)
target = torch.ones(1).reshape(1, 1).float() * 0.1
print(target)
loss = loss_func(output, target)
print(loss)

SymLogTwoHotLoss()
tensor([[[ 1.8980e+00, -1.6085e+00,  8.4599e-01, -7.6679e-01,  6.5921e-01,
          -3.8726e-01,  8.4523e-01, -3.9936e-01,  1.8783e+00,  1.0451e-01,
           1.9588e+00, -1.3000e-01, -6.2812e-01,  1.0949e+00,  9.0249e-01,
           5.6656e-01, -1.4412e-01, -2.9450e+00, -5.7787e-01,  1.2009e+00,
          -6.8237e-01, -1.2234e+00, -1.1486e+00,  1.0034e+00,  4.2528e-01,
          -3.3263e-01, -3.2194e-01,  1.0251e+00, -8.3089e-01, -1.3858e-01,
           3.0711e-01, -1.2917e+00, -1.0172e+00, -3.0163e+00,  4.1699e-01,
          -1.5841e-01,  2.3525e-01,  5.2670e-01,  1.2528e+00,  8.9015e-01,
           6.1383e-01, -4.1630e-01, -1.0373e+00, -1.2937e+00,  6.4527e-02,
           1.6181e+00, -5.3967e-01, -1.6379e+00,  4.4835e-01,  7.4905e-03,
          -4.0924e-01, -1.5125e-01,  1.1645e+00, -1.1510e+00, -1.9637e+00,
          -3.1762e-01,  9.6048e-01, -7.2908e-01, -6.4671e-01, -3.0405e-01,
           8.4101e-01, -1.1952e+00, -2.2104e-01,  2.2338e+00,  7.8695e-01,
      