-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
22 changed files
with
678 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Defines base class and extensions for computing the GGN/Fisher matrix square root.""" | ||
|
||
from torch.nn import ( | ||
ELU, | ||
SELU, | ||
AvgPool1d, | ||
AvgPool2d, | ||
AvgPool3d, | ||
Conv1d, | ||
Conv2d, | ||
Conv3d, | ||
ConvTranspose1d, | ||
ConvTranspose2d, | ||
ConvTranspose3d, | ||
CrossEntropyLoss, | ||
Dropout, | ||
Flatten, | ||
LeakyReLU, | ||
Linear, | ||
LogSigmoid, | ||
MaxPool1d, | ||
MaxPool2d, | ||
MaxPool3d, | ||
MSELoss, | ||
ReLU, | ||
Sigmoid, | ||
Tanh, | ||
ZeroPad2d, | ||
) | ||
|
||
from backpack.extensions.backprop_extension import BackpropExtension | ||
from backpack.extensions.secondorder.hbp import LossHessianStrategy | ||
from backpack.extensions.secondorder.sqrt_ggn import ( | ||
activations, | ||
convnd, | ||
convtransposend, | ||
dropout, | ||
flatten, | ||
linear, | ||
losses, | ||
padding, | ||
pooling, | ||
) | ||
|
||
|
||
class SqrtGGN(BackpropExtension): | ||
"""Base class for extensions that compute the GGN/Fisher matrix square root.""" | ||
|
||
def __init__(self, loss_hessian_strategy: str, savefield: str): | ||
"""Store approximation for backpropagated object and where to save the result. | ||
Args: | ||
loss_hessian_strategy: Which approximation is used for the backpropagated | ||
loss Hessian. Must be ``'exact'`` or ``'sampling'``. | ||
savefield: Attribute under which the quantity is saved in a parameter. | ||
""" | ||
self.loss_hessian_strategy = loss_hessian_strategy | ||
super().__init__( | ||
savefield=savefield, | ||
fail_mode="ERROR", | ||
module_exts={ | ||
MSELoss: losses.SqrtGGNMSELoss(), | ||
CrossEntropyLoss: losses.SqrtGGNCrossEntropyLoss(), | ||
Linear: linear.SqrtGGNLinear(), | ||
MaxPool1d: pooling.SqrtGGNMaxPool1d(), | ||
MaxPool2d: pooling.SqrtGGNMaxPool2d(), | ||
AvgPool1d: pooling.SqrtGGNAvgPool1d(), | ||
MaxPool3d: pooling.SqrtGGNMaxPool3d(), | ||
AvgPool2d: pooling.SqrtGGNAvgPool2d(), | ||
AvgPool3d: pooling.SqrtGGNAvgPool3d(), | ||
ZeroPad2d: padding.SqrtGGNZeroPad2d(), | ||
Conv1d: convnd.SqrtGGNConv1d(), | ||
Conv2d: convnd.SqrtGGNConv2d(), | ||
Conv3d: convnd.SqrtGGNConv3d(), | ||
ConvTranspose1d: convtransposend.SqrtGGNConvTranspose1d(), | ||
ConvTranspose2d: convtransposend.SqrtGGNConvTranspose2d(), | ||
ConvTranspose3d: convtransposend.SqrtGGNConvTranspose3d(), | ||
Dropout: dropout.SqrtGGNDropout(), | ||
Flatten: flatten.SqrtGGNFlatten(), | ||
ReLU: activations.SqrtGGNReLU(), | ||
Sigmoid: activations.SqrtGGNSigmoid(), | ||
Tanh: activations.SqrtGGNTanh(), | ||
LeakyReLU: activations.SqrtGGNLeakyReLU(), | ||
LogSigmoid: activations.SqrtGGNLogSigmoid(), | ||
ELU: activations.SqrtGGNELU(), | ||
SELU: activations.SqrtGGNSELU(), | ||
}, | ||
) | ||
|
||
def get_loss_hessian_strategy(self) -> str: | ||
"""Return the strategy used to represent the backpropagated loss Hessian. | ||
Returns: | ||
Loss Hessian strategy. | ||
""" | ||
return self.loss_hessian_strategy | ||
|
||
|
||
class SqrtGGNExact(SqrtGGN): | ||
"""Exact matrix square root of the generalized Gauss-Newton/Fisher. | ||
Uses the exact Hessian of the loss w.r.t. the model output. | ||
Stores the output in :code:`sqrt_ggn_exact`, has shape ``[C, N, param.shape]``, | ||
where ``C`` is the model output dimension (number of classes for classification | ||
problems) and ``N`` is the batch size. | ||
For a faster but less precise alternative, see | ||
:py:meth:`backpack.extensions.SqrtGGNMC`. | ||
.. note:: | ||
(Relation to the GGN/Fisher) For each parameter, ``param.sqrt_ggn_exact`` | ||
can be viewed as a ``[C * N, param.numel()]`` matrix. Concatenating this | ||
matrix over all parameters results in a matrix ``Vᵀ``, which | ||
is the GGN/Fisher's matrix square root, i.e. ``G = V Vᵀ``. | ||
""" | ||
|
||
def __init__(self): | ||
"""Use exact loss Hessian and set savefield to ``sqrt_ggn_exact``.""" | ||
super().__init__(LossHessianStrategy.EXACT, "sqrt_ggn_exact") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Contains extensions for activation layers used by ``SqrtGGN{Exact, MC}``.""" | ||
from backpack.core.derivatives.elu import ELUDerivatives | ||
from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives | ||
from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives | ||
from backpack.core.derivatives.relu import ReLUDerivatives | ||
from backpack.core.derivatives.selu import SELUDerivatives | ||
from backpack.core.derivatives.sigmoid import SigmoidDerivatives | ||
from backpack.core.derivatives.tanh import TanhDerivatives | ||
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule | ||
|
||
|
||
class SqrtGGNReLU(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ReLU`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.ReLU`` module.""" | ||
super().__init__(ReLUDerivatives()) | ||
|
||
|
||
class SqrtGGNSigmoid(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Sigmoid`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.Sigmoid`` module.""" | ||
super().__init__(SigmoidDerivatives()) | ||
|
||
|
||
class SqrtGGNTanh(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Tanh`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.Tanh`` module.""" | ||
super().__init__(TanhDerivatives()) | ||
|
||
|
||
class SqrtGGNELU(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ELU`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.ELU`` module.""" | ||
super().__init__(ELUDerivatives()) | ||
|
||
|
||
class SqrtGGNSELU(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.SELU`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.SELU`` module.""" | ||
super().__init__(SELUDerivatives()) | ||
|
||
|
||
class SqrtGGNLeakyReLU(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LeakyReLU`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.LeakyReLU`` module.""" | ||
super().__init__(LeakyReLUDerivatives()) | ||
|
||
|
||
class SqrtGGNLogSigmoid(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.LogSigmoid`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.LogSigmoid`` module.""" | ||
super().__init__(LogSigmoidDerivatives()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Contains base class for ``SqrtGGN{Exact, MC}`` module extensions.""" | ||
from typing import Any, Callable, List, Tuple, Union | ||
|
||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from backpack.core.derivatives.basederivatives import ( | ||
BaseDerivatives, | ||
BaseParameterDerivatives, | ||
) | ||
from backpack.extensions.mat_to_mat_jac_base import MatToJacMat | ||
|
||
|
||
class SqrtGGNBaseModule(MatToJacMat): | ||
"""Base module extension for ``SqrtGGN{Exact, MC}``.""" | ||
|
||
def __init__( | ||
self, | ||
derivatives: Union[BaseParameterDerivatives, BaseDerivatives], | ||
params: List[str] = None, | ||
): | ||
"""Store parameter names and derivatives. | ||
Sets up methods that extract the GGN/Fisher matrix square root for the | ||
passed parameters, unless these methods are overwritten by a child class. | ||
Args: | ||
derivatives: derivatives object. | ||
params: List of parameter names. Defaults to None. | ||
""" | ||
if params is not None: | ||
for param_str in params: | ||
if not hasattr(self, param_str): | ||
setattr(self, param_str, self._make_param_function(param_str)) | ||
|
||
super().__init__(derivatives, params=params) | ||
|
||
# TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] | ||
# WAITING Deprecation of python3.6 (cyclic imports caused by annotations) | ||
def _make_param_function( | ||
self, param_str: str | ||
) -> Callable[[Any, Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor]: | ||
"""Create a function that computes the GGN/Fisher square root for a parameter. | ||
Args: | ||
param_str: name of parameter | ||
Returns: | ||
Function that computes the GGN/Fisher matrix square root. | ||
""" | ||
# TODO Replace Any with Union[SqrtGGNExact, SqrtGGNMC] | ||
# WAITING Deprecation of python3.6 (cyclic imports caused by annotations) | ||
def param_function( | ||
ext: Any, | ||
module: Module, | ||
g_inp: Tuple[Tensor], | ||
g_out: Tuple[Tensor], | ||
backproped: Tensor, | ||
) -> Tensor: | ||
"""Calculate the GGN/Fisher matrix square root with the derivatives object. | ||
Args: | ||
ext: extension that is used | ||
module: module that performed forward pass | ||
g_inp: input gradient tensors | ||
g_out: output gradient tensors | ||
backproped: Backpropagated quantities from second-order extension. | ||
Returns: | ||
GGN/Fisher matrix square root. | ||
""" | ||
return getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")( | ||
module, g_inp, g_out, backproped, sum_batch=False | ||
) | ||
|
||
return param_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Contains extensions for convolution layers used by ``SqrtGGN{Exact, MC}``.""" | ||
from backpack.core.derivatives.conv1d import Conv1DDerivatives | ||
from backpack.core.derivatives.conv2d import Conv2DDerivatives | ||
from backpack.core.derivatives.conv3d import Conv3DDerivatives | ||
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule | ||
|
||
|
||
class SqrtGGNConv1d(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv1d`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.Conv1d`` module.""" | ||
super().__init__(Conv1DDerivatives(), params=["bias", "weight"]) | ||
|
||
|
||
class SqrtGGNConv2d(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv2d`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.Conv2d`` module.""" | ||
super().__init__(Conv2DDerivatives(), params=["bias", "weight"]) | ||
|
||
|
||
class SqrtGGNConv3d(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv3d`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.Conv3d`` module.""" | ||
super().__init__(Conv3DDerivatives(), params=["bias", "weight"]) |
29 changes: 29 additions & 0 deletions
29
backpack/extensions/secondorder/sqrt_ggn/convtransposend.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Contains transpose convolution layer extensions used by ``SqrtGGN{Exact, MC}``.""" | ||
from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives | ||
from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives | ||
from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives | ||
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule | ||
|
||
|
||
class SqrtGGNConvTranspose1d(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose1d`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.ConvTranspose1d`` module.""" | ||
super().__init__(ConvTranspose1DDerivatives(), params=["bias", "weight"]) | ||
|
||
|
||
class SqrtGGNConvTranspose2d(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose2d`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.ConvTranspose2d`` module.""" | ||
super().__init__(ConvTranspose2DDerivatives(), params=["bias", "weight"]) | ||
|
||
|
||
class SqrtGGNConvTranspose3d(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose3d`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.ConvTranspose3d`` module.""" | ||
super().__init__(ConvTranspose3DDerivatives(), params=["bias", "weight"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""Contains extensions for dropout layers used by ``SqrtGGN{Exact, MC}``.""" | ||
from backpack.core.derivatives.dropout import DropoutDerivatives | ||
from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule | ||
|
||
|
||
class SqrtGGNDropout(SqrtGGNBaseModule): | ||
"""``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Dropout`` module.""" | ||
|
||
def __init__(self): | ||
"""Pass derivatives for ``torch.nn.Dropout`` module.""" | ||
super().__init__(DropoutDerivatives()) |
Oops, something went wrong.