In [1]:
import sys
import geoopt
import torch
import numpy as np
import pytest

import matplotlib.pyplot as plt

Use `geoopt` to do line search optimization where parameters are constrained to manifolds. 

As example we will try to find the basis of a matrix $A$ in which it has the smallest 1-norm by minimizing $\|XAY\|_1$ where $X,Y$ are orthogonal

In [2]:
sphere = geoopt.manifolds.Sphere()
torus = geoopt.manifolds.ProductManifold(*[(sphere,10)]*2)
point = torus.random(20)
point

Tensor on (Sphere)x(Sphere) manifold containing:
tensor([-0.2364, -0.3048,  0.1211, -0.4428,  0.2482, -0.2867,  0.2788, -0.5560,
         0.3234,  0.0728, -0.0279, -0.0252,  0.4324, -0.2188, -0.2871, -0.1346,
        -0.3842, -0.3572,  0.1025, -0.6144])

In [3]:
(n,m) = (100,200)
stiefel = geoopt.manifolds.Stiefel()
stief_prod = geoopt.manifolds.ProductManifold((stiefel,(n,n)),(stiefel,(m,m)))
XY = stief_prod.random(n**2+m**2)
stief_prod.unpack_tensor(XY)

(tensor([[-0.1333,  0.1254, -0.0490,  ..., -0.1619, -0.0184,  0.0007],
         [-0.0771,  0.0013,  0.0826,  ...,  0.0609,  0.0841, -0.0999],
         [ 0.1675,  0.0737, -0.0271,  ..., -0.0092,  0.1793,  0.1258],
         ...,
         [ 0.0209, -0.0217, -0.1054,  ..., -0.1240, -0.1194,  0.0341],
         [-0.0329, -0.0586,  0.1153,  ..., -0.0039,  0.1052, -0.0061],
         [ 0.0348, -0.1333, -0.0477,  ...,  0.1258, -0.0130,  0.1702]]),
 tensor([[-0.0387, -0.1010,  0.1340,  ...,  0.0038, -0.0154,  0.0201],
         [ 0.1047, -0.0535, -0.0156,  ..., -0.1084, -0.0014,  0.0946],
         [-0.0749, -0.0755, -0.0089,  ...,  0.0665,  0.0425,  0.0003],
         ...,
         [-0.0294, -0.0646,  0.0375,  ..., -0.0364,  0.0114,  0.0098],
         [-0.0093, -0.1048,  0.0372,  ...,  0.0916, -0.1114,  0.0966],
         [ 0.0858,  0.0218,  0.0418,  ..., -0.0280, -0.0259,  0.0879]]))

In [7]:
(n,m) = (100,200)
A = torch.randn(n,m)
stiefel = geoopt.manifolds.Stiefel()

X = stiefel.random((n,n))
X.requires_grad=True
Y = stiefel.random((m,m))
Y.requires_grad=True

(n,m) = (100,200)
stiefel = geoopt.manifolds.Stiefel()
stief_prod = geoopt.manifolds.ProductManifold((stiefel,(n,n)),(stiefel,(m,m)))
XY = stief_prod.random(n**2+m**2)
XY.requires_grad=True


def closure():
    optim.zero_grad()
    #X,Y = stief_prod.unpack_tensor(XY)
    loss = (X@A@Y).norm(p=1)
    loss.backward()

    return loss.item()

optim = geoopt.optim.RiemannianLineSearch([{'a':[X]},{'b':[Y]}],stabilize=2)
losses = []
losses2 = []
for i in range(500):
    losses.append(optim.step(closure))


plt.plot(losses)
plt.figure()
log_step_sizes = (np.log10([x if (x is not None) else 1 for x in optim.step_size_history]))
plt.plot(log_step_sizes)
plt.show()

KeyError: 'params'