Skip to content

Commit

Permalink
[ADD] state_dict functionality to KFACLinearOperator and `KFACInv…
Browse files Browse the repository at this point in the history
…erseLinearOperator` (#114)

* Add tests for state dict functionality

* Add state dict functionality to (inverse) KFAC linear operator

* Fix tests

* Address review comments on tests

* Test torch.save/load as well and fix order of equivalence checks

* Check if covariance and mapping keys match when loading state dict

* Use compare_state_dicts everywhere
  • Loading branch information
runame committed May 23, 2024
1 parent e382e59 commit d509cd6
Show file tree
Hide file tree
Showing 5 changed files with 464 additions and 6 deletions.
67 changes: 66 additions & 1 deletion curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implements linear operator inverses."""

from math import sqrt
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

from einops import einsum, rearrange
Expand Down Expand Up @@ -695,3 +695,68 @@ def _matmat(self, M: ndarray) -> ndarray:
M_torch = self._A._preprocess(M)
M_torch = self.torch_matmat(M_torch)
return self._A._postprocess(M_torch)

def state_dict(self) -> Dict[str, Any]:
"""Return the state of the inverse KFAC linear operator.
Returns:
State dictionary.
"""
return {
"A": self._A.state_dict(),
# Attributes
"damping": self._damping,
"use_heuristic_damping": self._use_heuristic_damping,
"min_damping": self._min_damping,
"use_exact_damping": self._use_exact_damping,
"cache": self._cache,
"retry_double_precision": self._retry_double_precision,
# Inverse Kronecker factors (if computed and cached)
"inverse_input_covariances": self._inverse_input_covariances,
"inverse_gradient_covariances": self._inverse_gradient_covariances,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the state of the inverse KFAC linear operator.
Args:
state_dict: State dictionary.
"""
self._A.load_state_dict(state_dict["A"])

# Set attributes
self._damping = state_dict["damping"]
self._use_heuristic_damping = state_dict["use_heuristic_damping"]
self._min_damping = state_dict["min_damping"]
self._use_exact_damping = state_dict["use_exact_damping"]
self._cache = state_dict["cache"]
self._retry_double_precision = state_dict["retry_double_precision"]

# Set inverse Kronecker factors (if computed and cached)
self._inverse_input_covariances = state_dict["inverse_input_covariances"]
self._inverse_gradient_covariances = state_dict["inverse_gradient_covariances"]

@classmethod
def from_state_dict(
cls, state_dict: Dict[str, Any], A: KFACLinearOperator
) -> "KFACInverseLinearOperator":
"""Load an inverse KFAC linear operator from a state dictionary.
Args:
state_dict: State dictionary.
A: ``KFACLinearOperator`` whose inverse is formed.
Returns:
Linear operator of inverse KFAC approximation.
"""
inv_kfac = cls(
A,
damping=state_dict["damping"],
use_heuristic_damping=state_dict["use_heuristic_damping"],
min_damping=state_dict["min_damping"],
use_exact_damping=state_dict["use_exact_damping"],
cache=state_dict["cache"],
retry_double_precision=state_dict["retry_double_precision"],
)
inv_kfac.load_state_dict(state_dict)
return inv_kfac
174 changes: 172 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from collections.abc import MutableMapping
from functools import partial
from math import sqrt
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from einops import einsum, rearrange, reduce
from numpy import ndarray
Expand Down Expand Up @@ -111,7 +111,7 @@ class KFACLinearOperator(_LinearOperator):
def __init__( # noqa: C901
self,
model_func: Module,
loss_func: MSELoss,
loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
Expand Down Expand Up @@ -1070,3 +1070,173 @@ def frobenius_norm(self) -> Tensor:
)
self._frobenius_norm.sqrt_()
return self._frobenius_norm

def state_dict(self) -> Dict[str, Any]:
"""Return the state of the KFAC linear operator.
Returns:
State dictionary.
"""
loss_type = {
MSELoss: "MSELoss",
CrossEntropyLoss: "CrossEntropyLoss",
BCEWithLogitsLoss: "BCEWithLogitsLoss",
}[type(self._loss_func)]
return {
# Model and loss function
"model_func_state_dict": self._model_func.state_dict(),
"loss_type": loss_type,
"loss_reduction": self._loss_func.reduction,
# Attributes
"progressbar": self._progressbar,
"shape": self.shape,
"seed": self._seed,
"fisher_type": self._fisher_type,
"mc_samples": self._mc_samples,
"kfac_approx": self._kfac_approx,
"loss_average": self._loss_average,
"num_per_example_loss_terms": self._num_per_example_loss_terms,
"separate_weight_and_bias": self._separate_weight_and_bias,
"num_data": self._N_data,
# Kronecker factors (if computed)
"input_covariances": self._input_covariances,
"gradient_covariances": self._gradient_covariances,
# Properties (not necessarily computed)
"trace": self._trace,
"det": self._det,
"logdet": self._logdet,
"frobenius_norm": self._frobenius_norm,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the state of the KFAC linear operator.
Args:
state_dict: State dictionary.
Raises:
ValueError: If the loss function does not match the state dict.
ValueError: If the loss function reduction does not match the state dict.
"""
self._model_func.load_state_dict(state_dict["model_func_state_dict"])
# Verify that the loss function and its reduction match the state dict
loss_func_type = {
"MSELoss": MSELoss,
"CrossEntropyLoss": CrossEntropyLoss,
"BCEWithLogitsLoss": BCEWithLogitsLoss,
}[state_dict["loss_type"]]
if not isinstance(self._loss_func, loss_func_type):
raise ValueError(
f"Loss function mismatch: {loss_func_type} != {type(self._loss_func)}."
)
if state_dict["loss_reduction"] != self._loss_func.reduction:
raise ValueError(
"Loss function reduction mismatch: "
f"{state_dict['loss_reduction']} != {self._loss_func.reduction}."
)

# Set attributes
self._progressbar = state_dict["progressbar"]
self.shape = state_dict["shape"]
self._seed = state_dict["seed"]
self._fisher_type = state_dict["fisher_type"]
self._mc_samples = state_dict["mc_samples"]
self._kfac_approx = state_dict["kfac_approx"]
self._loss_average = state_dict["loss_average"]
self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"]
self._separate_weight_and_bias = state_dict["separate_weight_and_bias"]
self._N_data = state_dict["num_data"]

# Set Kronecker factors (if computed)
if self._input_covariances or self._gradient_covariances:
# If computed, check if the keys match the mapping keys
input_covariances_keys = set(self._input_covariances.keys())
gradient_covariances_keys = set(self._gradient_covariances.keys())
mapping_keys = set(self._mapping.keys())
if (
input_covariances_keys != mapping_keys
or gradient_covariances_keys != mapping_keys
):
raise ValueError(
"Input or gradient covariance keys in state dict do not match "
"mapping keys of linear operator. "
"Difference between input covariance and mapping keys: "
f"{input_covariances_keys - mapping_keys}. "
"Difference between gradient covariance and mapping keys: "
f"{gradient_covariances_keys - mapping_keys}."
)
self._input_covariances = state_dict["input_covariances"]
self._gradient_covariances = state_dict["gradient_covariances"]

# Set properties (not necessarily computed)
self._trace = state_dict["trace"]
self._det = state_dict["det"]
self._logdet = state_dict["logdet"]
self._frobenius_norm = state_dict["frobenius_norm"]

@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, Any],
model_func: Module,
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
check_deterministic: bool = True,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
) -> KFACLinearOperator:
"""Load a KFAC linear operator from a state dictionary.
Args:
state_dict: State dictionary.
model_func: The model function.
params: The model's parameters that KFAC is computed for.
data: A data loader containing the data of the Fisher/GGN.
check_deterministic: Whether to check that the linear operator is
deterministic. Defaults to ``True``.
batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this
needs to be specified. The intended behavior is to consume the first
entry of the iterates from ``data`` and return their batch size.
Returns:
Linear operator of KFAC approximation.
Raises:
RuntimeError: If the check for deterministic behavior fails.
"""
loss_func = {
"MSELoss": MSELoss,
"CrossEntropyLoss": CrossEntropyLoss,
"BCEWithLogitsLoss": BCEWithLogitsLoss,
}[state_dict["loss_type"]](reduction=state_dict["loss_reduction"])
kfac = cls(
model_func,
loss_func,
params,
data,
batch_size_fn=batch_size_fn,
check_deterministic=False,
progressbar=state_dict["progressbar"],
shape=state_dict["shape"],
seed=state_dict["seed"],
fisher_type=state_dict["fisher_type"],
mc_samples=state_dict["mc_samples"],
kfac_approx=state_dict["kfac_approx"],
loss_average=state_dict["loss_average"],
num_per_example_loss_terms=state_dict["num_per_example_loss_terms"],
separate_weight_and_bias=state_dict["separate_weight_and_bias"],
num_data=state_dict["num_data"],
)
kfac.load_state_dict(state_dict)

# Potentially call `check_deterministic` after the state dict is loaded
if check_deterministic:
old_device = kfac._device
kfac.to_device(device("cpu"))
try:
kfac._check_deterministic()
except RuntimeError as e:
raise e
finally:
kfac.to_device(old_device)

return kfac
82 changes: 81 additions & 1 deletion test/test_inverse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Contains tests for ``curvlinops/inverse``."""

import os
from math import sqrt
from test.utils import cast_input
from test.utils import cast_input, compare_state_dicts
from typing import Iterable, List, Tuple, Union

import torch
Expand Down Expand Up @@ -654,3 +655,82 @@ def test_KFAC_inverse_damped_torch_matvec(

# Test against _matmat
report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy())


def test_KFAC_inverse_save_and_load_state_dict():
"""Test that KFACInverseLinearOperator can be saved and loaded from state dict."""
torch.manual_seed(0)
batch_size, D_in, D_out = 4, 3, 2
X = torch.rand(batch_size, D_in)
y = torch.rand(batch_size, D_out)
model = torch.nn.Linear(D_in, D_out)

params = list(model.parameters())
# create and compute KFAC
kfac = KFACLinearOperator(
model,
MSELoss(reduction="sum"),
params,
[(X, y)],
loss_average=None,
)

# create inverse KFAC
inv_kfac = KFACInverseLinearOperator(
kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False
)
_ = inv_kfac @ eye(kfac.shape[1]) # to trigger inverse computation

# save state dict
state_dict = inv_kfac.state_dict()
torch.save(state_dict, "inv_kfac_state_dict.pt")

# create new inverse KFAC with different linop input and try to load state dict
wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)])
inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac)
with raises(ValueError, match="mismatch"):
inv_kfac_wrong.load_state_dict(torch.load("inv_kfac_state_dict.pt"))

