Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Minimal prototype for KFAC #43

Merged
merged 8 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from curvlinops.hessian import HessianLinearOperator
from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator
from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator
from curvlinops.kfac import KFACLinearOperator
from curvlinops.papyan2020traces.spectrum import (
LanczosApproximateLogSpectrumCached,
LanczosApproximateSpectrumCached,
Expand All @@ -22,6 +23,7 @@
"GGNLinearOperator",
"EFLinearOperator",
"FisherMCLinearOperator",
"KFACLinearOperator",
"JacobianLinearOperator",
"TransposedJacobianLinearOperator",
"CGInverseLinearOperator",
Expand Down
8 changes: 8 additions & 0 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Contains functionality to analyze Hessian & GGN via matrix-free multiplication."""

from typing import Callable, Iterable, List, Optional, Tuple, Union
from warnings import warn

from backpack.utils.convert_parameters import vector_to_parameter_list
from numpy import (
Expand Down Expand Up @@ -254,6 +255,13 @@ def _preprocess(self, x: ndarray) -> List[Tensor]:
Returns:
Vector in list format.
"""
if x.dtype != self.dtype:
warn(
f"Input vector is {x.dtype}, while linear operator is {self.dtype}. "
+ f"Converting to {self.dtype}."
)
x = x.astype(self.dtype)

x_torch = from_numpy(x).to(self._device)
return vector_to_parameter_list(x_torch, self._params)

Expand Down
4 changes: 3 additions & 1 deletion curvlinops/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ def report_nonclose(
else:
for a1, a2 in zip(array1.flatten(), array2.flatten()):
if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan):
print(f"{a1} ≠ {a2}")
print(f"{a1} ≠ {a2} (ratio {a1 / a2:.5f})")
print(f"Max: {array1.max():.5f}, {array2.max():.5f}")
print(f"Min: {array1.min():.5f}, {array2.min():.5f}")
raise ValueError("Compared arrays don't match.")
Loading
Loading