In [1]:
import torch

In [2]:
import torch.nn as nn

In [33]:
class BCEWithLogitsLossMasked:
    def __init__(self):
        self.secret_criterion = nn.BCEWithLogitsLoss(reduction='none')
    
    def __call__(self, x, y, m):
        assert x.shape[:2] == y.shape[:2] == m.shape
        assert m.dtype == torch.int64
        
        loss_not_reduced = self.secret_criterion(x, y)
        loss_masked_not_reduced = loss_not_reduced * m.unsqueeze(dim=-1)
        loss_masked_not_reduced_scalled = loss_masked_not_reduced / l.unsqueeze(dim=-1)
        loss_masked_not_reduced_scalled.sum() / len(loss_masked_not_reduced_scalled)
        
        return loss_masked_not_reduced_scalled.sum() / len(loss_masked_not_reduced_scalled)

In [34]:
criterion = nn.BCEWithLogitsLoss()
criterion_red = nn.BCEWithLogitsLoss(reduction='none')
criterion_fancy = BCEWithLogitsLossMasked()

# Inputs

In [6]:
# batch, seq, n_out
x = torch.randn(2, 4, 5)
y = torch.randn(2, 4, 5)

In [7]:
x

tensor([[[ 0.9224,  0.1275, -0.4445, -2.4378, -0.1389],
         [-0.7623, -0.2174, -0.1883, -0.8485,  0.0264],
         [ 1.2433,  0.4055, -1.0698, -1.5804,  1.2499],
         [-1.9781, -1.3375, -1.2968, -0.8545,  0.8955]],

        [[ 0.3711, -0.8387, -1.0524, -0.1452, -0.3665],
         [-0.1056, -1.4996,  0.2971,  1.0016, -0.2294],
         [-1.4290,  0.0747,  1.4130,  0.3132,  1.3201],
         [ 0.1978,  3.0526, -0.2047,  0.0086, -0.7705]]])

In [8]:
y

tensor([[[ 1.2679,  0.6888, -0.8276,  0.8544,  0.2605],
         [-0.4449, -0.2812, -1.0358, -0.4625, -0.1208],
         [ 1.0388, -0.3966, -0.2596,  1.9262, -0.1836],
         [-1.9481,  0.8885,  0.0687, -1.2673,  0.1870]],

        [[-0.9446,  1.1808, -2.2222,  0.0285, -0.0578],
         [-0.4722,  1.8635,  0.5248,  0.6397,  0.9245],
         [ 2.1397,  0.0330,  1.1433,  0.0392, -0.6117],
         [ 0.9815,  1.5426, -1.9470, -0.3060,  0.6797]]])

In [9]:
m = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])

In [10]:
l = torch.tensor([[2], [3]])

# Loss by fancy

In [35]:
criterion_fancy(x, y, m)

tensor(3.4548)

# Loss by loop

In [14]:
loss = 0
for i in range(len(x)):
    x_clip = x[i,-l[i]:]
    y_clip = y[i,-l[i]:]
    print(x_clip)
    print(y_clip)
    loss_a = criterion(x_clip, y_clip)
    loss_b = criterion_red(x_clip, y_clip).sum() / 2
    print(loss_a)
    print(loss_b)
    loss += loss_a
    
    print('--')
print('==')
print(loss / len(x))

tensor([[ 1.2433,  0.4055, -1.0698, -1.5804,  1.2499],
        [-1.9781, -1.3375, -1.2968, -0.8545,  0.8955]])
tensor([[ 1.0388, -0.3966, -0.2596,  1.9262, -0.1836],
        [-1.9481,  0.8885,  0.0687, -1.2673,  0.1870]])
tensor(0.4633)
tensor(2.3163)
--
tensor([[-0.1056, -1.4996,  0.2971,  1.0016, -0.2294],
        [-1.4290,  0.0747,  1.4130,  0.3132,  1.3201],
        [ 0.1978,  3.0526, -0.2047,  0.0086, -0.7705]])
tensor([[-0.4722,  1.8635,  0.5248,  0.6397,  0.9245],
        [ 2.1397,  0.0330,  1.1433,  0.0392, -0.6117],
        [ 0.9815,  1.5426, -1.9470, -0.3060,  0.6797]])
tensor(0.9187)
tensor(6.8900)
--
==
tensor(0.6910)


# Loss vectorized

In [15]:
# loss, but not reduce
loss_not_reduced = criterion_red(x, y)
loss_not_reduced

tensor([[[ 0.0876,  0.6711,  0.1275,  2.1665,  0.6623],
         [ 0.0439,  0.5292,  0.4084, -0.0361,  0.7096],
         [ 0.2052,  1.0771,  0.0172,  3.2313,  1.7313],
         [-3.7240,  1.4215,  0.3307, -0.7283,  1.0705]],

        [[ 1.2463,  1.3496, -2.0391,  0.6273,  0.5054],
         [ 0.5918,  2.9960,  0.6968,  0.6737,  0.7971],
         [ 3.2723,  0.7288,  0.0153,  0.8497,  2.3643],
         [ 0.6028, -1.6102,  0.1974,  0.7001,  0.9041]]])

In [21]:
# apply mask
loss_masked_not_reduced = loss_not_reduced * m.unsqueeze(dim=-1)
loss_masked_not_reduced

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
         [ 0.2052,  1.0771,  0.0172,  3.2313,  1.7313],
         [-3.7240,  1.4215,  0.3307, -0.7283,  1.0705]],

        [[ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
         [ 0.5918,  2.9960,  0.6968,  0.6737,  0.7971],
         [ 3.2723,  0.7288,  0.0153,  0.8497,  2.3643],
         [ 0.6028, -1.6102,  0.1974,  0.7001,  0.9041]]])

In [26]:
# scale rows by num not masked
loss_masked_not_reduced_scalled = loss_masked_not_reduced / l.unsqueeze(dim=-1)
loss_masked_not_reduced_scalled

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
         [ 0.1026,  0.5386,  0.0086,  1.6157,  0.8657],
         [-1.8620,  0.7108,  0.1654, -0.3642,  0.5353]],

        [[ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
         [ 0.1973,  0.9987,  0.2323,  0.2246,  0.2657],
         [ 1.0908,  0.2429,  0.0051,  0.2832,  0.7881],
         [ 0.2009, -0.5367,  0.0658,  0.2334,  0.3014]]])

In [27]:
# quick test?
aa = loss_masked_not_reduced[0].sum() / l[0]
bb = loss_masked_not_reduced[1].sum() / l[1]
(aa + bb) / 2

tensor([3.4548])

In [28]:
# sum and normalize by batch size
loss_masked_not_reduced_scalled.sum() / len(loss_masked_not_reduced_scalled)

tensor(3.4548)

In [137]:
loss_not_reduced = criterion_red(x, y)
loss_masked_not_reduced = loss_not_reduced * m
loss_masked_not_reduced_scalled = loss_masked_not_reduced / l
loss_masked_not_reduced_scalled.sum() / len(loss_masked_not_reduced_scalled)
loss_masked_not_reduced_scalled

tensor([[-0.0000,  0.0000,  0.5391, -1.1868],
        [ 0.0000, -0.2401,  0.2313,  0.2348]])

# Loss vectorised

In [138]:
x.shape == y.shape

True

In [139]:
x.shape

torch.Size([2, 4])

In [140]:
m.dtype

torch.int64