Skip to content

Commit

Permalink
[FMT] Apply black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Mar 1, 2024
1 parent 62ca347 commit 525123b
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions docs_src/examples/use_cases/example_custom_module.py
Expand Up @@ -72,7 +72,7 @@ def forward(self, input):
# how individual gradients are extracted with respect to ``ScaleModule``'s parameter.
#
# The module extension must implement methods named after the parameters passed to the
# constructor. In this case `weights`. For a module with additional parametes e.g. a `bias` additional methods named
# constructor. In this case `weights`. For a module with additional parametes e.g. a `bias` additional methods named
# after these parameters have to be added. For parameter `bias` method `bias` is implemented.
#
# Here it goes.
Expand Down Expand Up @@ -209,7 +209,7 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# Second-order Extension
# ----------------------
# Next, we focus on `BackPACK's second-order extensions <https://docs.backpack.pt/en/master/extensions.html#second-order-extensions>`_.
# They backpropagate additional information and thus require more functionality to be implemented and a more in depth
# They backpropagate additional information and thus require more functionality to be implemented and a more in depth
# understanding of BackPACK's internals and expert understanding of the metric
#
# Let's make BackPACK support computing the exact diagonal of the Gauss-Newton matrix for ``ScaleModule``.
Expand All @@ -219,7 +219,7 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# case for your module you can wrap it in a :py:class:`torch.nn.Module <torch.nn.Module>`.
#
# The second step is to write a module extension that implements how the exact diagonal of the Gauss-Newton matrix is
# computed for ``ScaleModule``.
# computed for ``ScaleModule``.
#
# To do this we need to understand the following about the extension:
# 1. The GGN is calculated by multiplying the Jacobian (w.r.t the parameters) with the Hessian of the loss function.
Expand All @@ -236,14 +236,14 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# ^^^^^^^^^^^^^^^
# Fist, the definition of the GGN:
# The GGN is calculated by multiplying the Jacobian (w.r.t the parameters) with the Hessian of the loss function.
#
#
# .. math::
# \mathbf{G}(\theta) = (\mathbf{J}_\theta f_\theta(x))^T \; \nabla^2_{f_\theta(x^{(0)})} \ell (f_\theta(x^{(0)}), y) \; (\mathbf{J}_\theta f_\theta(x))
#
# The Jacobian (left & right of RHS) is the matrix of all first-order derivatives of the function (neural network) w.r.t. the parameters.
# The Hessian (center of RHS) is the matrix of all second-order derivatives of the loss function w.r.t. the neural network output.
# The GGN (LHS) will be a matrix with dim :math:`p \times p` where :math:`p` is the number of parameters. It is calculated
# w.r.t the parameters of the network. In the implementation we will have to split the computation for each named
# The GGN (LHS) will be a matrix with dim :math:`p \times p` where :math:`p` is the number of parameters. It is calculated
# w.r.t the parameters of the network. In the implementation we will have to split the computation for each named
# parameter, e.g. ``weight``, ``bias``, etc..
#
# If the loss function is convex, which is the case for many losses in ML, the following holds:
Expand All @@ -253,7 +253,7 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# \exists S \in \mathbb{R}^{p \times p} \text{ s.t. } SS^T=\nabla^2_{f_\theta(x^{(0)})} \ell (f_\theta(x^{(0)}), y)
#
# There exists a decomposition of the Hessian into a multiplication of :math:`S` with its transpose.
# A corollary of this is that the GGN can be decomposed into a multiplication
# A corollary of this is that the GGN can be decomposed into a multiplication
# of :math:`V=(\mathbf{J}_\theta f_\theta(x))^T\;S` with its transpose:
#
# .. math::
Expand All @@ -264,18 +264,18 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# Calculations by Chain Rule
# ^^^^^^^^^^^^^^^^^^^^^^^^^^
# The Hessian and the required Jacobians are computed during the backward pass of the autograd engine using the chain rule.
# When using ANNs the autograd engine builds a representation of the ANN by using compositions of "atomic" operations.
# When using ANNs the autograd engine builds a representation of the ANN by using compositions of "atomic" operations.
# This is called computation graph. Consider the computation graph:
#
# .. image:: ../../images/comp_graph.jpg
# :width: 75%
# :align: center
#
# Each node in the graph represents a tensor. The arrows represent the flow of information and the computation associated
# Each node in the graph represents a tensor. The arrows represent the flow of information and the computation associated
# with the incoming and outgoing tensors: :math:`f_{\theta^{(k)}}^{(k)}(x^{(k)}) = x^{(k+1)}`. The information is
# computed by the function---i.e. neural network layer---at the node.
#
# The parameter vector :math:`\theta` contains all parameters of the ANN and is composed of the stacked parameters of
# The parameter vector :math:`\theta` contains all parameters of the ANN and is composed of the stacked parameters of
# each layer of the neural network.
#
# .. math::
Expand All @@ -288,14 +288,14 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# .. math::
# \mathbf{J}_\theta f_\theta(x) = (\mathbf{J}_{\theta^{(1)}} f_{\theta}(x^{(0)}), \mathbf{J}_{\theta^{(2)}} f_{\theta}(x^{(0)}), \dots, \mathbf{J}_{\theta^{(l)}} f_\theta(x^{(0)}))
#
# Due to the structure of the computation graph and the chain rule each Jacobian can be computed by multiplying the
# Due to the structure of the computation graph and the chain rule each Jacobian can be computed by multiplying the
# Jacobians against the information flow in the computation graph. For the path of interest:
#
# .. math::
# p^{(k)} = ((\theta^{(k)} \rightarrow x^{(k)}), (x^{(k)} \rightarrow x^{(k+1)}), (x^{(k+1)} \rightarrow x^{(k+2)}),\dots, (x^{(l-1)} \rightarrow x^{(l)}))
#
#
# The Jacobian of this path is computed by chaining the local Jacobian of each computation:
#
#
# .. math::
# \mathbf{J}_{\theta^{(k)}} f_{\theta}(x^{(0)}) = (\mathbf{J}_{x^{(l-1)}} f_\theta(x^{(0)}))\;\dots \; (\mathbf{J}_{x^{(k+2)}} x^{(k+1)})\;(\mathbf{J}_{x^{(k+1)}} x^{(k)})\;(\mathbf{J}_{\theta^{(k)}} x^{(k)})
#
Expand All @@ -305,12 +305,12 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# \mathbf{J}_{\theta^{(k)}} f_{\theta}(x^{(0)}) = (\mathbf{J}_{\theta^{(k)}} x^{(k)})^T\;(\mathbf{J}_{x^{(k+1)}} x^{(k)})^T\;(\mathbf{J}_{x^{(k+2)}} x^{(k+1)})^T\;\dots \;(\mathbf{J}_{x^{(l-1)}} f_\theta(x^{(0)}))^T
#
# If we assume that we receive the Jacobian :math:`\mathbf{J}_{x^{(k)}} f_\theta (x^{(0)})` from the previous node in the graph we can focus the computation on the local Jacobian
# :math:`\mathbf{J}_{x^{(k-1)}} x^{(k)}` and :math:`\mathbf{J}_{\theta^{(k)}} x^{(k)}`. The current nodes backwarded Jacobian is then given by
# :math:`\mathbf{J}_{x^{(k-1)}} x^{(k)}` and :math:`\mathbf{J}_{\theta^{(k)}} x^{(k)}`. The current nodes backwarded Jacobian is then given by
#
# .. math::
# (\mathbf{J}_{x^{(k-1)}} f_\theta (x^{(0)}))^T=(\mathbf{J}_{x^{(k-1)}} x^{(k)})^T\;(\mathbf{J}_{x^{(k)}} f_\theta (x^{(0)}))^T
# (\mathbf{J}_{x^{(k-1)}} f_\theta (x^{(0)}))^T=(\mathbf{J}_{x^{(k-1)}} x^{(k)})^T\;(\mathbf{J}_{x^{(k)}} f_\theta (x^{(0)}))^T
#
# and the matrix :math:`V` is given by:
# and the matrix :math:`V` is given by:
#
# .. math::
# V=(\mathbf{J}_{\theta^{(k)}} x^{(k)})^T\;(\mathbf{J}_{x^{(k)}} f_\theta (x^{(0)}))^T
Expand All @@ -321,21 +321,21 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
# ^^^^^^^^^^^^^^
#

