Skip to content

Commit

Permalink
Merge branch 'main' into kfac-state-dict
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 23, 2024
2 parents 1cf85bd + e382e59 commit fa119b0
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 7 deletions.
12 changes: 12 additions & 0 deletions curvlinops/submatrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Implements slices of linear operators."""

from __future__ import annotations

from typing import List

from numpy import column_stack, ndarray, zeros
Expand Down Expand Up @@ -78,3 +80,13 @@ def _matmat(self, X: ndarray) -> ndarray:
``A[row_idxs, :][:, col_idxs] @ x``. Has shape ``[len(row_idxs), N]``.
"""
return column_stack([self @ col for col in X.T])

def _adjoint(self) -> SubmatrixLinearOperator:
"""Return the adjoint of the sub-matrix.
For that, we need to take the adjoint operator, and swap row and column indices.
Returns:
The linear operator for the adjoint sub-matrix.
"""
return type(self)(self._A.adjoint(), self._col_idxs, self._row_idxs)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ lint =

# Dependencies needed to build/view the documentation (semicolon/line-separated)
docs =
setuptools==69.5.1 # RTD fails with setuptools>=70, see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/15863
transformers
datasets
matplotlib
Expand Down
45 changes: 38 additions & 7 deletions test/test_submatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Tuple

from numpy import eye, ndarray, random
from pytest import fixture, raises
from pytest import fixture, mark, raises
from scipy.sparse.linalg import aslinearoperator

from curvlinops.examples.utils import report_nonclose
Expand Down Expand Up @@ -34,29 +34,60 @@ def submatrix_case(request) -> Tuple[ndarray, List[int], List[int]]:
return case["A_fn"](), case["row_idxs_fn"](), case["col_idxs_fn"]()


def test_SubmatrixLinearOperator__matvec(submatrix_case):
@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"])
def test_SubmatrixLinearOperator__matvec(
submatrix_case: Tuple[ndarray, List[int], List[int]], adjoint: bool
):
"""Test the matrix-vector multiplication of a submatrix linear operator.
Args:
submatrix_case: A tuple with a random matrix and two index lists.
adjoint: Whether to take the operator's adjoint before multiplying.
"""
A, row_idxs, col_idxs = submatrix_case

A_sub = A[row_idxs, :][:, col_idxs]
A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs)

x = random.rand(len(col_idxs))
if adjoint:
A_sub = A_sub.conj().T
A_sub_linop = A_sub_linop.adjoint()

x = random.rand(A_sub.shape[1])
A_sub_linop_x = A_sub_linop @ x

assert A_sub_linop_x.shape == (len(row_idxs),)
assert A_sub_linop_x.shape == ((len(col_idxs),) if adjoint else (len(row_idxs),))
report_nonclose(A_sub @ x, A_sub_linop_x)


def test_SubmatrixLinearOperator__matmat(submatrix_case, num_vecs: int = 3):
@mark.parametrize("adjoint", [False, True], ids=["", "adjoint"])
def test_SubmatrixLinearOperator__matmat(
submatrix_case: Tuple[ndarray, List[int], List[int]],
adjoint: bool,
num_vecs: int = 3,
):
"""Test the matrix-matrix multiplication of a submatrix linear operator.
Args:
submatrix_case: A tuple with a random matrix and two index lists.
adjoint: Whether to take the operator's adjoint before multiplying.
num_vecs: The number of vectors to multiply. Default: ``3``.
"""
A, row_idxs, col_idxs = submatrix_case

A_sub = A[row_idxs, :][:, col_idxs]
A_sub_linop = SubmatrixLinearOperator(aslinearoperator(A), row_idxs, col_idxs)

X = random.rand(len(col_idxs), num_vecs)
if adjoint:
A_sub = A_sub.conj().T
A_sub_linop = A_sub_linop.adjoint()

X = random.rand(A_sub.shape[1], num_vecs)
A_sub_linop_X = A_sub_linop @ X

assert A_sub_linop_X.shape == (len(row_idxs), num_vecs)
assert A_sub_linop_X.shape == (
(len(col_idxs), num_vecs) if adjoint else (len(row_idxs), num_vecs)
)
report_nonclose(A_sub @ X, A_sub_linop_X)


Expand Down

0 comments on commit fa119b0

Please sign in to comment.