In [21]:
from dtw_soft import *

# First tests

In [22]:
x = torch.tensor([2.0, 5.0, 3.0, 6.0], requires_grad=True).unsqueeze(-1)
y = torch.tensor([5.0, 9.0, 2.0], requires_grad=True).unsqueeze(-1)

loss = soft_dtw(x, y)[0]

In [23]:
loss.backward()

In [24]:
x1 = torch.tensor([2.0, 5.0, 3.0, 6.0], requires_grad=True).unsqueeze(-1)
x2 = torch.tensor([2.0, 3.0, 3.0, 5.0], requires_grad=True).unsqueeze(-1)
x = torch.stack([x1, x2])

y1 = torch.tensor([5.0, 9.0, 2.0, 2.0], requires_grad=True).unsqueeze(-1)
y2 = torch.tensor([2.0, 3.0, 3.0, 2.0], requires_grad=True).unsqueeze(-1)
y = torch.stack([y1, y2])

print(x.shape, y.shape)

loss = torch.mean(soft_dtw_batch_same_size(x, y)[0])

torch.Size([2, 4, 1]) torch.Size([2, 4, 1])


In [25]:
loss

tensor(23.8546, grad_fn=<MeanBackward0>)

In [26]:
loss.backward()

# Pytorch Backward

In [103]:
x = torch.tensor([2.0, 5.0, 3.0, 6.0], requires_grad=True).unsqueeze(-1).repeat(1, 3)
y = torch.tensor([5.0, 9.0, 2.0], requires_grad=True).unsqueeze(-1).repeat(1, 3)

# Retaining the gradient for non-leaf tensor
x.retain_grad()
y.retain_grad()

loss = soft_dtw(x, y)[0]

In [104]:
loss.backward()

# Extracting gradients
x_grad = x.grad
y_grad = y.grad

x_grad, y_grad

(tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [8., 8., 8.]]),
 tensor([[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [-8., -8., -8.]]))

# Custom Backward

In [105]:
E = backward_recursion(x, y)
E

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.9999e-01, 1.1705e-05, 2.2001e-17],
        [9.9998e-01, 1.0806e-30, 1.1705e-05],
        [4.7425e-02, 9.9998e-01, 1.0000e+00]], grad_fn=<SliceBackward0>)

In [106]:
def grad_x(x, y, E):
    with torch.no_grad():
        p = len(x[0])
        m = len(y)
        print("p m:", p, m)
        B = E.transpose(-1, -2)

        a = torch.ones(p, m) @ B
        print(a, a.shape)
        a = a * torch.transpose(x, -1, -2)
        print(a, a.shape)
        b = y.transpose(0, 1) @ B
        print(b, b.shape)
        return 2 * (a - b).transpose(-1, -2)

In [107]:
out = grad_x(x, y, E)
print("Out:")
print(out, out.shape)

p m: 3 3
tensor([[1.0000, 1.0000, 1.0000, 2.0474],
        [1.0000, 1.0000, 1.0000, 2.0474],
        [1.0000, 1.0000, 1.0000, 2.0474]]) torch.Size([3, 4])
tensor([[ 2.0000,  5.0000,  3.0000, 12.2845],
        [ 2.0000,  5.0000,  3.0000, 12.2845],
        [ 2.0000,  5.0000,  3.0000, 12.2845]]) torch.Size([3, 4])
tensor([[ 5.0000,  5.0001,  4.9999, 11.2370],
        [ 5.0000,  5.0001,  4.9999, 11.2370],
        [ 5.0000,  5.0001,  4.9999, 11.2370]]) torch.Size([3, 4])
Out:
tensor([[-6.0000e+00, -6.0000e+00, -6.0000e+00],
        [-9.3460e-05, -9.3460e-05, -9.3460e-05],
        [-3.9999e+00, -3.9999e+00, -3.9999e+00],
        [ 2.0949e+00,  2.0949e+00,  2.0949e+00]]) torch.Size([4, 3])


# Test with Batches

In [108]:
print(x.repeat(3, 1, 1).shape)
E = backward_recursion_batch_same_size(x.repeat(3, 1, 1), y.repeat(3, 1, 1))
E

torch.Size([3, 4, 3])


tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.9999e-01, 1.1705e-05, 2.2001e-17],
         [9.9998e-01, 1.0806e-30, 1.1705e-05],
         [4.7425e-02, 9.9998e-01, 1.0000e+00]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.9999e-01, 1.1705e-05, 2.2001e-17],
         [9.9998e-01, 1.0806e-30, 1.1705e-05],
         [4.7425e-02, 9.9998e-01, 1.0000e+00]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.9999e-01, 1.1705e-05, 2.2001e-17],
         [9.9998e-01, 1.0806e-30, 1.1705e-05],
         [4.7425e-02, 9.9998e-01, 1.0000e+00]]], grad_fn=<SliceBackward0>)

In [101]:
# Afair plus tard quand ca march sans batch

out = grad_x(x.repeat(3, 1, 1), y.repeat(3, 1, 1), E)
print(out, out.shape)

p m: 4 3
tensor([[[1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474]],

        [[1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474]],

        [[1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474],
         [1.0000, 1.0000, 1.0000, 2.0474]]]) torch.Size([3, 4, 4])


RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1

# Test with big vectors

In [40]:
x = torch.ones(100, requires_grad=True).unsqueeze(-1)
y = torch.ones(100, requires_grad=True).unsqueeze(-1)

# Retaining the gradient for non-leaf tensor
x.retain_grad()
y.retain_grad()

In [41]:
E = backward_recursion(x, y)