In [69]:
from dtw_soft import *

In [70]:
import numpy as np
import torch


## Loss Pytorch

In [71]:
def soft_dtw(
    x: torch.Tensor, y: torch.Tensor, gamma: float = 1.0
) -> tuple[torch.Tensor, torch.Tensor]:
    """Soft Dynamic Time Warping.

    Args:
        x (list): length x feature
        y (list): length x feature
        gamma (float, optional): gamma parameter. Defaults to 1.0.

    Returns:
        float: soft-DTW distance
    """
    # initialize DP matrix
    n = x.shape[0]
    m = y.shape[0]
    R = torch.zeros((n + 1, m + 1))
    R[0, 1:] = float("inf")
    R[1:, 0] = float("inf")
    R[0, 0] = 0.0

    try:
        cost = torch.cdist(x, y, p=2) ** 2
    except:
        print(
            "Carefull : x and y are not D-dimensional > 1 features : added 2 dimensions"
        )
        cost = torch.cdist(x.unsqueeze(1), y.unsqueeze(1), p=2) ** 2

    for j in range(1, m + 1):
        for i in range(1, n + 1):
            # calculate minimum
            _min = soft_min([R[i - 1, j], R[i, j - 1], R[i - 1, j - 1]], gamma)

            # update cell
            R[i, j] = cost[i - 1, j - 1] + _min

    return R[-1, -1], R, cost


def jacobian_product_sq_euc_optimized(X, Y, E):
    # Expand X and Y to 3D tensors for broadcasting
    X_expanded = X.unsqueeze(1)  # Shape: [m, 1, d]
    Y_expanded = Y.unsqueeze(0)  # Shape: [1, n, d]

    # Compute the squared differences, shape: [m, n, d]
    diff = X_expanded - Y_expanded

    # Compute the weighted differences, shape: [m, n, d]
    weighted_diff = E.unsqueeze(-1) * diff * 2

    # Sum over the second dimension (n) to get the result, shape: [m, d]
    G = weighted_diff.sum(dim=1)
    return G



def backward_recursion(
    x: torch.Tensor, y: torch.Tensor, R, delta, gamma: float = 1.0
) -> torch.Tensor:
    """backward recursion of soft-DTW

    Args:
        x (torch.Tensor): length x feature
        y (torch.Tensor): length x feature
        gamma (float, optional): gamma parameter. Defaults to 1.0.

    Returns:
        torch.Tensor: E matrix
    """
    n, m = x.shape[0], y.shape[0]

    # intialization
    # compute R
    # _, R, delta = soft_dtw(x, y, gamma=gamma)
    R = torch.cat((R, -float("inf") * torch.ones((n + 1)).reshape(-1, 1)), dim=1)
    R = torch.cat((R, -float("inf") * torch.ones((m + 2)).reshape(1, -1)), dim=0)
    R[n + 1, m + 1] = R[n, m]

    # delta[:-1, m], delta[n, :-1] = 0.0, 0.0
    delta = torch.cat((delta, torch.zeros((n)).reshape(-1, 1)), dim=1)
    delta = torch.cat((delta, torch.zeros((m + 1)).reshape(1, -1)), dim=0)
    delta[n, m] = 0.0

    # compute E
    E = torch.zeros((n + 2, m + 2))
    E[n + 1, m + 1] = 1.0



    # backward recursion
    for j in range(m, 0, -1):  # ranges from m to 1
        for i in range(n, 0, -1):  # ranges from n to 1
            a = torch.exp((R[i + 1, j] - R[i, j] - delta[i, j - 1]) / gamma)
            b = torch.exp((R[i, j + 1] - R[i, j] - delta[i - 1, j]) / gamma)
            c = torch.exp((R[i + 1, j + 1] - R[i, j] - delta[i, j]) / gamma)
            E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c

    return E[1:-1, 1:-1]

class SoftDTWFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, target, gamma):
        loss, R, delta = soft_dtw(input, target, gamma)

        # save data for backward
        ctx.save_for_backward(input, target, R, delta)
        ctx.gamma = gamma

        return loss

    @staticmethod
    def backward(ctx, grad_output):
        # get value from forward
        x, y, R, delta = ctx.saved_tensors
        E = backward_recursion(x, y, R, delta, ctx.gamma)
        q = jacobian_product_sq_euc_optimized(x, y, E)
        return q, None, None


class MyDtw(torch.nn.Module):
    def __init__(self, gamma=1):
        super(MyDtw, self).__init__()
        self.gamma = gamma

    def forward(self, input, target):
        # Use self.param in your loss computation
        loss = SoftDTWFunction.apply(input, target, self.gamma)
        return loss

## Custom grad

In [72]:
x1 = torch.tensor([[1, 1, 56], [2, 8, 0]], dtype=torch.float32,requires_grad=True).T
y1 = torch.tensor([[1, 8, 1, 9, 1], [5, 9, 14, 7, -1]], dtype=torch.float32,requires_grad=True).T

x1.retain_grad()
y1.retain_grad()
print("x1 :",x1.shape,"y1 :",y1.shape)

cost, R, dist = soft_dtw(x1, y1, gamma=1.0)
cost.item()

x1 : torch.Size([3, 2]) y1 : torch.Size([5, 2])


3186.0

In [73]:
criterion = MyDtw()

In [74]:
loss_custom = criterion(x1,y1)
print(loss_custom)

tensor(3186., grad_fn=<SoftDTWFunctionBackward>)


In [75]:
loss_custom.backward()
print("x",x1.shape)
print(x1.grad)

x torch.Size([3, 2])
tensor([[-1.9950e-20, -6.0000e+00],
        [-3.0000e+01, -1.1999e+01],
        [ 1.1000e+02,  2.0000e+00]])


## Auto Grad

In [76]:
#x = torch.tensor([2.0,5.0,3.0,6.0],requires_grad=True).unsqueeze(-1)
#y = torch.tensor([5.0,9.0,2.0],requires_grad=True).unsqueeze(-1)
x2 = torch.tensor([[1, 1, 56], [2, 8, 0]], dtype=torch.float32,requires_grad=True).T
y2 = torch.tensor([[1, 8, 1, 9, 1], [5, 9, 14, 7, -1]], dtype=torch.float32,requires_grad=True).T

x2.retain_grad()
y2.retain_grad()

loss = soft_dtw(x2,y2)[0]


In [77]:
loss.backward()
print("x",x2.shape)
print(x2.grad)

x torch.Size([3, 2])
tensor([[  0.,   0.],
        [  0.,   0.],
        [110.,   2.]])


# Batch same size

In [108]:
def soft_min_batch(list_a, gamma):
    """Softmin function.

    Args:
        list_a (list): list of values
        gamma (float): gamma parameter

    Returns:
        float: softmin value
    """
    assert gamma >= 0, "gamma must be greater than or equal to 0"
    # Assuming list_a is a list of tensors of the same shape
    list_a = torch.stack(list_a)  # Shape: [n, m]

    if gamma == 0:
        _min = torch.min(list_a, dim=0)[0]  # Min along the first dimension
    else:
        z = -list_a / gamma
        max_z = torch.max(z, dim=0, keepdim=True)[0]  # Max along the first dimension
        log_sum = max_z + torch.log(torch.sum(torch.exp(z - max_z), dim=0))
        _min = -gamma * log_sum
    return _min

