In [1]:
%load_ext autoreload
%autoreload 2

Some notes:
- Are we 100% positive their graph distances are Floyd-Warshall?

In [2]:
import manify

In [3]:
D, y, A = manify.utils.dataloaders.load("cs_phds")

Top CC has 1025 nodes; original graph has 1025 nodes.


In [4]:
D

tensor([[ 0.,  1., 13.,  ..., 16., 12., 13.],
        [ 1.,  0., 12.,  ..., 15., 11., 12.],
        [13., 12.,  0.,  ..., 17.,  3., 14.],
        ...,
        [16., 15., 17.,  ...,  0., 18., 11.],
        [12., 11.,  3.,  ..., 18.,  0., 15.],
        [13., 12., 14.,  ..., 11., 15.,  0.]], dtype=torch.float64)

In [48]:
import geoopt
import torch
from tqdm.notebook import tqdm


def loss(D_est, D_true):
    idx = torch.triu_indices(D_est.shape[0], D_est.shape[1], offset=1)
    D_est = D_est[idx[0], idx[1]]
    D_true = D_true[idx[0], idx[1]]
    return torch.sum(torch.abs((D_est / D_true) ** 2 - 1))


def D_avg(D_est, D_true):
    idx = torch.triu_indices(D_est.shape[0], D_est.shape[1], offset=1)
    D_est = D_est[idx[0], idx[1]]
    D_true = D_true[idx[0], idx[1]]
    return (torch.abs(D_est - D_true) / D_true).mean().item()


def train_embedding(manifold, dim=2, lr=1e-2, normalize_dists=True):
    X = torch.randn(D.shape[0], dim)
    X = manifold.projx(X)
    X = geoopt.ManifoldParameter(X, manifold=manifold)
    optimizer = geoopt.optim.RiemannianAdam(params=[X], lr=lr)

    my_tqdm = tqdm(total=2000)
    print("Begin")
    D_true = D.clone() / D.max() if normalize_dists else D.clone()
    losses = []
    d_avgs = []
    for i in range(2000):
        optimizer.zero_grad()
        D_est = manifold.dist(X[:, None, :], X[None, :, :])
        l = loss(D_est, D_true)
        l.backward()
        optimizer.step()
        my_tqdm.update(1)
        my_tqdm.set_postfix(loss=l.item(), d_avg=D_avg(D_est, D))
        losses.append(l.item())
        d_avgs.append(D_avg(D_est, D))
    my_tqdm.close()
    return X.detach().cpu().numpy(), losses, d_avgs


train_embedding(geoopt.manifolds.PoincareBall(), lr=1, dim=10, normalize_dists=False)

  0%|          | 0/2000 [00:00<?, ?it/s]

Begin


(array([[-0.41266134,  0.04390064,  0.21087684, ...,  0.5866078 ,
          0.13129652, -0.419038  ],
        [-0.40916002,  0.0446205 ,  0.20808618, ...,  0.58192086,
          0.13010074, -0.41489148],
        [-0.4579871 , -0.08479324,  0.33678293, ..., -0.2594849 ,
         -0.15139666,  0.6791879 ],
        ...,
        [-0.07043439,  0.4154565 ,  0.35482875, ...,  0.1395706 ,
          0.26730317, -0.08804966],
        [-0.46073776, -0.08991821,  0.33055612, ..., -0.26560038,
         -0.16061793,  0.6790575 ],
        [ 0.47857633,  0.09600827,  0.10971424, ..., -0.37737453,
         -0.36313415, -0.08973501]], dtype=float32),
 [876499.8999500896,
  697155.8543704948,
  571681.6703396515,
  495040.98600926914,
  455719.236558,
  441023.8648388904,
  441721.2300390648,
  450188.5855244417,
  461044.04521805065,
  468613.8636491372,
  470770.4316042352,
  468517.4702401043,
  463086.8171283598,
  457389.7432038771,
  451598.4064156881,
  446196.24245001323,
  441542.2276338823,
  

In [22]:
manifold = geoopt.manifolds.Euclidean()

_x = torch.randn(10, 2)
manifold.dist2(_x[:, None, :], _x[None, :, :]).sum(axis=-1) ** 0.5

tensor([[0.0000, 1.2204, 1.5106, 0.8963, 1.3816, 3.3374, 1.1489, 2.8645, 0.4374,
         2.9514],
        [1.2204, 0.0000, 0.3209, 0.3546, 1.3069, 2.3802, 0.7575, 1.6540, 1.5099,
         2.2168],
        [1.5106, 0.3209, 0.0000, 0.6738, 1.3213, 2.3179, 1.0294, 1.4146, 1.8224,
         2.2566],
        [0.8963, 0.3546, 0.6738, 0.0000, 1.3328, 2.5399, 0.5714, 1.9693, 1.1582,
         2.2751],
        [1.3816, 1.3069, 1.3213, 1.3328, 0.0000, 3.6386, 1.9023, 2.6129, 1.8182,
         3.5234],
        [3.3374, 2.3802, 2.3179, 2.5399, 3.6386, 0.0000, 2.1976, 1.4558, 3.3876,
         0.7388],
        [1.1489, 0.7575, 1.0294, 0.5714, 1.9023, 2.1976, 0.0000, 1.9795, 1.2078,
         1.8116],
        [2.8645, 1.6540, 1.4146, 1.9693, 2.6129, 1.4558, 1.9795, 0.0000, 3.0977,
         1.8592],
        [0.4374, 1.5099, 1.8224, 1.1582, 1.8182, 3.3876, 1.2078, 3.0977, 0.0000,
         2.9178],
        [2.9514, 2.2168, 2.2566, 2.2751, 3.5234, 0.7388, 1.8116, 1.8592, 2.9178,
         0.0000]])

In [23]:
torch.linalg.norm(_x[:, None, :] - _x[None, :, :], dim=-1)

tensor([[0.0000, 1.2204, 1.5106, 0.8963, 1.3816, 3.3374, 1.1489, 2.8645, 0.4374,
         2.9514],
        [1.2204, 0.0000, 0.3209, 0.3546, 1.3069, 2.3802, 0.7575, 1.6540, 1.5099,
         2.2168],
        [1.5106, 0.3209, 0.0000, 0.6738, 1.3213, 2.3179, 1.0294, 1.4146, 1.8224,
         2.2566],
        [0.8963, 0.3546, 0.6738, 0.0000, 1.3328, 2.5399, 0.5714, 1.9693, 1.1582,
         2.2751],
        [1.3816, 1.3069, 1.3213, 1.3328, 0.0000, 3.6386, 1.9023, 2.6129, 1.8182,
         3.5234],
        [3.3374, 2.3802, 2.3179, 2.5399, 3.6386, 0.0000, 2.1976, 1.4558, 3.3876,
         0.7388],
        [1.1489, 0.7575, 1.0294, 0.5714, 1.9023, 2.1976, 0.0000, 1.9795, 1.2078,
         1.8116],
        [2.8645, 1.6540, 1.4146, 1.9693, 2.6129, 1.4558, 1.9795, 0.0000, 3.0977,
         1.8592],
        [0.4374, 1.5099, 1.8224, 1.1582, 1.8182, 3.3876, 1.2078, 3.0977, 0.0000,
         2.9178],
        [2.9514, 2.2168, 2.2566, 2.2751, 3.5234, 0.7388, 1.8116, 1.8592, 2.9178,
         0.0000]])