Skip to content

Commit

Permalink
refactor code into torch.1.0.0 (#39)
Browse files Browse the repository at this point in the history
* refactor code into torch.1.0.0

* fix typo in a comment

* black

* one more typo in comment spotted

* torch script fix

* fix list add bug

* comment

* move to linalg
  • Loading branch information
ferrine committed Feb 5, 2019
1 parent 8098292 commit 0453504
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 106 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ cache:
- $HOME/miniconda3

env:
- PYTHON_VERSION=3.6 TORCH='pytorch<1.0.0' COVERAGE=''
- PYTHON_VERSION=3.6 TORCH='pytorch>=1.0.0' COVERAGE='--cov geoopt'
- PYTHON_VERSION=3.6 TORCH='pytorch-nightly' COVERAGE=''

Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ New Features
Maintenance
-----------
* Add gitter chat (#31)

* Maintain torch>=1.0.0 only

Deprecations
------------
Expand Down
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Methods”`_ ICLR2019 and more.

Installation
------------
Make sure you have pytorch>=1.0.0 installed

There are two ways to install geoopt:

1. GitHub (preferred so far) due to active development
Expand Down
2 changes: 1 addition & 1 deletion geoopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from . import optim
from . import tensor
from . import samplers
from . import util
from . import linalg

from .tensor import ManifoldParameter, ManifoldTensor
from .manifolds import Stiefel, Euclidean, Sphere
Expand Down
3 changes: 0 additions & 3 deletions geoopt/_compat.py

This file was deleted.

1 change: 1 addition & 0 deletions geoopt/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .batch_linalg import svd, qr, sym, extract_diag, matrix_rank
95 changes: 95 additions & 0 deletions geoopt/linalg/batch_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch

__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank"]


@torch.jit.script
def svd(x):
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
if x.dim() == 2:
# 17 milliseconds on my mac to check that condition, that is low overhead
result = torch.svd(x)
else:
batches = x.shape[:-2]
other = x.shape[-2:]
flat = x.view((-1,) + other)
slices = flat.unbind(0)
U, D, V = [], [], []
# I wish I had a parallel_for
for i in range(flat.shape[0]):
u, d, v = torch.svd(slices[i])
U += [u]
D += [d]
V += [v]
U = torch.stack(U).view(batches + U[0].shape)
D = torch.stack(D).view(batches + D[0].shape)
V = torch.stack(V).view(batches + V[0].shape)
result = U, D, V
return result


@torch.jit.script
def qr(x):
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
if x.dim() == 2:
result = torch.qr(x)
else:
batches = x.shape[:-2]
other = x.shape[-2:]
flat = x.view((-1,) + other)
slices = flat.unbind(0)
Q, R = [], []
# I wish I had a parallel_for
for i in range(flat.shape[0]):
q, r = torch.qr(slices[i])
Q += [q]
R += [r]
Q = torch.stack(Q).view(batches + Q[0].shape)
R = torch.stack(R).view(batches + R[0].shape)
result = Q, R
return result


@torch.jit.script
def sym(x):
return 0.5 * (x.transpose(-1, -2) + x)


@torch.jit.script
def extract_diag(x):
n, m = x.shape[-2:]
batch = x.shape[:-2]
k = n if n < m else m
idx = torch.arange(k, dtype=torch.long, device=x.device)
# torch script does not support Ellipsis indexing
x = x.view(-1, n, m)
return x[:, idx, idx].view(batch + (k,))


@torch.jit.script
def matrix_rank(x):
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
if x.dim() == 2:
result = torch.matrix_rank(x)
else:
batches = x.shape[:-2]
other = x.shape[-2:]
flat = x.view((-1,) + other)
slices = flat.unbind(0)
ranks = []
# I wish I had a parallel_for
for i in range(flat.shape[0]):
r = torch.matrix_rank(slices[i])
# interesting,
# ranks.append(r)
# does not work on pytorch 1.0.0
# but the below code does
ranks += [r]
result = torch.stack(ranks).view(batches)
return result
8 changes: 4 additions & 4 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from .base import Manifold
import geoopt.util.linalg
import geoopt.linalg.batch_linalg

__all__ = [
"Sphere",
Expand Down Expand Up @@ -95,7 +95,7 @@ class SphereSubspaceIntersection(Sphere):

def __init__(self, span):
self._configure_manifold(span)
if (geoopt.util.linalg.matrix_rank(self._projector) == 1).any():
if (geoopt.linalg.batch_linalg.matrix_rank(self._projector) == 1).any():
raise ValueError(
"Manifold only consists of isolated points when "
"subspace is 1-dimensional."
Expand All @@ -119,7 +119,7 @@ def _check_shape(self, x, name):
return ok, reason

def _configure_manifold(self, span):
Q, _ = geoopt.util.linalg.qr(span)
Q, _ = geoopt.linalg.batch_linalg.qr(span)
self._projector = Q @ Q.transpose(-1, -2)

def _project_on_subspace(self, x):
Expand Down Expand Up @@ -150,7 +150,7 @@ class SphereSubspaceComplementIntersection(SphereSubspaceIntersection):
"""

def _configure_manifold(self, span):
Q, _ = geoopt.util.linalg.qr(span)
Q, _ = geoopt.linalg.batch_linalg.qr(span)
P = -Q @ Q.transpose(-1, -2)
P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1
self._projector = P
10 changes: 5 additions & 5 deletions geoopt/manifolds/stiefel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from geoopt import util
from geoopt import linalg
from .base import Manifold


Expand Down Expand Up @@ -74,7 +74,7 @@ def _amat(self, x, u):
return u @ x.transpose(-1, -2) - x @ u.transpose(-1, -2)

def _projx(self, x):
U, d, V = util.linalg.svd(x)
U, d, V = linalg.batch_linalg.svd(x)
return torch.einsum("...ik,...k,...jk->...ij", [U, torch.ones_like(d), V])


Expand Down Expand Up @@ -152,7 +152,7 @@ class EuclideanStiefel(Stiefel):
reversible = False

def _proju(self, x, u):
return u - x @ util.linalg.sym(x.transpose(-1, -2) @ u)
return u - x @ linalg.batch_linalg.sym(x.transpose(-1, -2) @ u)

def _transp_one(self, x, u, t, v, y=None):
if y is None:
Expand All @@ -173,7 +173,7 @@ def _inner(self, x, u, v):
return (u * v).sum([-1, -2])

def _retr(self, x, u, t):
q, r = util.linalg.qr(x + u * t)
unflip = torch.sign(torch.sign(util.linalg.extract_diag(r)) + 0.5)
q, r = linalg.batch_linalg.qr(x + u * t)
unflip = torch.sign(torch.sign(linalg.batch_linalg.extract_diag(r)) + 0.5)
q *= unflip[..., None, :]
return q
13 changes: 2 additions & 11 deletions geoopt/optim/tracing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
import torch.jit
from .._compat import _TORCH_LESS_THAN_ONE


def _compat_trace(fn, args):
if _TORCH_LESS_THAN_ONE:
# torch.jit here does not support inplace ops
return fn
else:
return torch.jit.trace(fn, args)
import torch


def create_traced_update(step, manifold, point, *buffers, **kwargs):
Expand Down Expand Up @@ -38,4 +29,4 @@ def partial(*args):
step(manifold, *args, **kwargs)
return args

return _compat_trace(partial, (point, grad, lr) + buffers)
return torch.jit.trace(partial, (point, grad, lr) + buffers)
1 change: 0 additions & 1 deletion geoopt/util/__init__.py

This file was deleted.

78 changes: 0 additions & 78 deletions geoopt/util/linalg.py

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_version(*path):
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
packages=find_packages(),
install_requires=["torch", "numpy"],
install_requires=["torch>=1.0.0", "numpy"],
version=get_version("geoopt", "__init__.py"),
url="https://github.com/ferrine/geoopt",
python_requires=">=3.6.0",
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch
import numpy as np
import geoopt


@pytest.fixture
def A():
torch.manual_seed(42)
n = 10
a = torch.randn(n, 3, 3)
a[:, 2, :] = 0
return a


def test_svd(A):
u, d, v = geoopt.linalg.batch_linalg.svd(A)

0 comments on commit 0453504

Please sign in to comment.