In [10]:
import numpy as np
import pandas as pd
import torch

%matplotlib inline
from IPython import display
from IPython.core.pylabtools import figsize
from matplotlib import pyplot as plt

In [68]:
def torch_or_numpy(fn_or_module_cls):
    torch_fn = fn_or_module_cls()
    def apply_fn(*args):
        if isinstance(args[0], torch.Tensor):
            return torch_fn(*args)
        else:
            return torch_fn(*[
                torch.tensor(np.array(arg, np.float))
                for arg in args]
           ).detach().numpy()
    return apply_fn


class EuclideanMagnitude(torch.autograd.Function):
    
    def forward(self, x):
        return torch.sqrt(torch.sum(x ** 2, -1))

mag = torch_or_numpy(EuclideanMagnitude)


class Dot(torch.autograd.Function):
    
    def forward(self, x, y):
        return torch.sum(x * y, -1)
    
dot = torch_or_numpy(Dot)

<img src="https://wikimedia.org/api/rest_v1/media/math/render/svg/aad0269ecc2abf70d7f2df4f2c1c9a4d33790583">

In [72]:
class PoincareDistance(torch.autograd.Function):
    
    def forward(self, u, v):
        return 2 * torch.log(

            # Numerator
            (mag(u - v) +
             torch.sqrt(
                 (mag(u) ** 2) * (mag(v) ** 2) -
                 2.0 * dot(u, v) +
                 1.0)) /
            
            # Denominator
            torch.sqrt(
                (1.0 - (mag(u) ** 2)) *
                (1.0 - (mag(v) ** 2))))

pdist = torch_or_numpy(PoincareDistance)

In [112]:
pdist(
    [[0.0, 0.0],
     [0.0, 0.0],
     [0.0, 0.0],
     [0.0, 0.0],
     [0.0, 0.0],
     [0.0, 0.0],
     [0.0, 0.0]],
    [[0.1, 0.1],
     [0.2, 0.2],
     [0.4, 0.4],
     [0.7, 0.7],
     [0.703, 0.703],
     [0.705, 0.705],
     [0.70710678118654, 0.70710678118654]])

array([  0.28475129,   0.58153858,   1.28230988,   5.28824152,
         5.83878116,   6.50767676,  32.86559956])

In [104]:
1 / np.sqrt(2)

0.70710678118654746