In [None]:
# default_exp distance

# Distance
> Distance calculation

Wiki on [k-d-tree](https://en.wikipedia.org/wiki/K-d_tree)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
import numpy as np
import torch
import plotly.graph_objects as go
import plotly.express as px
import timeit
import itertools
import pandas as pd

## Direct pairwise distances

Example positions and distances

In [None]:
positions = torch.tensor([[0,0],
                          [1,1],
                          [2,2]], dtype=torch.float)
correct_distances = torch.tensor([2., 8., 2.], dtype=torch.float)

In [None]:
#export
class PairwiseDistance:
    def __call__(self, x:torch.Tensor, method:str='torch'):
        return getattr(self, f'{method}_pairwise_distance')(x)

In [None]:
pdist = PairwiseDistance()

Using the Gram matrix to compute pairwise distances: [stackexchange](https://math.stackexchange.com/questions/2240429/pairwise-distance-matrix)

In [None]:
#export
def pairwise_dist_gram(x:torch.Tensor, y:torch.Tensor,
                       flat:bool=True) -> torch.Tensor:
    
    nx = x.size()[0]
    ny = y.size()[0]
    
    x2 = torch.mm(x, x.t())
    y2 = torch.mm(y, y.t())
    xy = torch.mm(x, y.t())
    
    x2 = x2.diag().unsqueeze(0).expand(ny,-1)
    y2 = y2.diag().unsqueeze(0).expand(nx,-1)
    d = x2.t() + y2 - 2*xy
    if flat:
        ix = torch.triu_indices(nx,ny,offset=1) #list(zip(*list(itertools.combinations(range(len(positions)), r=2))))
        return d[ix[0],ix[1]]
    return d

In [None]:
%%time
pairwise_dist_gram(positions, positions)

In [None]:
%%timeit
pairwise_dist_gram(positions, positions)

In [None]:
#export
def stackoverflow_pairwise_distance(self, x:torch.Tensor):
    return pairwise_dist_gram(x, x)

PairwiseDistance.stackoverflow_pairwise_distance = stackoverflow_pairwise_distance

In [None]:
pdist(positions, method='stackoverflow')

In [None]:
%timeit pdist(positions, method='stackoverflow')

In [None]:
assert torch.isclose(correct_distances, 
              pdist(positions, method='stackoverflow')).all()

Using pre-existing implementation in pytorch: `torch.nn.PairwiseDistance`

In [None]:
#export
def torch_pairwise_distance(self, x:torch.Tensor):
    nx = ny = x.size()[0]
    pdist = torch.nn.PairwiseDistance(p=2, keepdim=True)
    ix = torch.triu_indices(nx,ny,offset=1)
    return pdist(x[ix[0],:], x[ix[1],:]).squeeze(1) ** 2

PairwiseDistance.torch_pairwise_distance = torch_pairwise_distance

In [None]:
pdist(positions, method='torch')

In [None]:
%timeit pdist(positions, method='torch')

In [None]:
assert torch.isclose(correct_distances, 
              pdist(positions, method='torch')).all()

**Testing the scaling behavior**

In [None]:
# number of positions
n = [5, 10, 50, 100, 500, 1000]
# number of dimensions
d = [2, 3, 4, 5, 6, 7, 8, 9, 10]

In [None]:
timeit.Timer(lambda: pdist(positions, method='torch')).repeat(repeat=3, number=7)

In [None]:
#export
def measure_execution_time(fun, args, kwargs, repetitions:int=3,
                           number:int=21):
    ts = []
    return (timeit.Timer(lambda: fun(*args, **kwargs))
            .repeat(repeat=repetitions, number=number))


In [None]:
reps = 7
number = 21

In [None]:
%%time
timing_stats = []
for method, _n, _d in itertools.product(['torch', 'stackoverflow'], n, d):
    positions = torch.randn(_n, _d)
    timing_stats.append({
        'n':_n,
        'd':_d,
        'method': method,
        'ts': measure_execution_time(pdist, args=(positions,), 
                                     kwargs=dict(method=method),
                                     repetitions=reps,
                                     number=number)
    })
        
timing_stats = pd.DataFrame(timing_stats)

In [None]:
timing_stats.head()

In [None]:
timing_stats_e = timing_stats.explode('ts')

In [None]:
n_filter = 100
d_filter = 3

In [None]:
mask_n = timing_stats['n'] == n_filter
mask_d = timing_stats['d'] == d_filter

In [None]:
px.scatter(timing_stats_e.loc[mask_d], x='n', y='ts', color='method',
           title=f'Performance @ {d_filter} dims')

In [None]:
px.scatter(timing_stats_e.loc[mask_n], x='d', y='ts', color='method',
           title=f'Performance @ {n_filter} samples')

TODO:
- mapping between flat and square form to answer nearest neighbor queries