-
-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
72 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,90 @@ | ||
import torch | ||
import itertools | ||
import warnings | ||
from .._compat import _TORCH_LESS_THAN_ONE | ||
|
||
__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank"] | ||
|
||
|
||
def _warn_lost_grad(x, op): | ||
if torch.is_grad_enabled() and x.requires_grad: | ||
warnings.warn( | ||
"Gradient for operation {0}(...) is lost, please use the pytorch native analog as {0} this " | ||
"is more optimized for internal purposes which do not require gradients but vectorization".format( | ||
op | ||
) | ||
) | ||
|
||
|
||
@torch.jit.script | ||
def svd(x): | ||
# inspired by | ||
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2 | ||
_warn_lost_grad(x, "geoopt.utils.linalg.svd") | ||
with torch.no_grad(): | ||
# prolonged here: | ||
if x.dim() == 2: | ||
# 17 milliseconds on my mac to check that condition, that is low overhead | ||
result = torch.svd(x) | ||
else: | ||
batches = x.shape[:-2] | ||
if batches: | ||
# in most cases we do not require gradients when applying svd (e.g. in projection) | ||
n, m = x.shape[-2:] | ||
k = min(n, m) | ||
U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k) | ||
for idx in itertools.product(*map(range, batches)): | ||
U[idx], d[idx], V[idx] = torch.svd(x[idx]) | ||
return U, d, V | ||
else: | ||
return torch.svd(x) | ||
other = x.shape[-2:] | ||
flat = x.view((-1,) + other) | ||
slices = flat.unbind(0) | ||
U, D, V = [], [], [] | ||
# I wish I had a parallel_for | ||
for i in range(flat.shape[0]): | ||
u, d, v = torch.svd(slices[i]) | ||
U.append(u) | ||
D.append(d) | ||
V.append(v) | ||
U = torch.stack(U).view(batches + U[0].shape) | ||
D = torch.stack(D).view(batches + D[0].shape) | ||
V = torch.stack(V).view(batches + V[0].shape) | ||
result = U, D, V | ||
return result | ||
|
||
|
||
@torch.jit.script | ||
def qr(x): | ||
# vectorized version as svd | ||
_warn_lost_grad(x, "geoopt.utils.linalg.qr") | ||
with torch.no_grad(): | ||
# inspired by | ||
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2 | ||
# prolonged here: | ||
if x.dim() == 2: | ||
result = torch.qr(x) | ||
else: | ||
batches = x.shape[:-2] | ||
if batches: | ||
# in most cases we do not require gradients when applying qr (e.g. in retraction) | ||
assert not x.requires_grad | ||
n, m = x.shape[-2:] | ||
Q, R = x.new(*batches, n, m), x.new(*batches, m, m) | ||
for idx in itertools.product(*map(range, batches)): | ||
Q[idx], R[idx] = torch.qr(x[idx]) | ||
return Q, R | ||
else: | ||
return torch.qr(x) | ||
other = x.shape[-2:] | ||
flat = x.view((-1,) + other) | ||
slices = flat.unbind(0) | ||
Q, R = [], [] | ||
# I wish I had a parallel_for | ||
for i in range(flat.shape[0]): | ||
q, r = torch.qr(slices[i]) | ||
Q.append(q) | ||
R.append(r) | ||
Q = torch.stack(Q).view(batches + Q[0].shape) | ||
R = torch.stack(R).view(batches + R[0].shape) | ||
result = Q, R | ||
return result | ||
|
||
|
||
@torch.jit.script | ||
def sym(x): | ||
return 0.5 * (x.transpose(-1, -2) + x) | ||
|
||
|
||
@torch.jit.script | ||
def extract_diag(x): | ||
n, m = x.shape[-2:] | ||
k = min(n, m) | ||
return x[..., torch.arange(k), torch.arange(k)] | ||
batch = x.shape[:-2] | ||
k = n if n < m else m | ||
idx = torch.arange(k, dtype=torch.long, device=x.device) | ||
x = x.view(-1, n, m) | ||
return x[:, idx, idx].view(batch + (k,)) | ||
|
||
|
||
@torch.jit.script | ||
def matrix_rank(x): | ||
if _TORCH_LESS_THAN_ONE: | ||
import numpy as np | ||
|
||
return torch.from_numpy( | ||
np.asarray(np.linalg.matrix_rank(x.detach().cpu().numpy())) | ||
).type_as(x) | ||
with torch.no_grad(): | ||
# inspired by | ||
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2 | ||
# prolonged here: | ||
if x.dim() == 2: | ||
ranks = torch.matrix_rank(x) | ||
else: | ||
batches = x.shape[:-2] | ||
if batches: | ||
out = x.new(*batches) | ||
for idx in itertools.product(*map(range, batches)): | ||
out[idx] = torch.matrix_rank(x[idx]) | ||
return out | ||
else: | ||
return torch.matrix_rank(x) | ||
other = x.shape[-2:] | ||
flat = x.view((-1,) + other) | ||
slices = flat.unbind(0) | ||
ranks = [] | ||
# I wish I had a parallel_for | ||
for i in range(flat.shape[0]): | ||
r = torch.matrix_rank(slices[i]) | ||
ranks.append(r) | ||
ranks = torch.stack(ranks).view(batches) | ||
return ranks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters