In [13]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [14]:
x = torch.tensor([
    [0.,0],
    [0,1],
    [1,1]
])
print(x)
print(x[None,:,:])

tensor([[0., 0.],
        [0., 1.],
        [1., 1.]])
tensor([[[0., 0.],
         [0., 1.],
         [1., 1.]]])


In [15]:
def method1(x):
    xx = F.pdist(x)
    m = torch.zeros((x.shape[0],x.shape[0]))
    triu_indices = torch.triu_indices(row=x.shape[0], col=x.shape[0], offset=1)
    m[triu_indices[0], triu_indices[1]] = xx
    m[triu_indices[1], triu_indices[0]] = xx
    return m


def method2(x):
    return torch.norm(x[:, None] - x, dim=2, p=2)


def method3(x):
    return (x[:, :, None, :] - x[:, None, :, :]).norm(p=2, dim=-1)

In [16]:
x = torch.rand((100,2))

In [17]:
%timeit method1(x)


107 µs ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [18]:
%timeit method2(x)

1.74 ms ± 3.73 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [19]:
%timeit method3(x[None,:,:])

1.75 ms ± 4.07 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [21]:
print(torch.norm(method1(x)[:,:] - method2(x)[:,:]))
print(torch.norm(method1(x)[:,:] - method3(x[None,:,:])[:,:]))
print(torch.norm(method2(x)[:,:] - method3(x[None,:,:])[:,:]))

tensor(0.)
tensor(0.)
tensor(0.)


In [None]:
model = nn.Embedding(100,128)

print(sum(t.numel() for t in model.parameters() if t.requires_grad))

model = nn.Linear(100,128)

print(sum(t.numel() for t in model.parameters() if t.requires_grad))

12800
12928


In [None]:
pi = torch.tensor([
    [0,2,1],
    [2,1,0]
])
# dataset = torch.tensor([
#     [
#         [1,1.],
#         [2,2],
#         [3,3]
#     ],
#     [
#         [-1,-1],
#         [-2,-2],
#         [-3,-3]
#     ]
# ])

for _ in range(10000):
    dataset = torch.rand((2,3,2))
    d = dataset.gather(1, pi.unsqueeze(-1).expand_as(dataset))
    res = (d[:, 1:] - d[:, :-1]).norm(p=2, dim=2).sum(1) + (d[:, 0] - d[:, -1]).norm(p=2, dim=1)

    # metodo con distanze
    dataset_dist = torch.stack(tuple(method1(d) for d in dataset),0)
    a = torch.arange(pi.shape[1])
    idx = torch.stack((a, a.roll(-1,0)))
    res_dist = dataset_dist[:,idx[0],idx[1]].sum(1)

    assert torch.norm(res - res_dist) <= 1e-6, f'{res} {res_dist}'

In [27]:
x = torch.tensor([
    [[1,1],[2,2],[3,3],[4,4]],
    [[5,5],[6,6],[7,7],[8,8]]])
print(x[:,None,:,:].shape,x[:,None,:,:])
print(x[:,:,None,:].shape, x[:,:,None,:])

torch.Size([2, 1, 4, 2]) tensor([[[[1, 1],
          [2, 2],
          [3, 3],
          [4, 4]]],


        [[[5, 5],
          [6, 6],
          [7, 7],
          [8, 8]]]])
torch.Size([2, 4, 1, 2]) tensor([[[[1, 1]],

         [[2, 2]],

         [[3, 3]],

         [[4, 4]]],


        [[[5, 5]],

         [[6, 6]],

         [[7, 7]],

         [[8, 8]]]])
