Skip to content

Commit

Permalink
Merge 650cb18 into 6f1f5e3
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jun 21, 2021
2 parents 6f1f5e3 + 650cb18 commit 0a27ebb
Show file tree
Hide file tree
Showing 22 changed files with 678 additions and 16 deletions.
2 changes: 2 additions & 0 deletions backpack/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DiagGGNExact,
DiagGGNMC,
DiagHessian,
SqrtGGNExact,
)

__all__ = [
Expand All @@ -33,4 +34,5 @@
"BatchDiagGGNMC",
"DiagHessian",
"BatchDiagHessian",
"SqrtGGNExact",
]
16 changes: 13 additions & 3 deletions backpack/extensions/secondorder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,20 @@
:func:`KFRA <backpack.extensions.KFRA>`,
:func:`KFLR <backpack.extensions.KFLR>`.
- The diagonal of the Hessian :func:`DiagHessian <backpack.extensions.DiagHessian>`
- The symmetric (square root) factorization of the GGN/Fisher information,
using exact computation
(:func:`SqrtGGNExact <backpack.extensions.SqrtGGNExact>`)
"""

from .diag_ggn import BatchDiagGGNExact, BatchDiagGGNMC, DiagGGNExact, DiagGGNMC
from .diag_hessian import BatchDiagHessian, DiagHessian
from .hbp import HBP, KFAC, KFLR, KFRA
from backpack.extensions.secondorder.diag_ggn import (
BatchDiagGGNExact,
BatchDiagGGNMC,
DiagGGNExact,
DiagGGNMC,
)
from backpack.extensions.secondorder.diag_hessian import BatchDiagHessian, DiagHessian
from backpack.extensions.secondorder.hbp import HBP, KFAC, KFLR, KFRA
from backpack.extensions.secondorder.sqrt_ggn import SqrtGGNExact

__all__ = [
"DiagGGNExact",
Expand All @@ -34,4 +43,5 @@
"KFLR",
"KFRA",
"HBP",
"SqrtGGNExact",
]
121 changes: 121 additions & 0 deletions backpack/extensions/secondorder/sqrt_ggn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Defines base class and extensions for computing the GGN/Fisher matrix square root."""

from torch.nn import (
ELU,
SELU,
AvgPool1d,
AvgPool2d,
AvgPool3d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
CrossEntropyLoss,
Dropout,
Flatten,
LeakyReLU,
Linear,
LogSigmoid,
MaxPool1d,
MaxPool2d,
MaxPool3d,
MSELoss,
ReLU,
Sigmoid,
Tanh,
ZeroPad2d,
)

from backpack.extensions.backprop_extension import BackpropExtension
from backpack.extensions.secondorder.hbp import LossHessianStrategy
from backpack.extensions.secondorder.sqrt_ggn import (
activations,
convnd,
convtransposend,
dropout,
flatten,
linear,
losses,
padding,
pooling,
)


class SqrtGGN(BackpropExtension):
"""Base class for extensions that compute the GGN/Fisher matrix square root."""