def soft_dtw_batch_same_size(x: torch.Tensor, y: torch.Tensor, gamma: float = 1) -> torch.Tensor:
    """Soft Dynamic Time Warping.

    Args:
        x (list): batch x length x feature
        y (list): batch x length x feature
        gamma (float, optional): gamma parameter. Defaults to 1.0.

    Returns:
        float: soft-DTW distance
    """
    # initialize DP matrix
    n = x.shape[1]
    m = y.shape[1]
    b = x.shape[0]

    R = torch.zeros((b, n + 1, m + 1))
    R[:, 0, 1:] = float("inf")
    R[:, 1:, 0] = float("inf")
    R[:, 0, 0] = 0.0

    try:
        cost = torch.cdist(x, y, p=2) ** 2
    except:
        print(
            "Carefull : x and y are not D-dimensional > 1 features : added 2 dimensions"
        )
        cost = torch.cdist(x.unsqueeze(1), y.unsqueeze(1), p=2) ** 2

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            R[:, i, j] = cost[:, i - 1, j - 1] + soft_min_batch([R[:, i - 1, j], R[:, i, j - 1], R[:, i - 1, j - 1]], gamma)

    return R[:, -1, -1], R, cost


def backward_recursion_batch_same_size(
    x: torch.Tensor, y: torch.Tensor, R, delta, gamma: float = 1.0
) -> torch.Tensor:
    """backward recursion of soft-DTW

    Args:
        x (torch.Tensor): batch x length x feature
        y (torch.Tensor): batch x length x feature
        gamma (float, optional): gamma parameter. Defaults to 1.0.

    Returns:
        torch.Tensor: E batch x matrix
    """
    batch = x.shape[0]
    n, m = x.shape[1], y.shape[1]

    # intialization
    delta = torch.cat((delta, torch.zeros((batch,n)).reshape(batch, -1, 1)), dim=2)
    delta = torch.cat((delta, torch.zeros((batch,m + 1)).reshape(batch, 1, -1)), dim=1)
    delta[:, n, m] = 0.0

    # compute E
    E = torch.zeros((batch, n + 2, m + 2))
    E[:, n + 1, m + 1] = 1.0

    # compute R
    # _, R = soft_dtw_batch_same_size(x, y, gamma=gamma)
    R = torch.cat((R, -float("inf") * torch.ones((batch,n + 1)).reshape(batch, -1, 1)), dim=2)
    R = torch.cat((R, -float("inf") * torch.ones((batch,m + 2)).reshape(batch, 1, -1)), dim=1)
    R[:, n + 1, m + 1] = R[:, n, m]

    # backward recursion
    for j in range(m, 0, -1):  # ranges from m to 1
        for i in range(n, 0, -1):  # ranges from n to 1
            a = torch.exp((R[:, i + 1, j] - R[:, i, j] - delta[:, i, j - 1]) / gamma)
            b = torch.exp((R[:, i, j + 1] - R[:, i, j] - delta[:, i - 1, j]) / gamma)
            c = torch.exp((R[:, i + 1, j + 1] - R[:, i, j] - delta[:, i, j]) / gamma)
            E[:, i, j] = E[:, i + 1, j] * a + E[:, i, j + 1] * b + E[:, i + 1, j + 1] * c

    return E[:, 1:-1, 1:-1]

def jacobian_product_sq_euc_batch(X, Y, E):
    # Expand X and Y to 4D tensors for broadcasting, shape: [b, m, 1, d] and [b, 1, n, d]
    X_expanded = X.unsqueeze(2)
    Y_expanded = Y.unsqueeze(1)

    # Compute the squared differences, shape: [b, m, n, d]
    diff = X_expanded - Y_expanded

    # Adjust E for broadcasting, shape: [b, m, n, d]
    E_adjusted = E.unsqueeze(-1)

    # Compute the weighted differences, shape: [b, m, n, d]
    weighted_diff = E_adjusted * diff * 2

    # Sum over the third dimension (n) to get the result, shape: [b, m, d]
    G = weighted_diff.sum(dim=2)
    return G

