From e3039d425b4f9f86a815347b6d692421d33b0971 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 15 Jan 2019 02:42:55 +0300 Subject: [PATCH 1/8] add some not working code --- geoopt/manifolds/__init__.py | 2 +- geoopt/manifolds/sphere.py | 87 ++++++++++++++++++++++++++++++++++++ geoopt/util/linalg.py | 74 ++++++++++++++++++++---------- tests/test_manifold.py | 55 +++++++++++++++-------- 4 files changed, 174 insertions(+), 44 deletions(-) diff --git a/geoopt/manifolds/__init__.py b/geoopt/manifolds/__init__.py index 8741d288..ac117de7 100644 --- a/geoopt/manifolds/__init__.py +++ b/geoopt/manifolds/__init__.py @@ -1,4 +1,4 @@ from .base import Manifold from .euclidean import Euclidean from .stiefel import Stiefel, EuclideanStiefel, CanonicalStiefel -from .sphere import Sphere +from .sphere import Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection diff --git a/geoopt/manifolds/sphere.py b/geoopt/manifolds/sphere.py index fa0250f7..f2b56984 100644 --- a/geoopt/manifolds/sphere.py +++ b/geoopt/manifolds/sphere.py @@ -1,6 +1,13 @@ import torch from .base import Manifold +import geoopt.util.linalg + +__all__ = [ + "Sphere", + "SphereSubspaceIntersection", + "SphereSubspaceComplementIntersection" +] class Sphere(Manifold): @@ -67,3 +74,83 @@ def _retr_transp(self, x, u, t, v, *more): y = self._retr(x, u, t) vs = self._transp_many(x, u, t, v, *more, y=y) return (y,) + vs + + +class SphereSubspaceIntersection(Sphere): + """ + Sphere manifold induced by the following constraint + + .. math:: + + \|x\|=1\\ + x \in \mathbb{span}(U) + + Parameters + ---------- + span : matrix + the subspace to intersect with + """ + + name = "SphereSubspace" + + def __init__(self, span): + self._configure_manifold(span) + if (geoopt.util.linalg.matrix_rank(self._projector) == 1).any(): + raise ValueError( + "Manifold only consists of isolated points when " + "subspace is 1-dimensional." + ) + + def _check_shape(self, x, name): + ok, reason = super()._check_shape(x, name) + if ok: + + ok = x.shape[-1] == self._projector.shape[-2] + if not ok: + reason = "The leftmost shape of `span` does not match `x`: {}, {}".format( + x.shape[-1], self._projector.shape[-1] + ) + elif x.dim() < (self._projector.dim() - 1): + reason = "`x` should have at least {} dimensions but has {}".format( + self._projector.dim() - 1, x.dim() + ) + else: + reason = None + return ok, reason + + def _configure_manifold(self, span): + Q, _ = geoopt.util.linalg.qr(span) + self._projector = Q @ Q.transpose(-1, -2) + + def _project_on_subspace(self, x): + return x @ self._projector.transpose(-1, -2) + + def _proju(self, x, u): + u = super()._proju(x, u) + return self._project_on_subspace(u) + + def _projx(self, x): + x = self._project_on_subspace(x) + return super()._projx(x) + + +class SphereSubspaceComplementIntersection(SphereSubspaceIntersection): + """ + Sphere manifold induced by the following constraint + + .. math:: + + \|x\|=1\\ + x \in \mathbb{span}(U) + + Parameters + ---------- + span : matrix + the subspace to compliment (being orthogonal to) + """ + + def _configure_manifold(self, span): + Q, _ = geoopt.util.linalg.qr(span) + P = -Q @ Q.transpose(-1, -2) + P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1 + self._projector = P diff --git a/geoopt/util/linalg.py b/geoopt/util/linalg.py index e08b87fe..e5d0c06f 100644 --- a/geoopt/util/linalg.py +++ b/geoopt/util/linalg.py @@ -1,38 +1,52 @@ import torch import itertools +import warnings -__all__ = ["svd", "qr", "sym", "extract_diag"] +__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank"] + + +def _warn_lost_grad(x, op): + if torch.is_grad_enabled() and x.requires_grad: + warnings.warn( + "Gradient for operation {0}(...) is lost, please use the pytorch native analog as {0} this " + "is more optimized for internal purposes which do not require gradients but vectorization".format( + op + ) + ) def svd(x): # https://discuss.pytorch.org/t/multidimensional-svd/4366/2 - batches = x.shape[:-2] - if batches: - # in most cases we do not require gradients when applying svd (e.g. in projection) - assert not x.requires_grad - n, m = x.shape[-2:] - k = min(n, m) - U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k) - for idx in itertools.product(*map(range, batches)): - U[idx], d[idx], V[idx] = torch.svd(x[idx]) - return U, d, V - else: - return torch.svd(x) + _warn_lost_grad(x, "geoopt.utils.linalg.svd") + with torch.no_grad(): + batches = x.shape[:-2] + if batches: + # in most cases we do not require gradients when applying svd (e.g. in projection) + n, m = x.shape[-2:] + k = min(n, m) + U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k) + for idx in itertools.product(*map(range, batches)): + U[idx], d[idx], V[idx] = torch.svd(x[idx]) + return U, d, V + else: + return torch.svd(x) def qr(x): # vectorized version as svd - batches = x.shape[:-2] - if batches: - # in most cases we do not require gradients when applying qr (e.g. in retraction) - assert not x.requires_grad - n, m = x.shape[-2:] - Q, R = x.new(*batches, n, m), x.new(*batches, m, m) - for idx in itertools.product(*map(range, batches)): - Q[idx], R[idx] = torch.qr(x[idx]) - return Q, R - else: - return torch.qr(x) + _warn_lost_grad(x, "geoopt.utils.linalg.qr") + with torch.no_grad(): + batches = x.shape[:-2] + if batches: + # in most cases we do not require gradients when applying qr (e.g. in retraction) + assert not x.requires_grad + n, m = x.shape[-2:] + Q, R = x.new(*batches, n, m), x.new(*batches, m, m) + for idx in itertools.product(*map(range, batches)): + Q[idx], R[idx] = torch.qr(x[idx]) + return Q, R + else: + return torch.qr(x) def sym(x): @@ -43,3 +57,15 @@ def extract_diag(x): n, m = x.shape[-2:] k = min(n, m) return x[..., torch.arange(k), torch.arange(k)] + + +def matrix_rank(x): + with torch.no_grad(): + batches = x.shape[:-2] + if batches: + out = x.new(*batches) + for idx in itertools.product(*map(range, batches)): + out[idx] = torch.matrix_rank(x[idx]) + return out + else: + return torch.matrix_rank(x) diff --git a/tests/test_manifold.py b/tests/test_manifold.py index 4d0de31f..377dfdd4 100644 --- a/tests/test_manifold.py +++ b/tests/test_manifold.py @@ -15,6 +15,14 @@ functools.partial(geoopt.manifolds.Stiefel, canonical=True), geoopt.manifolds.Euclidean, geoopt.manifolds.Sphere, + functools.partial( + geoopt.manifolds.SphereSubspaceIntersection, + torch.arange(30).reshape(10, 3).double(), + ), + functools.partial( + geoopt.manifolds.SphereSubspaceComplementIntersection, + torch.arange(30).reshape(10, 3).double(), + ), ], ) def manifold(request): @@ -26,6 +34,13 @@ def manifold(request): geoopt.manifolds.CanonicalStiefel: pymanopt.manifolds.Stiefel, geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean, geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere, + geoopt.manifolds.SphereSubspaceIntersection: functools.partial( + pymanopt.manifolds.SphereSubspaceIntersection, U=np.arange(30).reshape(10, 3) + ), + geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial( + pymanopt.manifolds.SphereSubspaceComplementIntersection, + U=np.arange(30).reshape(10, 3), + ), } # shapes to verify unary element implementation @@ -34,6 +49,8 @@ def manifold(request): geoopt.manifolds.CanonicalStiefel: (10, 5), geoopt.manifolds.Euclidean: (1,), geoopt.manifolds.Sphere: (10,), + geoopt.manifolds.SphereSubspaceIntersection: (10,), + geoopt.manifolds.SphereSubspaceComplementIntersection: (10,), } UnaryCase = collections.namedtuple( @@ -46,7 +63,7 @@ def unary_case(manifold): shape = shapes[type(manifold)] manopt_manifold = mannopt[type(manifold)](*shape) np.random.seed(42) - rand = manopt_manifold.rand() + rand = manopt_manifold.rand().astype('float64') x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold) torch.manual_seed(43) ex = geoopt.ManifoldTensor(torch.randn_like(x), manifold=manifold) @@ -117,7 +134,7 @@ def test_transport(unary_case): def test_broadcast_projx(unary_case): torch.manual_seed(43) - X = torch.randn(4, *unary_case.shape) + X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) pX = unary_case.manifold.projx(X) unary_case.manifold.assert_check_point_on_manifold(pX) for px in pX: @@ -129,8 +146,8 @@ def test_broadcast_projx(unary_case): def test_broadcast_proju(unary_case): torch.manual_seed(43) - X = torch.randn(4, *unary_case.shape) - U = torch.randn(4, *unary_case.shape) + X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) pX = unary_case.manifold.projx(X) pU = unary_case.manifold.proju(pX, U) unary_case.manifold.assert_check_vector_on_tangent(pX, pU) @@ -143,8 +160,8 @@ def test_broadcast_proju(unary_case): def test_broadcast_retr(unary_case): torch.manual_seed(43) - X = torch.randn(4, *unary_case.shape) - U = torch.randn(4, *unary_case.shape) + X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) pX = unary_case.manifold.projx(X) pU = unary_case.manifold.proju(pX, U) Y = unary_case.manifold.retr(pX, pU, 1.0) @@ -158,9 +175,9 @@ def test_broadcast_retr(unary_case): def test_broadcast_transp(unary_case): torch.manual_seed(43) - X = torch.randn(4, *unary_case.shape) - U = torch.randn(4, *unary_case.shape) - V = torch.randn(4, *unary_case.shape) + X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + V = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) pX = unary_case.manifold.projx(X) pU = unary_case.manifold.proju(pX, U) pV = unary_case.manifold.proju(pX, V) @@ -176,10 +193,10 @@ def test_broadcast_transp(unary_case): def test_broadcast_transp_many(unary_case): torch.manual_seed(43) - X = torch.randn(4, *unary_case.shape) - U = torch.randn(4, *unary_case.shape) - V = torch.randn(4, *unary_case.shape) - F = torch.randn(4, *unary_case.shape) + X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + V = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + F = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) pX = unary_case.manifold.projx(X) pU = unary_case.manifold.proju(pX, U) pV = unary_case.manifold.proju(pX, V) @@ -199,10 +216,10 @@ def test_broadcast_transp_many(unary_case): def test_broadcast_retr_transp_many(unary_case): torch.manual_seed(43) - X = torch.randn(4, *unary_case.shape) - U = torch.randn(4, *unary_case.shape) - V = torch.randn(4, *unary_case.shape) - F = torch.randn(4, *unary_case.shape) + X = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + U = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + V = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) + F = torch.randn(4, *unary_case.shape, dtype=unary_case.x.dtype) pX = unary_case.manifold.projx(X) pU = unary_case.manifold.proju(pX, U) pV = unary_case.manifold.proju(pX, V) @@ -224,8 +241,8 @@ def test_broadcast_retr_transp_many(unary_case): def test_reversibility(unary_case): torch.manual_seed(43) - X = torch.randn(*unary_case.shape) - U = torch.randn(*unary_case.shape) + X = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) + U = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) X = unary_case.manifold.projx(X) U = unary_case.manifold.proju(X, U) Z, Q = unary_case.manifold.retr_transp(X, U, 1.0, U) From 5ba6328b31843e71d932fc4c2add11ad49bb8814 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 27 Jan 2019 01:24:06 +0300 Subject: [PATCH 2/8] fix tests due to qr behaviour on a specific case --- tests/test_manifold.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_manifold.py b/tests/test_manifold.py index 377dfdd4..39315e22 100644 --- a/tests/test_manifold.py +++ b/tests/test_manifold.py @@ -17,11 +17,11 @@ geoopt.manifolds.Sphere, functools.partial( geoopt.manifolds.SphereSubspaceIntersection, - torch.arange(30).reshape(10, 3).double(), + torch.from_numpy(np.random.RandomState(42).randn(10, 3)), ), functools.partial( geoopt.manifolds.SphereSubspaceComplementIntersection, - torch.arange(30).reshape(10, 3).double(), + torch.from_numpy(np.random.RandomState(42).randn(10, 3)), ), ], ) @@ -35,11 +35,11 @@ def manifold(request): geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean, geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere, geoopt.manifolds.SphereSubspaceIntersection: functools.partial( - pymanopt.manifolds.SphereSubspaceIntersection, U=np.arange(30).reshape(10, 3) + pymanopt.manifolds.SphereSubspaceIntersection, U=np.random.RandomState(42).randn(10, 3) ), geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial( pymanopt.manifolds.SphereSubspaceComplementIntersection, - U=np.arange(30).reshape(10, 3), + U=np.random.RandomState(42).randn(10, 3), ), } From 76d9e105829bfda956c611bab59db783040ce0be Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 27 Jan 2019 01:48:03 +0300 Subject: [PATCH 3/8] black --- tests/test_manifold.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_manifold.py b/tests/test_manifold.py index 39315e22..61547f5b 100644 --- a/tests/test_manifold.py +++ b/tests/test_manifold.py @@ -35,7 +35,8 @@ def manifold(request): geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean, geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere, geoopt.manifolds.SphereSubspaceIntersection: functools.partial( - pymanopt.manifolds.SphereSubspaceIntersection, U=np.random.RandomState(42).randn(10, 3) + pymanopt.manifolds.SphereSubspaceIntersection, + U=np.random.RandomState(42).randn(10, 3), ), geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial( pymanopt.manifolds.SphereSubspaceComplementIntersection, @@ -63,7 +64,7 @@ def unary_case(manifold): shape = shapes[type(manifold)] manopt_manifold = mannopt[type(manifold)](*shape) np.random.seed(42) - rand = manopt_manifold.rand().astype('float64') + rand = manopt_manifold.rand().astype("float64") x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold) torch.manual_seed(43) ex = geoopt.ManifoldTensor(torch.randn_like(x), manifold=manifold) From baadfbba34f614ce88d7a7a5694339a8f07ac4a2 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 27 Jan 2019 01:49:44 +0300 Subject: [PATCH 4/8] add Sphere to docs --- docs/manifolds.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/manifolds.rst b/docs/manifolds.rst index 942555b0..ae2b6ebf 100644 --- a/docs/manifolds.rst +++ b/docs/manifolds.rst @@ -4,7 +4,7 @@ Manifolds .. currentmodule:: geoopt.manifolds .. automodule:: geoopt.manifolds - :members: Euclidean, Stiefel, Sphere + :members: Euclidean, Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection Extending ``geoopt`` -------------------- From 367a0ad99ddb840c4b81eb9c55b6a646cd52a5db Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 27 Jan 2019 02:03:06 +0300 Subject: [PATCH 5/8] black --- geoopt/manifolds/__init__.py | 6 +++++- geoopt/manifolds/sphere.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/geoopt/manifolds/__init__.py b/geoopt/manifolds/__init__.py index ac117de7..a54c88b6 100644 --- a/geoopt/manifolds/__init__.py +++ b/geoopt/manifolds/__init__.py @@ -1,4 +1,8 @@ from .base import Manifold from .euclidean import Euclidean from .stiefel import Stiefel, EuclideanStiefel, CanonicalStiefel -from .sphere import Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection +from .sphere import ( + Sphere, + SphereSubspaceComplementIntersection, + SphereSubspaceIntersection, +) diff --git a/geoopt/manifolds/sphere.py b/geoopt/manifolds/sphere.py index f2b56984..cc3aa679 100644 --- a/geoopt/manifolds/sphere.py +++ b/geoopt/manifolds/sphere.py @@ -6,7 +6,7 @@ __all__ = [ "Sphere", "SphereSubspaceIntersection", - "SphereSubspaceComplementIntersection" + "SphereSubspaceComplementIntersection", ] From 16aecf6fd6485430db43df1c2e9deb83b958d43d Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 27 Jan 2019 02:33:56 +0300 Subject: [PATCH 6/8] matrix rank no out --- geoopt/_compat.py | 3 +++ geoopt/optim/tracing.py | 3 ++- geoopt/util/linalg.py | 11 +++++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 geoopt/_compat.py diff --git a/geoopt/_compat.py b/geoopt/_compat.py new file mode 100644 index 00000000..f9dba0d3 --- /dev/null +++ b/geoopt/_compat.py @@ -0,0 +1,3 @@ +import torch + +_TORCH_LESS_THAN_ONE = tuple(map(int, torch.__version__.split(".")[:2])) < (1, 0) diff --git a/geoopt/optim/tracing.py b/geoopt/optim/tracing.py index 0d654483..5cbdd1c5 100644 --- a/geoopt/optim/tracing.py +++ b/geoopt/optim/tracing.py @@ -1,8 +1,9 @@ import torch.jit +from .._compat import _TORCH_LESS_THAN_ONE def _compat_trace(fn, args): - if tuple(map(int, torch.__version__.split(".")[:2])) < (1, 0): + if _TORCH_LESS_THAN_ONE: # torch.jit here does not support inplace ops return fn else: diff --git a/geoopt/util/linalg.py b/geoopt/util/linalg.py index e5d0c06f..35b26f5a 100644 --- a/geoopt/util/linalg.py +++ b/geoopt/util/linalg.py @@ -1,6 +1,7 @@ import torch import itertools import warnings +from .._compat import _TORCH_LESS_THAN_ONE __all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank"] @@ -26,7 +27,7 @@ def svd(x): k = min(n, m) U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k) for idx in itertools.product(*map(range, batches)): - U[idx], d[idx], V[idx] = torch.svd(x[idx]) + torch.svd(x[idx], out=(U[idx], d[idx], V[idx])) return U, d, V else: return torch.svd(x) @@ -43,7 +44,7 @@ def qr(x): n, m = x.shape[-2:] Q, R = x.new(*batches, n, m), x.new(*batches, m, m) for idx in itertools.product(*map(range, batches)): - Q[idx], R[idx] = torch.qr(x[idx]) + torch.qr(x[idx], out=(Q[idx], R[idx])) return Q, R else: return torch.qr(x) @@ -60,6 +61,12 @@ def extract_diag(x): def matrix_rank(x): + if _TORCH_LESS_THAN_ONE: + import numpy as np + + return torch.from_numpy( + np.linalg.matrix_rank(x.detach().cpu().numpy()) + ).type_as(x) with torch.no_grad(): batches = x.shape[:-2] if batches: From e561c583746e22ea5bf53e94539befd201b167d3 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 27 Jan 2019 02:48:50 +0300 Subject: [PATCH 7/8] asarray check --- geoopt/util/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geoopt/util/linalg.py b/geoopt/util/linalg.py index 35b26f5a..cb678479 100644 --- a/geoopt/util/linalg.py +++ b/geoopt/util/linalg.py @@ -65,7 +65,7 @@ def matrix_rank(x): import numpy as np return torch.from_numpy( - np.linalg.matrix_rank(x.detach().cpu().numpy()) + np.asarray(np.linalg.matrix_rank(x.detach().cpu().numpy())) ).type_as(x) with torch.no_grad(): batches = x.shape[:-2] From 3d0791dba18fce59c80ae4cd22ce5d9f6ce190cc Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 27 Jan 2019 03:14:09 +0300 Subject: [PATCH 8/8] back to assing implementation --- geoopt/manifolds/sphere.py | 6 +++--- geoopt/util/linalg.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/geoopt/manifolds/sphere.py b/geoopt/manifolds/sphere.py index cc3aa679..f7a0c2b6 100644 --- a/geoopt/manifolds/sphere.py +++ b/geoopt/manifolds/sphere.py @@ -11,7 +11,7 @@ class Sphere(Manifold): - """ + r""" Sphere manifold induced by the following constraint .. math:: @@ -77,7 +77,7 @@ def _retr_transp(self, x, u, t, v, *more): class SphereSubspaceIntersection(Sphere): - """ + r""" Sphere manifold induced by the following constraint .. math:: @@ -135,7 +135,7 @@ def _projx(self, x): class SphereSubspaceComplementIntersection(SphereSubspaceIntersection): - """ + r""" Sphere manifold induced by the following constraint .. math:: diff --git a/geoopt/util/linalg.py b/geoopt/util/linalg.py index cb678479..e0f8c61e 100644 --- a/geoopt/util/linalg.py +++ b/geoopt/util/linalg.py @@ -27,7 +27,7 @@ def svd(x): k = min(n, m) U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k) for idx in itertools.product(*map(range, batches)): - torch.svd(x[idx], out=(U[idx], d[idx], V[idx])) + U[idx], d[idx], V[idx] = torch.svd(x[idx]) return U, d, V else: return torch.svd(x) @@ -44,7 +44,7 @@ def qr(x): n, m = x.shape[-2:] Q, R = x.new(*batches, n, m), x.new(*batches, m, m) for idx in itertools.product(*map(range, batches)): - torch.qr(x[idx], out=(Q[idx], R[idx])) + Q[idx], R[idx] = torch.qr(x[idx]) return Q, R else: return torch.qr(x)