# create new inverse KFAC and load state dict
inv_kfac_new = KFACInverseLinearOperator(kfac)
inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt"))
# clean up
os.remove("inv_kfac_state_dict.pt")

# check that the two inverse KFACs are equal
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())
test_vec = torch.rand(inv_kfac.shape[1])
report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec)


def test_KFAC_inverse_from_state_dict():
"""Test that KFACInverseLinearOperator can be created from state dict."""
torch.manual_seed(0)
batch_size, D_in, D_out = 4, 3, 2
X = torch.rand(batch_size, D_in)
y = torch.rand(batch_size, D_out)
model = torch.nn.Linear(D_in, D_out)

params = list(model.parameters())
# create and compute KFAC
kfac = KFACLinearOperator(
model,
MSELoss(reduction="sum"),
params,
[(X, y)],
loss_average=None,
)

# create inverse KFAC and save state dict
inv_kfac = KFACInverseLinearOperator(
kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False
)
state_dict = inv_kfac.state_dict()

# create new KFAC from state dict
inv_kfac_new = KFACInverseLinearOperator.from_state_dict(state_dict, kfac)

# check that the two inverse KFACs are equal
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())
test_vec = torch.rand(kfac.shape[1])
report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec)
Loading

0 comments on commit d509cd6

Please sign in to comment.