Skip to content

Commit

Permalink
move to linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Feb 5, 2019
1 parent 88882e9 commit d990f6b
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 11 deletions.
2 changes: 1 addition & 1 deletion geoopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from . import optim
from . import tensor
from . import samplers
from . import util
from . import linalg

from .tensor import ManifoldParameter, ManifoldTensor
from .manifolds import Stiefel, Euclidean, Sphere
Expand Down
1 change: 1 addition & 0 deletions geoopt/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .batch_linalg import svd, qr, sym, extract_diag, matrix_rank
File renamed without changes.
8 changes: 4 additions & 4 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from .base import Manifold
import geoopt.util.linalg
import geoopt.linalg.batch_linalg

__all__ = [
"Sphere",
Expand Down Expand Up @@ -95,7 +95,7 @@ class SphereSubspaceIntersection(Sphere):

def __init__(self, span):
self._configure_manifold(span)
if (geoopt.util.linalg.matrix_rank(self._projector) == 1).any():
if (geoopt.linalg.batch_linalg.matrix_rank(self._projector) == 1).any():
raise ValueError(
"Manifold only consists of isolated points when "
"subspace is 1-dimensional."
Expand All @@ -119,7 +119,7 @@ def _check_shape(self, x, name):
return ok, reason

def _configure_manifold(self, span):
Q, _ = geoopt.util.linalg.qr(span)
Q, _ = geoopt.linalg.batch_linalg.qr(span)
self._projector = Q @ Q.transpose(-1, -2)

def _project_on_subspace(self, x):
Expand Down Expand Up @@ -150,7 +150,7 @@ class SphereSubspaceComplementIntersection(SphereSubspaceIntersection):
"""

def _configure_manifold(self, span):
Q, _ = geoopt.util.linalg.qr(span)
Q, _ = geoopt.linalg.batch_linalg.qr(span)
P = -Q @ Q.transpose(-1, -2)
P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1
self._projector = P
10 changes: 5 additions & 5 deletions geoopt/manifolds/stiefel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from geoopt import util
from geoopt import linalg
from .base import Manifold


Expand Down Expand Up @@ -74,7 +74,7 @@ def _amat(self, x, u):
return u @ x.transpose(-1, -2) - x @ u.transpose(-1, -2)

def _projx(self, x):
U, d, V = util.linalg.svd(x)
U, d, V = linalg.batch_linalg.svd(x)
return torch.einsum("...ik,...k,...jk->...ij", [U, torch.ones_like(d), V])


Expand Down Expand Up @@ -152,7 +152,7 @@ class EuclideanStiefel(Stiefel):
reversible = False

def _proju(self, x, u):
return u - x @ util.linalg.sym(x.transpose(-1, -2) @ u)
return u - x @ linalg.batch_linalg.sym(x.transpose(-1, -2) @ u)

def _transp_one(self, x, u, t, v, y=None):
if y is None:
Expand All @@ -173,7 +173,7 @@ def _inner(self, x, u, v):
return (u * v).sum([-1, -2])

def _retr(self, x, u, t):
q, r = util.linalg.qr(x + u * t)
unflip = torch.sign(torch.sign(util.linalg.extract_diag(r)) + 0.5)
q, r = linalg.batch_linalg.qr(x + u * t)
unflip = torch.sign(torch.sign(linalg.batch_linalg.extract_diag(r)) + 0.5)
q *= unflip[..., None, :]
return q
1 change: 0 additions & 1 deletion geoopt/util/__init__.py

This file was deleted.

17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch
import numpy as np
import geoopt


@pytest.fixture
def A():
torch.manual_seed(42)
n = 10
a = torch.randn(n, 3, 3)
a[:, 2, :] = 0
return a


def test_svd(A):
u, d, v = geoopt.linalg.batch_linalg.svd(A)

0 comments on commit d990f6b

Please sign in to comment.