In [33]:
import torch
from itertools import product
from functools import partial
from torch.autograd.functional import jacobian

torch.manual_seed(42)


<torch._C.Generator at 0x7f6990743330>

In [7]:
# x = torch.arange(3, dtype=torch.float32, requires_grad=True)
x = torch.tensor([
    [1, 2, 3],
    [.6, .28, .9]
], requires_grad=True)


def f(x):
    return torch.softmax(x, dim=-1)

auto = jacobian(f, x[0])
print(auto)
auto = jacobian(f, x[1])
print(auto)

tensor([[ 0.0819, -0.0220, -0.0599],
        [-0.0220,  0.1848, -0.1628],
        [-0.0599, -0.1628,  0.2227]])
tensor([[ 0.2194, -0.0767, -0.1427],
        [-0.0767,  0.1803, -0.1036],
        [-0.1427, -0.1036,  0.2463]])


In [8]:
N, M = x.shape

P = torch.softmax(x, dim=-1)
print(P.shape, P)
manual = torch.tensor([[[P[i, j] * ((1 if j == k else 0) - P[i, k]) for k in range(M)] for j in range(M)] for i in range(N)])
manual.shape, manual

torch.Size([2, 3]) tensor([[0.0900, 0.2447, 0.6652],
        [0.3251, 0.2361, 0.4388]], grad_fn=<SoftmaxBackward0>)


(torch.Size([2, 3, 3]),
 tensor([[[ 0.0819, -0.0220, -0.0599],
          [-0.0220,  0.1848, -0.1628],
          [-0.0599, -0.1628,  0.2227]],
 
         [[ 0.2194, -0.0767, -0.1427],
          [-0.0767,  0.1803, -0.1036],
          [-0.1427, -0.1036,  0.2463]]]))

In [10]:
def softmax_deriv(P):
    # P is of shape (N, M)
    N, M = P.shape

    # Create a Kronecker delta (identity matrix) of shape (M, M)
    delta = torch.eye(M, dtype=P.dtype, device=P.device)

    # Expand P to have an extra dimension for broadcasting
    P_ik = P.unsqueeze(2)  # Shape: (N, M, 1)
    P_ij = P.unsqueeze(1)  # Shape: (N, 1, M)

    # Compute the expression: p_{ij} * (delta_{j=k} - p_{ik})
    # Here, delta expands along axis 0 (broadcasted across the batch dimension)
    return P_ij * (delta - P_ik)

print(P.shape, P)
softmax_deriv(P).shape, softmax_deriv(P)

torch.Size([2, 3]) tensor([[0.0900, 0.2447, 0.6652],
        [0.3251, 0.2361, 0.4388]], grad_fn=<SoftmaxBackward0>)


(torch.Size([2, 3, 3]),
 tensor([[[ 0.0819, -0.0220, -0.0599],
          [-0.0220,  0.1848, -0.1628],
          [-0.0599, -0.1628,  0.2227]],
 
         [[ 0.2194, -0.0767, -0.1427],
          [-0.0767,  0.1803, -0.1036],
          [-0.1427, -0.1036,  0.2463]]], grad_fn=<MulBackward0>))

In [57]:
z = torch.tensor([
    [1, 2, 3],
    [1.5, 2.2, 3.8],
    [-.5, 1.2, 0.7],
    [-.2, 1.3, 0.9],
], requires_grad=True)
x = torch.tensor([
    [0.9, 1.9, 2.9], 
    [0.1, 1.2, -2.9], 
], requires_grad=True)

# B x M x D
x_expanded = x.unsqueeze(1).expand(-1, z.shape[0], -1)        
diff = x_expanded - z
distances = torch.linalg.vector_norm(diff, ord=2, dim=-1)
softmax_distances = torch.softmax(-distances, dim=-1)
W_xz = softmax_distances
softmax_deriv = softmax_distances * (1 - softmax_distances)
W_xz_deriv = (softmax_deriv.unsqueeze(-1) * diff / distances.unsqueeze(-1)).transpose(2, 1)
torch.cat([W_xz, W_xz_deriv.reshape(-1, z.shape[0])], dim=0)

tensor([[ 0.6332,  0.2451,  0.0506,  0.0711],
        [ 0.0471,  0.0192,  0.5049,  0.4288],
        [-0.1341, -0.0989,  0.0249,  0.0308],
        [-0.1341, -0.0494,  0.0125,  0.0168],
        [-0.1341, -0.1483,  0.0391,  0.0560],
        [-0.0067, -0.0038,  0.0411,  0.0193],
        [-0.0060, -0.0027,  0.0000, -0.0064],
        [-0.0440, -0.0183, -0.2466, -0.2441]], grad_fn=<CatBackward0>)

In [56]:
def f(x, z):
    distances = torch.zeros(4)
    distances[0] = torch.linalg.vector_norm(x - z[0], ord=2, dim=0)
    distances[1] = torch.linalg.vector_norm(x - z[1], ord=2, dim=0)
    distances[2] = torch.linalg.vector_norm(x - z[2], ord=2, dim=0)
    distances[3] = torch.linalg.vector_norm(x - z[3], ord=2, dim=0)
    # distances = torch.tensor([torch.linalg.vector_norm(x - z[i], ord=2, dim=0) for i in range(4)])
    return torch.softmax(-distances, dim=-1)

# print(f(x[0], z))
# print(f(x[1], z))
# print(jacobian(partial(f, x[0]), z).diagonal(dim1=0, dim2=1))
# print(jacobian(partial(f, x[1]), z).diagonal(dim1=0, dim2=1))

torch.cat([
    f(x[0], z).unsqueeze(0),
    f(x[1], z).unsqueeze(0),
    jacobian(partial(f, x[0]), z).diagonal(dim1=0, dim2=1),
    jacobian(partial(f, x[1]), z).diagonal(dim1=0, dim2=1),
], dim=0)

tensor([[ 0.6332,  0.2451,  0.0506,  0.0711],
        [ 0.0471,  0.0192,  0.5049,  0.4288],
        [-0.1341, -0.0989,  0.0249,  0.0308],
        [-0.1341, -0.0494,  0.0125,  0.0168],
        [-0.1341, -0.1483,  0.0391,  0.0560],
        [-0.0067, -0.0038,  0.0411,  0.0193],
        [-0.0060, -0.0027,  0.0000, -0.0064],
        [-0.0440, -0.0183, -0.2466, -0.2441]], grad_fn=<CatBackward0>)