from torch.nn.utils.convert_parameters import parameters_to_vector

# %%
# First some additional imports.
from backpack.extensions.module_extension import ModuleExtension
from backpack.extensions.secondorder.diag_ggn import DiagGGNExact
from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import vector_to_parameter_list
from torch.nn.utils.convert_parameters import parameters_to_vector



#%%
# %%
# The module extension must implement methods named after the parameters that are passed to the
# constructor. This is similar to the first-order extension. In addition it is necessary to implement the ``backpropagate`` function. This
# constructor. This is similar to the first-order extension. In addition it is necessary to implement the ``backpropagate`` function. This
# function is called by BackPACK during the backward pass and used to feed the Jacobians to later computations.


class ScaleModuleDiagGGNExact(ModuleExtension):
"""Extract diagonal of the Gauss-Newton matrix for ``ScaleModule``."""

Expand Down Expand Up @@ -445,15 +445,15 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
return result


#%%
# After we have implemented the module extension we need to register the mapping between layer (``ScaleModule``) and the
# layer extension (``ScaleModuleDiagGGNExact``) in an instance of :py:class:`DiagGGNExact <backpack.extensions.DiagGGNExact>`.
# %%
# After we have implemented the module extension we need to register the mapping between layer (``ScaleModule``) and the
# layer extension (``ScaleModuleDiagGGNExact``) in an instance of :py:class:`DiagGGNExact <backpack.extensions.DiagGGNExact>`.

extension = DiagGGNExact()
extension.set_module_extension(ScaleModule, ScaleModuleDiagGGNExact())


#%%
# %%
# Testing the extension
# ^^^^^^^^^^^^^^^^^^^^^
# Here, we verify the custom module extension on a small net with random inputs as we have before.
Expand All @@ -466,10 +466,11 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
params = list(my_module.parameters())

ggn_dim = sum(p.numel() for p in params)
diag_ggn_flat = torch.zeros(batch_size * ggn_dim, device=inputs.device, dtype=inputs.dtype)
diag_ggn_flat = torch.zeros(
batch_size * ggn_dim, device=inputs.device, dtype=inputs.dtype
)
# looping explicitly over the batch dimension
for b in range(batch_size):

outputs = my_module(inputs[b])
loss = lossfunc(outputs, targets[b])

Expand Down Expand Up @@ -513,4 +514,3 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
"exact GGN diagonal does not match:"
+ f"\n{grad_batch_autograd}\nvs.\n{grad_batch_backpack}"
)

0 comments on commit 525123b

Please sign in to comment.