In [None]:
# default_exp distance

# Distance
> Distance calculation

Wiki on [k-d-tree](https://en.wikipedia.org/wiki/K-d_tree)

In [None]:
#export
import numpy as np
import torch
import plotly.graph_objects as go
import plotly.express as px
import timeit
import itertools

from stackoverflow:
```python
def pairwise_dist(x, y):
    xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x, y.t())
    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))
    P = (rx.t() + ry - 2*zz)
    return P
```

In [None]:
timeit.Timer(lambda : 1+1).repeat(repeat=3, number=7)

In [None]:
positions = torch.tensor([[0,0],
                          [1,1],
                          [2,2]], dtype=torch.float)
correct_distances = torch.tensor([2., 8., 2.], dtype=torch.float)

In [None]:
#export
class PairwiseDistance:
    def __call__(self, x:torch.Tensor, method:str='torch'):
        return getattr(self, f'{method}_pairwise_distance')(x)

In [None]:
pdist = PairwiseDistance()

Using the Gram matrix to compute pairwise distances

In [None]:
#export
def pairwise_dist_gram(x:torch.Tensor, y:torch.Tensor,
                       flat:bool=True) -> torch.Tensor:
    
    nx = x.size()[0]
    ny = y.size()[0]
    
    x2 = torch.mm(x, x.t())
    y2 = torch.mm(y, y.t())
    xy = torch.mm(x, y.t())
    
    x2 = x2.diag().unsqueeze(0).expand(ny,-1)
    y2 = y2.diag().unsqueeze(0).expand(nx,-1)
    d = x2.t() + y2 - 2*xy
    if flat:
        ix = torch.triu_indices(nx,ny,offset=1) #list(zip(*list(itertools.combinations(range(len(positions)), r=2))))
        return d[ix[0],ix[1]]
    return d

In [None]:
%%time
pairwise_dist_gram(positions, positions)

In [None]:
%%timeit
pairwise_dist_gram(positions, positions)

In [None]:
assert torch.isclose(correct_distances, 
              pdist(positions, method='stackoverflow')).all()

In [None]:
#export
def stackoverflow_pairwise_distance(self, x:torch.Tensor):
    return pairwise_dist_gram(x, x)

PairwiseDistance.stackoverflow_pairwise_distance = stackoverflow_pairwise_distance

In [None]:
pdist(positions, method='stackoverflow')

Using pre-existing implementation in pytorch: `torch.nn.PairwiseDistance`

In [None]:
#export
def torch_pairwise_distance(self, x:torch.Tensor):
    nx = ny = x.size()[0]
    pdist = torch.nn.PairwiseDistance(p=2, keepdim=True)
    ix = torch.triu_indices(nx,ny,offset=1)
    return pdist(x[ix[0],:], x[ix[1],:]).squeeze(1) ** 2

PairwiseDistance.torch_pairwise_distance = torch_pairwise_distance

Wow `torch.nn.PairwiseDistance` seems to be 10x faster than `pairwise_dist_gram`

In [None]:
pdist(positions, method='torch')

In [None]:
%timeit pdist(positions, method='torch')

In [None]:
assert torch.isclose(correct_distances, 
              pdist(positions, method='torch')).all()

TODO:
- test time scaling behavior of the distance functions
- mapping between flat and square form to answer nearest neighbor queries

In [None]:
timeit.Timer.repeat?