In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from manify.curvature_estimation.delta_hyperbolicity import delta_hyperbolicity, iterative_delta_hyperbolicity

In [3]:
from manify.manifolds import ProductManifold
import torch

def generate_data(n_dims=2, curvature=-1, n_samples=100):
    pm = ProductManifold(signature=[(curvature,n_dims)])
    X, _ = pm.sample(z_mean=torch.stack([pm.mu0] * n_samples))
    dists = pm.pdist(X)
    return dists

def test_deltas(deltas, dists_max):
    return (deltas <= dists_max).all() and (deltas >= -dists_max).all()

def test_vectorized_deltas(iterative_deltas, vectorized_deltas):
    return torch.allclose(iterative_deltas, vectorized_deltas)

def test_vectorized_deltas_shape(vectorized_deltas, n_samples):
    return vectorized_deltas.shape == (n_samples, n_samples, n_samples)

def test_sampled_deltas(sampled_deltas, vectorized_deltas, indices):
    sampled_vectorized_deltas = vectorized_deltas[indices]
    return torch.allclose(sampled_deltas, sampled_vectorized_deltas)

def test_gromov_products(gromov_products,dists_max):
    return (gromov_products >= 0).all() and (gromov_products <= dists_max).all()

# Maybe testing more one reference value (p) and making sure they are all equal/have the same ratio in them to justify point

In [4]:
dists = generate_data(n_samples=20)
vectorized_deltas = delta_hyperbolicity(dists,full=True,relative=True)
iterative_deltas, gromov_products = iterative_delta_hyperbolicity(dists)

print(f'Testing vectorized delta value ranges: {test_deltas(vectorized_deltas, torch.max(dists))}')
print(f'Testing iterative delta value ranges: {test_deltas(iterative_deltas, torch.max(dists))}')
print(f'Test vectorized against iterative: {test_vectorized_deltas(iterative_deltas, vectorized_deltas)}')
print(f'Test gromov products: {test_gromov_products(gromov_products,torch.max(dists))}')
print(f'Test vectorized delta shapes: {test_vectorized_deltas_shape(vectorized_deltas, 20)}')
print(f'Test Iterative delta shapes: {test_vectorized_deltas_shape(iterative_deltas, 20)}')

Testing vectorized delta value ranges: True
Testing iterative delta value ranges: True
Test vectorized against iterative: True
Test gromov products: True
Test vectorized delta shapes: True
Test Iterative delta shapes: True


In [5]:
print(vectorized_deltas[5][:5])

tensor([[ 0.0000, -0.3802, -0.8795, -0.7351, -0.9693, -1.0536, -0.9955, -0.5677,
         -0.1635, -0.3472, -0.6060, -0.9172, -0.2628, -0.9631, -1.0442, -0.6954,
         -0.4983, -0.6201, -0.8405, -0.5162],
        [ 0.0000,  0.0000, -0.5616, -0.3549, -0.5891, -0.6734, -0.6154, -0.1875,
         -0.1142,  0.0330, -0.2258, -0.6038,  0.1174, -0.6346, -0.6640, -0.3152,
         -0.1182, -0.3860, -0.4603, -0.1361],
        [ 0.0000, -0.0623,  0.0000, -0.1428, -0.1740, -0.1740, -0.1678, -0.0879,
          0.0324, -0.0637, -0.1057, -0.0376, -0.0531, -0.0835, -0.1647, -0.1012,
         -0.0935,  0.1011, -0.1408, -0.0903],
        [ 0.0000,  0.1626, -0.2872,  0.0000, -0.2342, -0.3185, -0.2604,  0.1634,
         -0.0673,  0.1980,  0.1291, -0.3122,  0.1730, -0.3183, -0.3091,  0.0397,
          0.2368, -0.2056, -0.1054,  0.2189],
        [ 0.0000,  0.0655, -0.0843,  0.2342,  0.0000, -0.0843, -0.0262,  0.0886,
         -0.0255,  0.0723,  0.1309, -0.0721,  0.0603, -0.0537, -0.0749,  0.0787,
      

In [6]:
print(iterative_deltas[5][:5])

tensor([[ 0.0000, -0.3802, -0.8795, -0.7351, -0.9693, -1.0536, -0.9955, -0.5677,
         -0.1635, -0.3472, -0.6060, -0.9172, -0.2628, -0.9631, -1.0442, -0.6954,
         -0.4983, -0.6201, -0.8405, -0.5162],
        [ 0.0000,  0.0000, -0.5616, -0.3549, -0.5891, -0.6734, -0.6154, -0.1875,
         -0.1142,  0.0330, -0.2258, -0.6038,  0.1174, -0.6346, -0.6640, -0.3152,
         -0.1182, -0.3860, -0.4603, -0.1361],
        [ 0.0000, -0.0623,  0.0000, -0.1428, -0.1740, -0.1740, -0.1678, -0.0879,
          0.0324, -0.0637, -0.1057, -0.0376, -0.0531, -0.0835, -0.1647, -0.1012,
         -0.0935,  0.1011, -0.1408, -0.0903],
        [ 0.0000,  0.1626, -0.2872,  0.0000, -0.2342, -0.3185, -0.2604,  0.1634,
         -0.0673,  0.1980,  0.1291, -0.3122,  0.1730, -0.3183, -0.3091,  0.0397,
          0.2368, -0.2056, -0.1054,  0.2189],
        [ 0.0000,  0.0655, -0.0843,  0.2342,  0.0000, -0.0843, -0.0262,  0.0886,
         -0.0255,  0.0723,  0.1309, -0.0721,  0.0603, -0.0537, -0.0749,  0.0787,
      

In [7]:
torch.sort(vectorized_deltas.flatten()).values / torch.sort(iterative_deltas.flatten()).values

tensor([1., 1., 1.,  ..., 1., 1., 1.], grad_fn=<DivBackward0>)

In [8]:
torch.sort(iterative_deltas.flatten()).values

tensor([-1.8494, -1.6930, -1.6055,  ...,  0.2895,  0.2975,  0.2975],
       grad_fn=<SortBackward0>)

In [28]:
# Test Gromov Products

def sampled_delta_hyperbolicity(dismat, n_samples=1000):
    # (j,k)_i = .5 (d(i,j) + d(i,k) - d(j,k))
    n = dismat.shape[0]
    # Sample n_samples triplets of points randomly
    indices = torch.randint(0, n, (n_samples, 3))

    # Get gromov products

    x,y,z = indices.T
    w = 0
    # xy_w = _vectorized_gromov(w,x,y,dismat)
    # xz_w = _vectorized_gromov(w,x,z,dismat)
    # yz_w = _vectorized_gromov(w,y,z,dismat)

    xy_w = .5 * (dismat[w,x] + dismat[w,y] - dismat[x,y])
    xz_w = .5 * (dismat[w,x] + dismat[w,z] - dismat[x,z])
    yz_w = .5 * (dismat[w,y] + dismat[w,z] - dismat[y,z])

    deltas = torch.minimum(xy_w,yz_w) - xz_w
    diam = torch.max(dismat)
    rel_deltas = 2 * deltas / diam

    return rel_deltas, indices


In [29]:
from manify.curvature_estimation.delta_hyperbolicity import delta_hyperbolicity, iterative_delta_hyperbolicity

sampled_deltas, indices = sampled_delta_hyperbolicity(dists)
iterative_deltas = iterative_delta_hyperbolicity(dists)[0][indices[:,0],indices[:,1],indices[:,2]]

torch.allclose(sampled_deltas, iterative_deltas)

True