In [1]:
import torch

I met a nan problem when backwarding the function below when `x` is lower than threshold $th=\frac{6}{29}^3\approx0.0089$

In [2]:
def lab_compress(x: torch.tensor):
    mask = x > (6./29.)**3.
    return (mask * x) ** (1./3.) + (841./108. * x + 4./29.) * ~mask

In [3]:
x = torch.tensor([x/1000 for x in range(20)])
print(x)
x.requires_grad_(True)
y = lab_compress(x)
y.backward(torch.ones_like(y))
x.grad

tensor([0.0000, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080,
        0.0090, 0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160, 0.0170,
        0.0180, 0.0190])


tensor([   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        7.7040, 7.1814, 6.7393, 6.3595, 6.0291, 5.7384, 5.4805, 5.2497, 5.0417,
        4.8532, 4.6814])

Check the reason with hook and rewrite the function.

In [4]:
def lab_compress(x: torch.tensor):
    th =  (6./29.)**3.
    mask = x > th
    x1 = mask * x
    x2 = ~mask * x
    x1.register_hook(lambda grad: print('x1.grad:\n', grad))
    x2.register_hook(lambda grad: print('x2.grad:\n', grad))
    return x1 ** (1./3.) + 841./108. * x2 + 4./29. * ~mask

In [5]:
x = torch.tensor([x/1000 for x in range(20)])
print(x)
x.requires_grad_(True)
y = lab_compress(x)
y.backward(torch.ones_like(y))
x.grad

tensor([0.0000, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080,
        0.0090, 0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160, 0.0170,
        0.0180, 0.0190])
x2.grad:
 tensor([7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870,
        7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870,
        7.7870, 7.7870])
x1.grad:
 tensor([   inf,    inf,    inf,    inf,    inf,    inf,    inf,    inf,    inf,
        7.7040, 7.1814, 6.7393, 6.3595, 6.0291, 5.7384, 5.4805, 5.2497, 5.0417,
        4.8532, 4.6814])


tensor([   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        7.7040, 7.1814, 6.7393, 6.3595, 6.0291, 5.7384, 5.4805, 5.2497, 5.0417,
        4.8532, 4.6814])

It is confusing that we got inf in x1.grad. See https://github.com/pytorch/pytorch/issues/4132 for a probable reason.

The first way to deal with it is hook. If you haven't used it before, read https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/.

In [6]:
def lab_compress(x: torch.tensor):
    def hook_fn(grad: torch.tensor):
        grad = torch.where(~torch.isfinite(grad), torch.zeros_like(grad), grad)
        print(grad)
        return

    th =  (6./29.)**3.
    mask = x > th
    x1 = mask * x
    x2 = ~mask * x
    x1.register_hook(hook_fn)
    x2.register_hook(hook_fn)
    return x1 ** (1./3.) + 841./108. * x2 + 4./29. * ~mask

In [7]:
x = torch.tensor([x/1000 for x in range(20)])
print(x)
x.requires_grad_(True)
y = lab_compress(x)
y.backward(torch.ones_like(y))
x.grad

tensor([0.0000, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080,
        0.0090, 0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160, 0.0170,
        0.0180, 0.0190])
tensor([7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870,
        7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870,
        7.7870, 7.7870])
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        7.7040, 7.1814, 6.7393, 6.3595, 6.0291, 5.7384, 5.4805, 5.2497, 5.0417,
        4.8532, 4.6814])


tensor([   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
        7.7040, 7.1814, 6.7393, 6.3595, 6.0291, 5.7384, 5.4805, 5.2497, 5.0417,
        4.8532, 4.6814])

By resetting infinte values to 0, we got a correct grad.

However, there may be some problems. if we got nan/inf  grad for other reasons, the hook will clear them at the same time. Also, there is no guarantee that high-order derivatives can be obtained correctly.

Another way is using built-in function to solve the problem.(complex, related to the forward function, but safe)

In [8]:
import torch.nn.functional as F

def lab_compress(x: torch.tensor):
    def masking(x: torch.tensor, threshold: float = 0, greater: bool = True):
        return F.relu(x - threshold) + threshold if greater \
            else -F.relu(threshold - x) + threshold

    th = (6./29.)**3.
    mask = x > th
    x1 = masking(x, th, True)
    x2 = masking(x.clone(), th, False)
    return x1 ** (1./3.) + 841./108. * x2 + 4./29. * ~mask

Here we use masking to wrap the operation. Note that we use x.clone() to avoid backward twice.

In [9]:
x = torch.tensor([x/1000 for x in range(20)])
print(x)
x.requires_grad_(True)
y = lab_compress(x)
y.backward(torch.ones_like(y))
x.grad

tensor([0.0000, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080,
        0.0090, 0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160, 0.0170,
        0.0180, 0.0190])


tensor([7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870, 7.7870,
        7.7040, 7.1814, 6.7393, 6.3595, 6.0291, 5.7384, 5.4805, 5.2497, 5.0417,
        4.8532, 4.6814])

The both tricks above have some drawbacks. A better solution is desired.