diff --git a/geoopt/manifolds/symmetric_positive_definite.py b/geoopt/manifolds/symmetric_positive_definite.py index 77738f19..c644d051 100755 --- a/geoopt/manifolds/symmetric_positive_definite.py +++ b/geoopt/manifolds/symmetric_positive_definite.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Optional, Tuple, Union import torch from .base import Manifold @@ -255,5 +254,13 @@ def extra_repr(self) -> str: def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: inv_sqrt_x, sqrt_x = batch_linalg.sym_inv_sqrtm2(x) - exp_x_y = batch_linalg.sym_expm(0.5 * batch_linalg.sym_logm(inv_sqrt_x @ y @ inv_sqrt_x)) - return sqrt_x @ exp_x_y @ batch_linalg.sym(inv_sqrt_x @ v @ inv_sqrt_x) @ exp_x_y @ sqrt_x \ No newline at end of file + exp_x_y = batch_linalg.sym_expm( + 0.5 * batch_linalg.sym_logm(inv_sqrt_x @ y @ inv_sqrt_x) + ) + return ( + sqrt_x + @ exp_x_y + @ batch_linalg.sym(inv_sqrt_x @ v @ inv_sqrt_x) + @ exp_x_y + @ sqrt_x + )