In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from torch import tensor, Tensor, matmul, diag, ones, zeros, cdist, inner, einsum, tensordot
import torch

In [None]:
my_float = torch.float32

In [None]:
layer1 = tensor([[0,0], [0,1], [0, 2]], dtype=my_float)
layer2 = tensor([[1,0], [1, 2]], dtype=my_float)
layer3 = tensor([[2,0], [2,1], [2, 2]], dtype=my_float)
layer1

tensor([[0., 0.],
        [0., 1.],
        [0., 2.]])

In [None]:
cdist(layer1, layer3)

tensor([[2.0000, 2.2361, 2.8284],
        [2.2361, 2.0000, 2.2361],
        [2.8284, 2.2361, 2.0000]])

In [None]:
cdist(layer2, layer3)

tensor([[1.0000, 1.4142, 2.2361],
        [2.2361, 1.4142, 1.0000]])

In [None]:
def cosines(a: Tensor, b: Tensor, c: Tensor) -> Tensor:
  r1 = cdist(a, b) #ij
  r2 = cdist(b, c) #jk
  d1 = b[None, :, :] - a[:, None, :] # ijc
  d2 = c[None, :, :] - b[:, None, :] # jkc
  rr = r1[:, :, None] * r2[None, :, :]
  return einsum('ijc,jkc->ijk', d1, d2) / rr

cosines(layer1, layer2, layer3)

tensor([[[ 1.0000,  0.7071,  0.4472],
         [-0.6000, -0.3162,  0.4472]],

        [[ 0.7071,  0.0000, -0.3162],
         [-0.3162,  0.0000,  0.7071]],

        [[ 0.4472, -0.3162, -0.6000],
         [ 0.4472,  0.7071,  1.0000]]])

In [None]:
from typing import Callable

M = 2

def curvature(a: Tensor, b: Tensor, c: Tensor) -> Callable[[Tensor, Tensor],Tensor]:
  r1 = cdist(a, b) #ij
  r2 = cdist(b, c) #jk
  d1 = b[None, :, :] - a[:, None, :] # ijc
  d2 = c[None, :, :] - b[:, None, :] # jkc
  rr = r1[:, :, None] * r2[None, :, :]

  def inner(v1: Tensor, v2: Tensor) -> Tensor:
    cosines = einsum('ijc,jkc->ijk', d1, d2) / rr
    return - 0.5 * ( torch.pow(cosines, M) * v1[:,:,None] * v2[None,:,:] / rr).sum()
  return inner

f = curvature(layer1, layer2, layer3)
[
  f(zeros((3,2)), zeros(2,3)),
  f(ones((3,2)), ones(2,3)),
  f(zeros((3,2)), ones(2,3)),
  f(ones((3,2)), zeros(2,3))
],

([tensor(-0.), tensor(-2.0212), tensor(-0.), tensor(-0.)],)

In [None]:
v1 = ones(3,2, requires_grad=True)
v2 = ones(2,3, requires_grad=True)

In [None]:
curvature(layer1, layer2, layer3)(v1, v2).backward()

In [None]:
v1.grad

tensor([[-0.7215, -0.0965],
        [-0.1926, -0.1926],
        [-0.0965, -0.7215]])

In [None]:
v2.grad

tensor([[-0.7215, -0.1926, -0.0965],
        [-0.0965, -0.1926, -0.7215]])

In [None]:
beta = 3.
def T2(a, b, c):
  N = len(a) + len(b) + len(c)
  def inner(v1, v2):
    return beta * torch.square(0.5 * (v1.sum() + v2.sum() - N))
  return inner
T1(layer1, layer2, layer3)(v1, v2)

tensor(1.1309, grad_fn=<MulBackward0>)

In [None]:
alpha = 0.2
def T1(a, b, c):
  def inner(v1, v2):
    return alpha/2 * (tensordot(v1, v1, [[0], [0]]).sum() - tensordot(v1, v1, 2) +
                      tensordot(v1, v1, [[1], [1]]).sum() - tensordot(v1, v1, 2) +
                      tensordot(v2, v2, [[0], [0]]).sum() - tensordot(v2, v2, 2) +
                      tensordot(v2, v2, [[1], [1]]).sum() - tensordot(v2, v2, 2))
  return inner
T1(layer1, layer2, layer3)(v1, v2)

tensor(1.1309, grad_fn=<MulBackward0>)

In [None]:
v1 = torch.full((3, 2), 1., requires_grad=True)
v2 = torch.full((2, 3), 1., requires_grad=True)
T1(layer1, layer2, layer3)(v1, v2).backward()
v2.grad

tensor([[0.6000, 0.6000, 0.6000],
        [0.6000, 0.6000, 0.6000]])

In [None]:
def energy(a,b,c):
  E1 = curvature(a, b, c)
  t1 = T1(a, b, c)
  t2 = T2(a, b, c)
  def inner(v1, v2):
    return E1(v1, v2) + t1(v1, v2) + t2(v1, v2)
  return inner

In [None]:
E = energy(layer1, layer2, layer3)
print(E(v1, v2))

tensor(13.5788, grad_fn=<AddBackward0>)


In [None]:
v1 = torch.full((3, 2), 0.5, requires_grad=True)
v2 = torch.full((2, 3), 0.5, requires_grad=True)
T = 3

for i in range(20):
  e = E(v1, v2)
  print(e.item())
  e.backward()
  v1 = torch.sigmoid(- v1.grad / T).clone().detach().requires_grad_(True)
  v2 = torch.sigmoid(- v2.grad / T).clone().detach().requires_grad_(True)


3.39469051361084
1.1352876424789429
7.777202129364014
3.7711129188537598
17.53969383239746
8.032204627990723
26.500661849975586
10.12683391571045
29.549264907836914
10.608461380004883
30.16474723815918
10.695420265197754
30.272871017456055
10.710372924804688
30.291379928588867
10.712919235229492
30.29452133178711
10.713351249694824
30.295055389404297
10.71342658996582


In [None]:
v1

tensor([[0.1521, 0.1281],
        [0.1316, 0.1316],
        [0.1281, 0.1521]], requires_grad=True)

In [None]:
v2

tensor([[0.1521, 0.1316, 0.1281],
        [0.1281, 0.1316, 0.1521]], requires_grad=True)