In [None]:
#default_exp distance

# Distance
> Distance calculation

Wiki on [k-d-tree](https://en.wikipedia.org/wiki/K-d_tree)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#export
import numpy as np
import torch
import plotly.graph_objects as go
import plotly.express as px
import timeit
import typing
import itertools
import pandas as pd
import math

## Direct pairwise distances

Example positions and distances

In [3]:
positions = torch.tensor([[0,0],
                          [1,1],
                          [2,2]], dtype=torch.float)
correct_distances = torch.tensor([2., 8., 2.], dtype=torch.float)

In [4]:
#export
class PairwiseDistance:
    def __call__(self, x:torch.Tensor, method:str='torch'):
        return getattr(self, f'{method}_pairwise_distance')(x)

In [5]:
pdist = PairwiseDistance()

Using the Gram matrix to compute pairwise distances: [stackexchange](https://math.stackexchange.com/questions/2240429/pairwise-distance-matrix)

In [6]:
#export
def pairwise_dist_gram(x:torch.Tensor, y:torch.Tensor,
                       flat:bool=True) -> torch.Tensor:
    
    nx = x.size()[0]
    ny = y.size()[0]
    
    x2 = torch.mm(x, x.t())
    y2 = torch.mm(y, y.t())
    xy = torch.mm(x, y.t())
    
    x2 = x2.diag().unsqueeze(0).expand(ny,-1)
    y2 = y2.diag().unsqueeze(0).expand(nx,-1)
    d = x2.t() + y2 - 2*xy
    if flat:
        ix = torch.triu_indices(nx,ny,offset=1) #list(zip(*list(itertools.combinations(range(len(positions)), r=2))))
        return d[ix[0],ix[1]]
    return d

In [None]:
%%time
pairwise_dist_gram(positions, positions)

In [None]:
%%timeit
pairwise_dist_gram(positions, positions)

In [7]:
#export
def stackoverflow_pairwise_distance(self, x:torch.Tensor):
    return pairwise_dist_gram(x, x)

PairwiseDistance.stackoverflow_pairwise_distance = stackoverflow_pairwise_distance

In [None]:
pdist(positions, method='stackoverflow')

In [None]:
%timeit pdist(positions, method='stackoverflow')

In [None]:
assert torch.isclose(correct_distances, 
              pdist(positions, method='stackoverflow')).all()

Using pre-existing implementation in pytorch: `torch.nn.PairwiseDistance`

In [8]:
#export
def torch_pairwise_distance(self, x:torch.Tensor):
    nx = ny = x.size()[0]
    pdist = torch.nn.PairwiseDistance(p=2, keepdim=True)
    ix = torch.triu_indices(nx,ny,offset=1)
    return pdist(x[ix[0],:], x[ix[1],:]).squeeze(1) ** 2

PairwiseDistance.torch_pairwise_distance = torch_pairwise_distance

In [None]:
pdist(positions, method='torch')

In [None]:
%timeit pdist(positions, method='torch')

In [None]:
assert torch.isclose(correct_distances, 
              pdist(positions, method='torch')).all()

**Testing the scaling behavior**

In [9]:
#export
def measure_execution_time(fun:typing.Callable, args, kwargs, 
                           repetitions:int=3, number:int=21):
    return (timeit.Timer(lambda: fun(*args, **kwargs))
            .repeat(repeat=repetitions, number=number))


def get_time_stats(fun:typing.Callable, n:int, d:int, 
                   methods:typing.List[str]=['torch', 'stackoverflow'],
                   reps:int=7, number:int=21):
    
    timing_stats = []
    for method, _n, _d in itertools.product(methods, n, d):
        positions = torch.randn(_n, _d)
        timing_stats.append({
            'n':_n,
            'd':_d,
            'method': method,
            'ts': measure_execution_time(fun, args=(positions,), 
                                         kwargs=dict(method=method),
                                         repetitions=reps,
                                         number=number)
        })
    return pd.DataFrame(timing_stats)

In [None]:
# number of positions
n = [10, 100, 1000]
# number of dimensions
d = [2, 3, 10]

In [None]:
%%time
timing_stats = get_time_stats(pdist, n, d)

In [None]:
timing_stats.head()

In [None]:
timing_stats_e = timing_stats.explode('ts')

In [None]:
n_filter = 1000
d_filter = 3

In [None]:
mask_n = timing_stats['n'] == n_filter
mask_d = timing_stats['d'] == d_filter

In [None]:
px.scatter(timing_stats_e.loc[mask_d], x='n', y='ts', color='method',
           title=f'Performance @ {d_filter} dims')

In [None]:
px.scatter(timing_stats_e.loc[mask_n], x='d', y='ts', color='method',
           title=f'Performance @ {n_filter} samples')

**Mapping indices between square distance matrix and flat vector**

The following is based on stackexchange:
- [answer1](https://math.stackexchange.com/questions/2134011/conversion-of-upper-triangle-linear-index-from-index-on-symmetrical-array)
- [answer2](https://stackoverflow.com/questions/19143657/linear-indexing-in-symmetric-matrices)

In [10]:
#export
class DistanceMatrixIndexMapper:
    
    ix_fun = lambda self, i, j, n: int(n*i -.5*i*(i+1) + j)
    i_fun = lambda self, x, n: int(math.floor((2*n+1 - math.sqrt((2*n+1)**2 - 8*x))*.5))
    j_fun = lambda self, x, i, n: int(x - (n*i -.5*i*(i+1)))

    def __init__(self,num_pos:int):
        self.num_pos = num_pos
    
    def __call__(self,i:int, direction:str, j:int=None):
        return getattr(self, direction)(i,j)
    
    def brute_force_square2flat(self, i:int, j:int):
        if i > j: i, j = j, i
        self._check_ix_map()
        return self.ix_map[i,j]
        
    def _check_ix_map(self):
        if not hasattr(self, 'ix_map'):
            self.ix_map = {ix: i for i, ix in enumerate(itertools.combinations_with_replacement(range(self.num_pos), r=2))}
        
    def brute_force_flat2square(self, ix:int, j:int=None):
        self._check_ix_map_inv()
        return self.ix_map_inv[ix]
    
    def _check_ix_map_inv(self):
        self._check_ix_map()
        if not hasattr(self, 'ix_map_inv'):
            self.ix_map_inv = {i: ix for ix, i in self.ix_map.items()}
    
    def analytical_square2flat(self, i:int, j:int):
        if i > j: i, j = j, i
        return self.ix_fun(i, j, self.num_pos)
    
    def analytical_flat2square(self, ix:int, j:int=None):
        return (i := self.i_fun(ix, self.num_pos)), (self.j_fun(ix, i, self.num_pos))


In [None]:
i = 0
j = 1
ix = 1

num_pos = len(positions)

In [None]:
dmap = DistanceMatrixIndexMapper(num_pos)

Square $(i,j) \rightarrow ix$ flat

In [None]:
dmap(i,'brute_force_square2flat', j)

In [None]:
dmap(i,'analytical_square2flat', j)

In [None]:
#hide
%timeit dmap(i,'brute_force_square2flat', j)

In [None]:
#hide
%timeit dmap(i,'analytical_square2flat', j)

Flat $ix \rightarrow (i,j)$ square

In [None]:
dmap(ix,'brute_force_flat2square')

In [None]:
dmap(ix,'analytical_flat2square')

In [None]:
#hide
%timeit dmap(ix,'brute_force_flat2square')

In [None]:
#hide
%timeit dmap(ix,'analytical_flat2square')

So the analytical version is quite a bit slower, but on the other hand does not have a memory scaling to consider, in contrast to the brute force variant

Performing sanity checks

In [None]:
#hide
%time
df = pd.DataFrame([[ix, i, j] for ix, (i,j) in enumerate(zip(*np.triu_indices(num_pos, k=0)))],
                  columns=['ix', 'i', 'j'])

for method in ['brute_force', 'analytical']:
    df_test = pd.DataFrame(columns=['ix','i','j'], dtype=int)
    df_test['i'], df_test['j'] = list(zip(*df.apply(lambda x: dmap(x['ix'], 'analytical_flat2square'), axis=1).values))
    df_test['ix'] = df.apply(lambda x: dmap(x['i'], 'analytical_square2flat', x['j']), axis=1).values

    assert df_test.equals(df), f'{method}: failed! `df` {df}\n!=\n`df_test` {df_test}'

Visualising the correct and the calculated (`-fun`) maps

In [None]:
go.Figure(data=[
    go.Scatter(x=df['ix'], y=df['ix'], mode='markers', name='ix'),
    go.Scatter(x=df_test['ix'], y=df_test['ix'], mode='markers', 
               marker_symbol='x-open', name='ix-fun'),
],
         layout=go.Layout(title=f'Positions: {num_pos}', 
                          xaxis_title='ix', yaxis_title='calculated'))

In [None]:
go.Figure(data=[
    go.Scatter(x=df['ix'], y=df['i'], mode='markers', name='i'),
    go.Scatter(x=df['ix'], y=df['j'], mode='markers', name='j'),
    go.Scatter(x=df_test['ix'], y=df_test['i'], mode='markers', 
               marker_symbol='x-open', name='i-fun'),
    go.Scatter(x=df_test['ix'], y=df_test['j'], mode='markers', 
               marker_symbol='x-open', name='j-fun'),
],
         layout=go.Layout(title=f'Positions: {num_pos}', 
                          xaxis_title='ix', yaxis_title='calculated'))

## Distance vectors

In [11]:
i, j = 0, 1

In [None]:
positions

In [None]:
positions[i] - positions[j]

In [None]:
dvec = positions[[i,i]] - positions[[i,j]]; dvec

In [None]:
torch.Tensor([1,2,3]).size()

In [25]:
#export
class Hull:
    unit_vectors:torch.Tensor = None
    magnitudes:torch.Tensor = None
    vecs:torch.Tensor = None
        
    @classmethod
    def cuboid(self, magnitudes:torch.Tensor):
        d = magnitudes.size()[0]
        box = self()
        box.unit_vectors = torch.diag(torch.ones(d))
        box.magnitudes = magnitudes
        box.vecs = box.unit_vectors * magnitudes
        return box
    
    def to_2dpositions(self): pass

In [26]:
box_lengths = torch.Tensor([2., 3.])
box = Hull.cuboid(box_lengths)
box.unit_vectors, box.magnitudes, box.vecs

(tensor([[1., 0.],
         [0., 1.]]),
 tensor([2., 3.]),
 tensor([[2., 0.],
         [0., 3.]]))

In [27]:
#hide
assert torch.Tensor([[1., 0.],[0., 1.]]).isclose(box.unit_vectors).all()
assert torch.Tensor([[2., 0.],[0., 3.]]).isclose(box.vecs).all()
assert torch.Tensor([2., 3.]).isclose(box.magnitudes).all()

In [28]:
#export
def to_2dpositions(self, d0:int, d1:int):
    to_np = lambda x: x.clone().numpy()
    x = to_np(self.vecs[d0])
    y = to_np(self.vecs[d1])
    xy = to_np(self.vecs[:,d0] + self.vecs[:,d1])
    
    # adding the origin and the combination of the base vectors
    # the order may look weird but is to finish a whole circle
    x = np.array([0, x[d0], xy[d0], x[d1], 0]) 
    y = np.array([0, y[d0], xy[d1], y[d1], 0]) 
    return {'x': x, 'y': y}

Hull.to_2dpositions = to_2dpositions

In [29]:
d0, d1 = 0, 1
box_pos = box.to_2dpositions(d0, d1)
box_pos

{'x': array([0., 2., 2., 0., 0.]), 'y': array([0., 0., 3., 3., 0.])}

In [None]:
#hide
assert np.allclose(np.array([2.,0.,0.,2.]),box_pos['x'])
assert np.allclose(np.array([0.,3.,0.,3.]),box_pos['y'])

In [37]:
#export
def plot_atoms_and_hull(positions:torch.Tensor,
                        hull:Hull, i:int=0, j:int=1,
                        d0:int=0, d1:int=1):
    
    d0, d1 = sorted([d0, d1])
    box_pos = hull.to_2dpositions(d0,d1)
    
    fig = go.Figure(data=[
        go.Scatter(x=box_pos['x'], y=box_pos['y'], name='box', 
                   fill='toself', mode='lines', opacity=.5),
        go.Scatter(x=[positions[i,d0]], y=[positions[i,d1]], name='atom i',
                   mode='markers'),
        go.Scatter(x=[positions[j,d0]], y=[positions[j,d1]], name='atom j',
                   mode='markers'),
    ], layout=go.Layout(xaxis_title_text=f'd{d0}', yaxis_title_text=f'd{d1}',
                        title='Atoms in a box'))

    return fig

In [38]:
i, j = 1, 0
d0, d1 = 0, 1
plot_atoms_and_hull(positions, box, i, j, d0, d1)

In [85]:
from ipywidgets import widgets

In [95]:
box_lengths = torch.Tensor([2., 3.])
box = Hull.cuboid(box_lengths)

positions = torch.tensor([[0,0],
                          [1,1],
                          [2,2],
                          [1,2]], dtype=torch.float)

In [99]:
i_field = widgets.BoundedIntText(
    value=0,
    min=0,
    max=positions.size()[0],
    step=1,
    description='i:',
)
j_field = widgets.BoundedIntText(
    value=1,
    min=0,
    max=positions.size()[0],
    step=1,
    description='j:',
)
d0_field = widgets.BoundedIntText(
    value=0,
    min=0,
    max=box.magnitudes.size()[0],
    step=1,
    description='d0:',
)
d1_field = widgets.BoundedIntText(
    value=1,
    min=0,
    max=box.magnitudes.size()[0],
    step=1,
    description='d1:',
)
btn_plot = widgets.Button(description='Plot')
out_pl = widgets.Output()

In [100]:
def on_click_plot(change):
    
    i, j = i_field.value, j_field.value
    d0, d1 = d0_field.value, d1_field.value
    
    out_pl.clear_output()
    with out_pl:    
        fig = plot_atoms_and_hull(positions, box, i, j, d0, d1)
        display(fig)

In [101]:
btn_plot.on_click(on_click_plot)

In [102]:
widgets.VBox([widgets.HBox([i_field, j_field]), 
              widgets.HBox([d0_field, d1_field]), 
              btn_plot, out_pl])

VBox(children=(HBox(children=(BoundedIntText(value=0, description='i:', max=4), BoundedIntText(value=1, descri…

In [103]:
positions

tensor([[0., 0.],
        [1., 1.],
        [2., 2.],
        [1., 2.]])

## Distances in periodic boxes

[wiki](https://en.wikipedia.org/wiki/Periodic_boundary_conditions)

TODO: visualize distance vectors with dynamic abilities to change the box properties