class SoftDTWFunction_batch_same_size(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, target, gamma):
        loss, R, delta = soft_dtw_batch_same_size(input, target, gamma)

        # save data for backward
        ctx.save_for_backward(input, target, R, delta)
        ctx.gamma = gamma

        return loss

    @staticmethod
    def backward(ctx, grad_output):
        # get value from forward
        x, y, R, delta = ctx.saved_tensors
        E = backward_recursion_batch_same_size(x, y, R, delta, ctx.gamma)
        q = jacobian_product_sq_euc_batch(x, y, E)
        return q/x.shape[0], None, None


class DTWLoss(torch.nn.Module):
    def __init__(self, gamma=1,reduction='mean'):
        super(DTWLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        # Use self.param in your loss computation
        if self.reduction == 'mean':
            loss =torch.mean(SoftDTWFunction_batch_same_size.apply(input, target, self.gamma))
        elif self.reduction == 'sum':
            loss =torch.sum(SoftDTWFunction_batch_same_size.apply(input, target, self.gamma))
        else:
            raise 
        return loss


## Custom backward

In [109]:
xs1_1 = torch.tensor([[1, 1, 56], [2, 8, 0]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)
ys1_1 = torch.tensor([[1, 8, 1, 9, 1], [5, 9, 14, 7, -1]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)

xs1_2 = torch.tensor([[1, 0, 56], [2, 0, 0]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)
ys1_2 = torch.tensor([[1, 0, 1, 9, 1], [5, 0, 14, 7, -1]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)

xs1 = torch.stack([xs1_1,xs1_2])
ys1 = torch.stack([ys1_1,ys1_2])

xs1.retain_grad()
ys1.retain_grad()
print("x1 :",xs1.shape,"y1 :",ys1.shape)

cost, R, dist = soft_dtw_batch_same_size(xs1, ys1, gamma=1.0)

x1 : torch.Size([2, 3, 2]) y1 : torch.Size([2, 5, 2])


In [110]:
criterion = DTWLoss()

In [111]:
loss_custom = criterion(xs1,ys1)
loss_custom


tensor(3230.5000, grad_fn=<MeanBackward0>)

In [112]:
loss_custom.backward()
print("x",xs1.shape)
print(xs1.grad)

x torch.Size([2, 3, 2])
tensor([[[-9.9749e-21, -3.0000e+00],
         [-1.5000e+01, -5.9996e+00],
         [ 5.5000e+01,  1.0000e+00]],

        [[-7.0000e+00, -1.8000e+01],
         [-1.0000e+00,  1.0000e+00],
         [ 5.5000e+01,  1.0000e+00]]])


# Auto Backward

In [113]:
xs2_1 = torch.tensor([[1, 1, 56], [2, 8, 0]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)
ys2_1 = torch.tensor([[1, 8, 1, 9, 1], [5, 9, 14, 7, -1]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)

xs2_2 = torch.tensor([[1, 0, 56], [2, 0, 0]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)
ys2_2 = torch.tensor([[1, 0, 1, 9, 1], [5, 0, 14, 7, -1]], dtype=torch.float32,requires_grad=True).T#.repeat(1,1,1)

xs2 = torch.stack([xs1_1,xs1_2])
ys2 = torch.stack([ys1_1,ys1_2])

xs2.retain_grad()
ys2.retain_grad()
print("x1 :",xs2.shape,"y1 :",ys2.shape)

loss = torch.mean(soft_dtw_batch_same_size(xs2,ys2)[0])

x1 : torch.Size([2, 3, 2]) y1 : torch.Size([2, 5, 2])


In [114]:
loss.backward()
print("x",xs2.shape)
print(xs2.grad)

x torch.Size([2, 3, 2])
tensor([[[-9.9749e-21, -3.0000e+00],
         [-1.5000e+01, -5.9996e+00],
         [ 5.5000e+01,  1.0000e+00]],

        [[-7.0000e+00, -1.8000e+01],
         [-1.0000e+00,  1.0000e+00],
         [ 5.5000e+01,  1.0000e+00]]])
