In [1]:
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss, NLLLoss, LogSoftmax, KLDivLoss, MSELoss, L1Loss

In [68]:
# cel = CrossEntropyLoss()
kll = KLDivLoss(reduction='batchmean')
msl = MSELoss()
l1l = L1Loss()

def earth_mover_distance(y_true, y_pred):
    return torch.square(torch.cumsum(y_true, dim=-1) - torch.cumsum(y_pred, dim=-1)).sum() / y_true.size(0)

logSoftmax = LogSoftmax(dim=1)
softmax = torch.nn.Softmax(dim=1)

preds = torch.tensor([[0.1, 0.1, 0.3, 0.7]]).repeat(512, 1024//4)
targets = torch.tensor([[0.2, 0.2, 0.9, 0.4]]).repeat(512, 1024//4)
targets[0, 0] += 0.1
targets[0, 1023] += 0.1


# print('CE:', cel(preds, softmax(targets)))
print('KLD:', kll(logSoftmax(preds), softmax(targets)))
print('L1:', l1l(preds, targets))
print('L2:', msl(preds, targets))
print('EMD:', earth_mover_distance(targets, preds))

KLD: tensor(0.0617)
L1: tensor(0.2750)
L2: tensor(0.1175)
EMD: tensor(5612170.)


In [131]:
preds = torch.randn(512, 512)
targets = torch.randn(512, 512)

torch.mean(earth_mover_distance(preds, targets))

tensor(507.8565)

In [61]:
import torch.nn.functional as F

In [62]:
kll(F.log_softmax(preds, dim=1), F.softmax(targets, dim=1))

tensor(-2.1623e-07)

In [43]:
torch.tensor([[0.1, 0.2, 0.3]]).repeat(1, 4)

tensor([[0.1000, 0.2000, 0.3000, 0.1000, 0.2000, 0.3000, 0.1000, 0.2000, 0.3000,
         0.1000, 0.2000, 0.3000]])

In [28]:
kll(logSoftmax(preds), softmax(targets))

tensor(0.)

In [20]:
cel(logSoftmax(preds), softmax(targets))

tensor(1.0953)

In [82]:
ground_truth[1]

tensor(172)

In [50]:
mask[1]

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 

In [4]:
non_ground_truth = torch.arange(logits.shape[1])
non_ground_truth = non_ground_truth[non_ground_truth != ground_truth]
non_gt_logits = logits[:, non_ground_truth]

tensor([ 55, 172,  10, 106, 260, 325,   5, 281,   3, 511,  63,  27, 413, 333,
        456, 264, 399, 209, 454, 179, 108, 365, 249, 150,  84,  17, 450,  42,
        163,   8, 446, 270, 484, 509, 479, 226, 486, 445, 172, 463, 140, 277,
        267,  43, 153, 386, 168,  93, 402,  26, 333, 403,  74, 307, 415, 497,
         69, 250, 172,  76, 322, 170, 490, 353, 271,   5, 504, 511, 300, 185,
        158,  41, 458, 409, 164, 115, 403, 279,  79, 404, 258, 367,  40, 126,
        370, 227, 270, 445, 326, 402, 184,  81, 490, 321, 359, 285, 476,  89,
        450,  42,  34,  25,  94, 406, 268, 394, 492, 480, 398, 176, 498, 258,
         26, 237, 508,  18, 207, 496, 497,  89, 458, 341,  85, 258, 161, 166,
        339,  70, 275, 458, 270, 437, 172, 268, 318, 425, 363, 182, 109, 285,
        399, 140, 396, 407, 136, 161, 238, 308, 377,  58, 461, 379,   5, 504,
         29, 141, 328,  85, 201, 141,  30, 130, 255, 411, 253, 253, 182, 430,
        463,  63, 184, 116, 205, 142, 198, 475, 417, 361, 121, 4

In [19]:
x = torch.arange(10)
value = 5
x = x[x!=value]

In [20]:
x

tensor([0, 1, 2, 3, 4, 6, 7, 8, 9])

In [22]:
len(x!=value)

9

In [4]:
targets = torch.linspace(-1, 1, 513).repeat(512, 1)
print(targets.shape, targets.dtype)


torch.Size([512, 513]) torch.float32


In [5]:
permutation = torch.rand(targets.size(0), targets.size(1)).argsort(dim=1)
print(permutation.shape, permutation.dtype)

torch.Size([512, 513]) torch.int64


In [6]:
targets = targets[torch.arange(targets.size(0)).unsqueeze(-1), permutation]
print(targets.shape, targets.dtype)

torch.Size([512, 513]) torch.float32


In [10]:
unscrambled_targets = targets[torch.arange(targets.size(0)).unsqueeze(-1), torch.argsort(permutation)]
print(unscrambled_targets.shape, unscrambled_targets.dtype)

torch.Size([512, 513]) torch.float32


In [11]:
unscrambled_targets[0]

tensor([-1.0000, -0.9961, -0.9922, -0.9883, -0.9844, -0.9805, -0.9766, -0.9727,
        -0.9688, -0.9648, -0.9609, -0.9570, -0.9531, -0.9492, -0.9453, -0.9414,
        -0.9375, -0.9336, -0.9297, -0.9258, -0.9219, -0.9180, -0.9141, -0.9102,
        -0.9062, -0.9023, -0.8984, -0.8945, -0.8906, -0.8867, -0.8828, -0.8789,
        -0.8750, -0.8711, -0.8672, -0.8633, -0.8594, -0.8555, -0.8516, -0.8477,
        -0.8438, -0.8398, -0.8359, -0.8320, -0.8281, -0.8242, -0.8203, -0.8164,
        -0.8125, -0.8086, -0.8047, -0.8008, -0.7969, -0.7930, -0.7891, -0.7852,
        -0.7812, -0.7773, -0.7734, -0.7695, -0.7656, -0.7617, -0.7578, -0.7539,
        -0.7500, -0.7461, -0.7422, -0.7383, -0.7344, -0.7305, -0.7266, -0.7227,
        -0.7188, -0.7148, -0.7109, -0.7070, -0.7031, -0.6992, -0.6953, -0.6914,
        -0.6875, -0.6836, -0.6797, -0.6758, -0.6719, -0.6680, -0.6641, -0.6602,
        -0.6562, -0.6523, -0.6484, -0.6445, -0.6406, -0.6367, -0.6328, -0.6289,
        -0.6250, -0.6211, -0.6172, -0.61

In [7]:
targets

tensor([-0.2344, -0.1953, -0.9883, -0.3711,  0.3477, -0.4062,  0.2188,  0.1094,
         0.1680, -0.9375, -0.0742, -0.4375, -0.0664,  0.1562,  0.7266,  0.3164,
         0.3789, -0.7930,  0.6367, -0.7852,  0.1523, -0.0352,  0.4844, -0.2812,
         0.4180, -0.3047,  0.9375, -0.5039,  0.1250, -0.5273,  0.8359, -0.4492,
        -0.8203,  0.7031,  0.3359,  0.7305,  0.1289, -0.9727,  0.0234,  0.4023,
         0.8086,  0.9688,  0.0000,  0.6055,  0.0703, -0.0156,  0.4102, -0.5664,
        -0.1992, -0.7656, -0.0859, -0.7969, -0.5078,  0.2578, -0.9609, -0.9531,
        -0.0820, -0.3633, -0.0469,  0.8242,  0.7344, -0.8477,  0.7461,  0.0586,
         0.6445, -0.8945, -0.3555, -0.6914,  0.7930, -0.6289,  0.0273, -0.9766,
         0.3672,  0.5469, -0.7812,  0.2344,  0.8047,  0.7109,  0.4688,  0.3281,
         0.6758, -0.7461,  0.8320, -0.8398, -0.8594, -0.8164,  0.5547, -0.9297,
        -0.3320,  0.1719, -0.2617,  0.3320,  0.2773,  0.5273, -0.9922,  0.5664,
        -0.4336,  0.8711,  0.9180, -0.14

In [6]:
import torch.nn.functional as F

def earth_mover_distance(input: Tensor, target: Tensor, convert_to_prob_distr=False) -> Tensor:
    '''From: https://discuss.pytorch.org/t/implementation-of-squared-earth-movers-distance-loss-function-for-ordinal-scale/107927/2'''

    # convert to probability distribution
    if convert_to_prob_distr:
        input = F.softmax(input, dim=-1)
        target = F.softmax(target, dim=-1)

    return torch.mean(torch.square(torch.cumsum(input, dim=-1) - torch.cumsum(target, dim=-1)))

In [2]:
step1 = torch.randn(512, 1024)
step2 = step1.clone()
step2[:, 0] += 0.5

print(step1)
print(step2)

tensor([[ 0.2488,  1.2290,  0.7107,  ...,  0.4301, -1.0904,  0.9434],
        [-1.8892, -0.7786, -0.3803,  ..., -1.5747, -1.6249, -0.4611],
        [ 2.6743,  1.2802,  0.8269,  ...,  0.9371,  2.3421,  0.1913],
        ...,
        [ 0.5909, -0.6343, -0.2968,  ...,  0.3054,  0.7932,  1.0691],
        [-1.3597, -1.1690, -0.1044,  ...,  0.4661,  0.1656,  1.2770],
        [ 1.5634, -1.6232,  0.4847,  ..., -0.5152,  0.7526, -0.3520]])
tensor([[ 0.7488,  1.2290,  0.7107,  ...,  0.4301, -1.0904,  0.9434],
        [-1.3892, -0.7786, -0.3803,  ..., -1.5747, -1.6249, -0.4611],
        [ 3.1743,  1.2802,  0.8269,  ...,  0.9371,  2.3421,  0.1913],
        ...,
        [ 1.0909, -0.6343, -0.2968,  ...,  0.3054,  0.7932,  1.0691],
        [-0.8597, -1.1690, -0.1044,  ...,  0.4661,  0.1656,  1.2770],
        [ 2.0634, -1.6232,  0.4847,  ..., -0.5152,  0.7526, -0.3520]])


In [7]:
earth_mover_distance(step1, step2, True)

tensor(4.7848e-07)

In [9]:
torch.cumsum(step1, dim=-1)

tensor([[-0.6995,  0.1876,  1.5932],
        [ 1.4654,  1.3840,  0.4774]])

In [8]:
loss_sample = torch.load('/home/nikita/e2e-driving/sample-1658673603.5867388-9368.62598.pt')

In [9]:
logits = loss_sample['logits']
odd = loss_sample['odd']
even = loss_sample['even']
loss = loss_sample['loss']

print('logits:', logits.shape)
print('odd:', odd.shape)
print('even:', even.shape)
print('loss:', loss.shape)

logits: torch.Size([512, 513])
odd: torch.Size([256, 513])
even: torch.Size([256, 513])
loss: torch.Size([])


In [10]:
logits

tensor([[ 0.0000, -0.7793, -0.7754,  ..., -0.3466, -0.3476, -0.3487],
        [ 0.0000, -0.7824, -0.7784,  ..., -0.3469, -0.3480, -0.3491],
        [ 0.0000, -0.7029, -0.6990,  ..., -0.3349, -0.3360, -0.3370],
        ...,
        [ 0.0000, -1.0124, -1.0089,  ..., -0.2926, -0.2936, -0.2946],
        [ 0.0000, -1.1167, -1.1133,  ..., -0.3901, -0.3911, -0.3922],
        [ 0.0000, -1.1110, -1.1076,  ..., -0.3902, -0.3913, -0.3924]],
       device='cuda:0', requires_grad=True)

In [11]:
odd

tensor([[ 0.0000, -0.7793, -0.7754,  ..., -0.3466, -0.3476, -0.3487],
        [ 0.0000, -0.7029, -0.6990,  ..., -0.3349, -0.3360, -0.3370],
        [ 0.0000, -1.2032, -1.1998,  ..., -0.1765, -0.1899, -0.2033],
        ...,
        [ 0.0000, -0.5586, -0.5366,  ..., -0.3946, -0.3956, -0.3966],
        [ 0.0000, -1.0001, -0.9966,  ..., -0.2945, -0.2956, -0.2966],
        [ 0.0000, -1.1167, -1.1133,  ..., -0.3901, -0.3911, -0.3922]],
       device='cuda:0', requires_grad=True)

In [12]:
even

tensor([[ 0.0000, -0.7824, -0.7784,  ..., -0.3469, -0.3480, -0.3491],
        [ 0.0000, -0.6881, -0.6842,  ..., -0.3355, -0.3365, -0.3376],
        [ 0.0000, -1.2157, -1.2123,  ..., -0.1670, -0.1801, -0.1932],
        ...,
        [ 0.0000, -0.6408, -0.6185,  ..., -0.3954, -0.3965, -0.3976],
        [ 0.0000, -1.0124, -1.0089,  ..., -0.2926, -0.2936, -0.2946],
        [ 0.0000, -1.1110, -1.1076,  ..., -0.3902, -0.3913, -0.3924]],
       device='cuda:0', requires_grad=True)