Skip to content

Commit

Permalink
Merge 5edbe61 into 8098292
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Feb 5, 2019
2 parents 8098292 + 5edbe61 commit 550ae53
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 71 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ cache:
- $HOME/miniconda3

env:
- PYTHON_VERSION=3.6 TORCH='pytorch<1.0.0' COVERAGE=''
- PYTHON_VERSION=3.6 TORCH='pytorch>=1.0.0' COVERAGE='--cov geoopt'
- PYTHON_VERSION=3.6 TORCH='pytorch-nightly' COVERAGE=''

Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ New Features
Maintenance
-----------
* Add gitter chat (#31)

* Maintain torch>=1.0.0 only

Deprecations
------------
Expand Down
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Methods”`_ ICLR2019 and more.

Installation
------------
Make sure you have pytorch>=1.0.0 installed

There are two ways to install geoopt:

1. GitHub (preferred so far) due to active development
Expand Down
3 changes: 0 additions & 3 deletions geoopt/_compat.py

This file was deleted.

13 changes: 2 additions & 11 deletions geoopt/optim/tracing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
import torch.jit
from .._compat import _TORCH_LESS_THAN_ONE


def _compat_trace(fn, args):
if _TORCH_LESS_THAN_ONE:
# torch.jit here does not support inplace ops
return fn
else:
return torch.jit.trace(fn, args)
import torch


def create_traced_update(step, manifold, point, *buffers, **kwargs):
Expand Down Expand Up @@ -38,4 +29,4 @@ def partial(*args):
step(manifold, *args, **kwargs)
return args

return _compat_trace(partial, (point, grad, lr) + buffers)
return torch.jit.trace(partial, (point, grad, lr) + buffers)
120 changes: 66 additions & 54 deletions geoopt/util/linalg.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,90 @@
import torch
import itertools
import warnings
from .._compat import _TORCH_LESS_THAN_ONE

__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank"]


def _warn_lost_grad(x, op):
if torch.is_grad_enabled() and x.requires_grad:
warnings.warn(
"Gradient for operation {0}(...) is lost, please use the pytorch native analog as {0} this "
"is more optimized for internal purposes which do not require gradients but vectorization".format(
op
)
)


@torch.jit.script
def svd(x):
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
_warn_lost_grad(x, "geoopt.utils.linalg.svd")
with torch.no_grad():
# prolonged here:
if x.dim() == 2:
# 17 milliseconds on my mac to check that condition, that is low overhead
result = torch.svd(x)
else:
batches = x.shape[:-2]
if batches:
# in most cases we do not require gradients when applying svd (e.g. in projection)
n, m = x.shape[-2:]
k = min(n, m)
U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k)
for idx in itertools.product(*map(range, batches)):
U[idx], d[idx], V[idx] = torch.svd(x[idx])
return U, d, V
else:
return torch.svd(x)
other = x.shape[-2:]
flat = x.view((-1,) + other)
slices = flat.unbind(0)
U, D, V = [], [], []
# I wish I had a parallel_for
for i in range(flat.shape[0]):
u, d, v = torch.svd(slices[i])
U.append(u)
D.append(d)
V.append(v)
U = torch.stack(U).view(batches + U[0].shape)
D = torch.stack(D).view(batches + D[0].shape)
V = torch.stack(V).view(batches + V[0].shape)
result = U, D, V
return result


@torch.jit.script
def qr(x):
# vectorized version as svd
_warn_lost_grad(x, "geoopt.utils.linalg.qr")
with torch.no_grad():
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
if x.dim() == 2:
result = torch.qr(x)
else:
batches = x.shape[:-2]
if batches:
# in most cases we do not require gradients when applying qr (e.g. in retraction)
assert not x.requires_grad
n, m = x.shape[-2:]
Q, R = x.new(*batches, n, m), x.new(*batches, m, m)
for idx in itertools.product(*map(range, batches)):
Q[idx], R[idx] = torch.qr(x[idx])
return Q, R
else:
return torch.qr(x)
other = x.shape[-2:]
flat = x.view((-1,) + other)
slices = flat.unbind(0)
Q, R = [], []
# I wish I had a parallel_for
for i in range(flat.shape[0]):
q, r = torch.qr(slices[i])
Q.append(q)
R.append(r)
Q = torch.stack(Q).view(batches + Q[0].shape)
R = torch.stack(R).view(batches + R[0].shape)
result = Q, R
return result


@torch.jit.script
def sym(x):
return 0.5 * (x.transpose(-1, -2) + x)


@torch.jit.script
def extract_diag(x):
n, m = x.shape[-2:]
k = min(n, m)
return x[..., torch.arange(k), torch.arange(k)]
batch = x.shape[:-2]
k = n if n < m else m
idx = torch.arange(k, dtype=torch.long, device=x.device)
x = x.view(-1, n, m)
return x[:, idx, idx].view(batch + (k,))


@torch.jit.script
def matrix_rank(x):
if _TORCH_LESS_THAN_ONE:
import numpy as np

return torch.from_numpy(
np.asarray(np.linalg.matrix_rank(x.detach().cpu().numpy()))
).type_as(x)
with torch.no_grad():
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
if x.dim() == 2:
ranks = torch.matrix_rank(x)
else:
batches = x.shape[:-2]
if batches:
out = x.new(*batches)
for idx in itertools.product(*map(range, batches)):
out[idx] = torch.matrix_rank(x[idx])
return out
else:
return torch.matrix_rank(x)
other = x.shape[-2:]
flat = x.view((-1,) + other)
slices = flat.unbind(0)
ranks = []
# I wish I had a parallel_for
for i in range(flat.shape[0]):
r = torch.matrix_rank(slices[i])
ranks.append(r)
ranks = torch.stack(ranks).view(batches)
return ranks
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_version(*path):
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
packages=find_packages(),
install_requires=["torch", "numpy"],
install_requires=["torch>=1.0.0", "numpy"],
version=get_version("geoopt", "__init__.py"),
url="https://github.com/ferrine/geoopt",
python_requires=">=3.6.0",
Expand Down

0 comments on commit 550ae53

Please sign in to comment.