Skip to content

Commit

Permalink
Merge pull request #93 from naefjo/feature/online-learning-improvements
Browse files Browse the repository at this point in the history
linear_operator cat_rows performance improvement
  • Loading branch information
Balandat committed Mar 18, 2024
2 parents eb28640 + 0a94f8b commit a0a9c42
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
10 changes: 5 additions & 5 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,17 +1267,17 @@ def cat_rows(
R = self.root_inv_decomposition().root.to_dense() # RR^T = A^{-1} (this is fast if L is triangular)
lower_left = B_ @ R # F = BR
schur = D - lower_left.matmul(lower_left.mT) # GG^T = new_mat - FF^T
schur_root = to_linear_operator(schur).root_decomposition().root.to_dense() # G = (new_mat - FF^T)^{1/2}
schur_root = to_linear_operator(schur).root_decomposition().root # G = (new_mat - FF^T)^{1/2}

# Form new root matrix
num_fant = schur_root.size(-2)
new_root = torch.zeros(*batch_shape, m + num_fant, n + num_fant, device=E.device, dtype=E.dtype)
new_root[..., :m, :n] = E.to_dense()
new_root[..., m:, : lower_left.shape[-1]] = lower_left
new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root
new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root.to_dense()
if generate_inv_roots:
if isinstance(E, TriangularLinearOperator) and isinstance(schur_root, TriangularLinearOperator):
# make sure these are actually upper triangular
# make sure these are actually lower triangular
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
raise NotImplementedError
# in this case we know new_root is triangular as well
Expand Down Expand Up @@ -2207,8 +2207,8 @@ def root_inv_decomposition(
:param method: Root decomposition method to use (symeig, diagonalization, lanczos, or cholesky).
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`.
"""
from linear_operator.operators.dense_linear_operator import to_linear_operator
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator

if not self.is_square:
raise RuntimeError(
Expand All @@ -2229,7 +2229,7 @@ def root_inv_decomposition(
# we don't need the batch shape here, thanks to broadcasting
Eye = torch.eye(L.shape[-2], device=L.device, dtype=L.dtype)
Linv = torch.linalg.solve_triangular(L, Eye, upper=False)
res = to_linear_operator(Linv.mT)
res = TriangularLinearOperator(Linv.mT, upper=True)
inv_root = res
elif method == "lanczos":
if initial_vectors is not None:
Expand Down
10 changes: 10 additions & 0 deletions linear_operator/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,13 @@ class verbose_linalg(_feature_flag):
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)


class stable_qr_cpu_threshold(_value_context):
"""
Matrix size threshold below which to perform `torch.qr` on the cpu.
(Default: 128)
"""

_global_value = 128
4 changes: 3 additions & 1 deletion linear_operator/utils/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from linear_operator.settings import stable_qr_cpu_threshold


def stable_qr(mat):
"""
Expand All @@ -11,7 +13,7 @@ def stable_qr(mat):
1. slow batched QR in pytorch (pytorch/pytorch#22573)
2. possible singularity in R
"""
if mat.shape[-1] <= 2048:
if mat.shape[-1] <= stable_qr_cpu_threshold.value():
# Dispatch to CPU so long as pytorch/pytorch#22573 is not fixed
device = mat.device
Q, R = torch.linalg.qr(mat.cpu())
Expand Down

0 comments on commit a0a9c42

Please sign in to comment.