Skip to content

Commit

Permalink
make inner and _stein_metric support keepdim
Browse files Browse the repository at this point in the history
  • Loading branch information
tao-harald committed Nov 8, 2020
1 parent f6405be commit a74c455
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
19 changes: 19 additions & 0 deletions geoopt/linalg/batch_linalg.py
@@ -1,3 +1,5 @@
from typing import List, Callable, Tuple
import torch
import torch.jit
from . import _expm

Expand Down Expand Up @@ -91,6 +93,23 @@ def block_matrix(blocks: List[List[torch.Tensor]], dim0: int = -2, dim1: int = -
return torch.cat(hblocks, dim=dim0)


@torch.jit.script
def trace(x: torch.Tensor) -> torch.Tensor:
r"""self-implemented matrix trace, since `torch.trace` only support 2-d input.
Parameters
----------
x : torch.Tensor
input matrix
Returns
-------
torch.Tensor
:math:`\operationname{Tr}(x)`
"""
return torch.diagonal(x, dim1=-2, dim2=-1).sum(-1)


def sym_funcm(
x: torch.Tensor, func: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
Expand Down
12 changes: 7 additions & 5 deletions geoopt/manifolds/symmetric_positive_definite.py
Expand Up @@ -99,9 +99,10 @@ def _stein_metric(
def log_det(tensor: torch.Tensor) -> torch.Tensor:
return torch.log(torch.det(tensor))

ret = log_det((x + y) * 0.5) - 0.5 * log_det(x @ y)
if keepdim:
raise ValueError("`torch.det` doesn't support keepdim.")
return log_det((x + y) * 0.5) - 0.5 * log_det(x @ y)
return torch.unsqueeze(torch.unsqueeze(ret, -1), -1)
return ret

def _log_eucliden_metric(
self, x: torch.Tensor, y: torch.Tensor, keepdim=False
Expand Down Expand Up @@ -231,10 +232,11 @@ def inner(
"""
if v is None:
v = u
if keepdim:
raise ValueError("`torch.trace` doesn't support keepdim.")
inv_x = batch_linalg.sym_invm(x)
return torch.trace(inv_x @ u @ inv_x @ v)
ret = batch_linalg.trace(inv_x @ u @ inv_x @ v)
if keepdim:
return torch.unsqueeze(torch.unsqueeze(ret, -1), -1)
return ret

def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
inv_x = batch_linalg.sym_invm(x)
Expand Down

0 comments on commit a74c455

Please sign in to comment.