def __init__(self, loss_hessian_strategy: str, savefield: str):
"""Store approximation for backpropagated object and where to save the result.
Args:
loss_hessian_strategy: Which approximation is used for the backpropagated
loss Hessian. Must be ``'exact'`` or ``'sampling'``.
savefield: Attribute under which the quantity is saved in a parameter.
"""
self.loss_hessian_strategy = loss_hessian_strategy
super().__init__(
savefield=savefield,
fail_mode="ERROR",
module_exts={
MSELoss: losses.SqrtGGNMSELoss(),
CrossEntropyLoss: losses.SqrtGGNCrossEntropyLoss(),
Linear: linear.SqrtGGNLinear(),
MaxPool1d: pooling.SqrtGGNMaxPool1d(),
MaxPool2d: pooling.SqrtGGNMaxPool2d(),
AvgPool1d: pooling.SqrtGGNAvgPool1d(),
MaxPool3d: pooling.SqrtGGNMaxPool3d(),
AvgPool2d: pooling.SqrtGGNAvgPool2d(),
AvgPool3d: pooling.SqrtGGNAvgPool3d(),
ZeroPad2d: padding.SqrtGGNZeroPad2d(),
Conv1d: convnd.SqrtGGNConv1d(),
Conv2d: convnd.SqrtGGNConv2d(),
Conv3d: convnd.SqrtGGNConv3d(),
ConvTranspose1d: convtransposend.SqrtGGNConvTranspose1d(),
ConvTranspose2d: convtransposend.SqrtGGNConvTranspose2d(),
ConvTranspose3d: convtransposend.SqrtGGNConvTranspose3d(),
Dropout: dropout.SqrtGGNDropout(),
Flatten: flatten.SqrtGGNFlatten(),
ReLU: activations.SqrtGGNReLU(),
Sigmoid: activations.SqrtGGNSigmoid(),
Tanh: activations.SqrtGGNTanh(),
LeakyReLU: activations.SqrtGGNLeakyReLU(),
LogSigmoid: activations.SqrtGGNLogSigmoid(),
ELU: activations.SqrtGGNELU(),
SELU: activations.SqrtGGNSELU(),
},
)

def get_loss_hessian_strategy(self) -> str:
"""Return the strategy used to represent the backpropagated loss Hessian.
Returns:
Loss Hessian strategy.
"""
return self.loss_hessian_strategy


class SqrtGGNExact(SqrtGGN):
"""Exact matrix square root of the generalized Gauss-Newton/Fisher.
Uses the exact Hessian of the loss w.r.t. the model output.
Stores the output in :code:`sqrt_ggn_exact`, has shape ``[C, N, param.shape]``,
where ``C`` is the model output dimension (number of classes for classification
problems) and ``N`` is the batch size.
For a faster but less precise alternative, see
:py:meth:`backpack.extensions.SqrtGGNMC`.
.. note::
(Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_exact``
can be viewed as a ``[C * N, param.numel()]`` matrix. Concatenating this
matrix over all parameters results in a matrix ``Vᵀ``, which
is the GGN/Fisher's matrix square root, i.e. ``G = V Vᵀ``.
"""

def __init__(self):
"""Use exact loss Hessian and set savefield to ``sqrt_ggn_exact``."""
super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact")
65 changes: 65 additions & 0 deletions backpack/extensions/secondorder/sqrt_ggn/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Contains extensions for activation layers used by ``SqrtGGN{Exact, MC}``."""
from backpack.core.derivatives.elu import ELUDerivatives
from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives
from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives
from backpack.core.derivatives.relu import ReLUDerivatives
from backpack.core.derivatives.selu import SELUDerivatives
from backpack.core.derivatives.sigmoid import SigmoidDerivatives
from backpack.core.derivatives.tanh import TanhDerivatives
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule


class SqrtGGNReLU(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ReLU`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.ReLU`` module."""
super().__init__(ReLUDerivatives())


class SqrtGGNSigmoid(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Sigmoid`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.Sigmoid`` module."""
super().__init__(SigmoidDerivatives())


class SqrtGGNTanh(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Tanh`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.Tanh`` module."""
super().__init__(TanhDerivatives())


class SqrtGGNELU(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ELU`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.ELU`` module."""
super().__init__(ELUDerivatives())


class SqrtGGNSELU(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.SELU`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.SELU`` module."""
super().__init__(SELUDerivatives())


class SqrtGGNLeakyReLU(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LeakyReLU`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.LeakyReLU`` module."""
super().__init__(LeakyReLUDerivatives())


class SqrtGGNLogSigmoid(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LogSigmoid`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.LogSigmoid`` module."""
super().__init__(LogSigmoidDerivatives())
76 changes: 76 additions & 0 deletions backpack/extensions/secondorder/sqrt_ggn/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Contains base class for ``SqrtGGN{Exact, MC}`` module extensions."""
from typing import Any, Callable, List, Tuple, Union

from torch import Tensor
from torch.nn import Module

from backpack.core.derivatives.basederivatives import (
BaseDerivatives,
BaseParameterDerivatives,
)
from backpack.extensions.mat_to_mat_jac_base import MatToJacMat


class SqrtGGNBaseModule(MatToJacMat):
"""Base module extension for ``SqrtGGN{Exact, MC}``."""

def __init__(
self,
derivatives: Union[BaseParameterDerivatives, BaseDerivatives],
params: List[str] = None,
):
"""Store parameter names and derivatives.
Sets up methods that extract the GGN/Fisher matrix square root for the
passed parameters, unless these methods are overwritten by a child class.
Args:
derivatives: derivatives object.
params: List of parameter names. Defaults to None.
"""
if params is not None:
for param_str in params:
if not hasattr(self, param_str):
setattr(self, param_str, self._make_param_function(param_str))

super().__init__(derivatives, params=params)

# TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC]
# WAITING Deprecation of python3.6 (cyclic imports caused by annotations)
def _make_param_function(
self, param_str: str
) -> Callable[[Any, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor]:
"""Create a function that computes the GGN/Fisher square root for a parameter.
Args:
param_str: name of parameter
Returns:
Function that computes the GGN/Fisher matrix square root.
"""
# TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC]
# WAITING Deprecation of python3.6 (cyclic imports caused by annotations)
def param_function(
ext: Any,
module: Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
backproped: Tensor,
) -> Tensor:
"""Calculate the GGN/Fisher matrix square root with the derivatives object.
Args:
ext: extension that is used
module: module that performed forward pass
g_inp: input gradient tensors
g_out: output gradient tensors
backproped: Backpropagated quantities from second-order extension.
Returns:
GGN/Fisher matrix square root.
"""
return getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")(
module, g_inp, g_out, backproped, sum_batch=False
)

return param_function
29 changes: 29 additions & 0 deletions backpack/extensions/secondorder/sqrt_ggn/convnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Contains extensions for convolution layers used by ``SqrtGGN{Exact, MC}``."""
from backpack.core.derivatives.conv1d import Conv1DDerivatives
from backpack.core.derivatives.conv2d import Conv2DDerivatives
from backpack.core.derivatives.conv3d import Conv3DDerivatives
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule


class SqrtGGNConv1d(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv1d`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.Conv1d`` module."""
super().__init__(Conv1DDerivatives(), params=["bias", "weight"])


class SqrtGGNConv2d(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv2d`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.Conv2d`` module."""
super().__init__(Conv2DDerivatives(), params=["bias", "weight"])


class SqrtGGNConv3d(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv3d`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.Conv3d`` module."""
super().__init__(Conv3DDerivatives(), params=["bias", "weight"])
29 changes: 29 additions & 0 deletions backpack/extensions/secondorder/sqrt_ggn/convtransposend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Contains transpose convolution layer extensions used by ``SqrtGGN{Exact, MC}``."""
from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives
from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives
from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule


class SqrtGGNConvTranspose1d(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose1d`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.ConvTranspose1d`` module."""
super().__init__(ConvTranspose1DDerivatives(), params=["bias", "weight"])


class SqrtGGNConvTranspose2d(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose2d`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.ConvTranspose2d`` module."""
super().__init__(ConvTranspose2DDerivatives(), params=["bias", "weight"])


class SqrtGGNConvTranspose3d(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose3d`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.ConvTranspose3d`` module."""
super().__init__(ConvTranspose3DDerivatives(), params=["bias", "weight"])
11 changes: 11 additions & 0 deletions backpack/extensions/secondorder/sqrt_ggn/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Contains extensions for dropout layers used by ``SqrtGGN{Exact, MC}``."""
from backpack.core.derivatives.dropout import DropoutDerivatives
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule


class SqrtGGNDropout(SqrtGGNBaseModule):
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Dropout`` module."""

def __init__(self):
"""Pass derivatives for ``torch.nn.Dropout`` module."""
super().__init__(DropoutDerivatives())

0 comments on commit 0a27ebb

Please sign in to comment.