Skip to content

Commit

Permalink
[TEST] Reduce run time (#199)
Browse files Browse the repository at this point in the history
- make automated settings' last linear layers non-differentiable to save time
- update new test suite, as it assumes all model parameters are differentiable
- update `ggn_vector_product` to work with models that contain non-differentiable
   parameters and fully-document `ggnvp.py`

---

* [DOC] Fully-document GGNVP and ignore non-differentiable parameters

* [TEST] Make automated settings' last linear layers non-differentiable

The final layers can have large that make the computation of
second-order quantities expensive. Disabling their `requires_grad`
speeds up the tests.

* [TEST] Check non-differentiable parameters while collecting results

* [DOC] Fully document automated test setting helpers
  • Loading branch information
f-dangel committed Jul 7, 2021
1 parent cbee344 commit f329651
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 167 deletions.
80 changes: 47 additions & 33 deletions backpack/hessianfree/ggnvp.py
@@ -1,42 +1,56 @@
from .hvp import hessian_vector_product
from .lop import L_op
from .rop import R_op


def ggn_vector_product(loss, output, model, v):
"""Autodiff-only matrix-free multiplication by the generalized Gauss-Newton/Fisher."""
from typing import List, Tuple

from torch import Tensor
from torch.nn import Module
from torch.nn.parameter import Parameter

from backpack.hessianfree.hvp import hessian_vector_product
from backpack.hessianfree.lop import L_op
from backpack.hessianfree.rop import R_op


def ggn_vector_product(
loss: Tensor, output: Tensor, model: Module, v: List[Tensor]
) -> Tuple[Tensor]:
"""Multiply a vector with the generalized Gauss-Newton/Fisher.
Note:
``G v = J.T @ H @ J @ v`` where ``J`` is the Jacobian of ``output`` w.r.t.
``model``'s trainable parameters and `H` is the Hessian of `loss` w.r.t.
``output``. ``v`` is the flattened and concatenated version of the passed
list of vectors.
Args:
loss: Scalar tensor that represents the loss.
output: Model output.
model: The model.
v: Vector specified as list of tensors matching the trainable parameters.
Returns:
GGN-vector product in list format, i.e. as list that matches the sizes
of trainable model parameters.
"""
Multiplies the vector `v` with the Generalized Gauss-Newton,
`ggn_v = J.T @ H @ J @ v`
where `J` is the Jacobian of `output` w.r.t. `model.parameters()`
and `H` is the Hessian of `loss` w.r.t. `output`.
return ggn_vector_product_from_plist(
loss, output, [p for p in model.parameters() if p.requires_grad], v
)

Example usage:
```
X, Y = data()
model = torch.nn.Linear(784, 10)
lossfunc = torch.nn.CrossEntropyLoss()

output = model(X)
loss = lossfunc(output, Y)
def ggn_vector_product_from_plist(
loss: Tensor, output: Tensor, plist: List[Parameter], v: List[Tensor]
) -> Tuple[Tensor]:
"""Multiply a vector with a sub-block of the generalized Gauss-Newton/Fisher.
v = list([torch.randn_like(p) for p in model.parameters])
Args:
loss: Scalar tensor that represents the loss.
output: Model output.
plist: List of trainable parameters whose GGN block is used for multiplication.
v: Vector specified as list of tensors matching the sizes of ``plist``.
GGNv = ggn_vector_product(loss, output, model, v)
```
Parameters:
-----------
loss: torch.Tensor
output: torch.Tensor
model: torch.nn.Module
v: [torch.Tensor]
List of tensors matching the sizes of model.parameters()
Returns:
GGN-vector product in list format, i.e. as list that matches the sizes of
``plist``.
"""
return ggn_vector_product_from_plist(loss, output, list(model.parameters()), v)


def ggn_vector_product_from_plist(loss, output, plist, v):
Jv = R_op(output, plist, v)
HJv = hessian_vector_product(loss, output, Jv)
JTHJv = L_op(output, plist, HJv)
Expand Down
3 changes: 3 additions & 0 deletions fully_documented.txt
Expand Up @@ -39,10 +39,13 @@ backpack/extensions/secondorder/diag_hessian/conv2d.py
backpack/extensions/secondorder/diag_hessian/conv3d.py
backpack/extensions/secondorder/sqrt_ggn/

backpack/hessianfree/ggnvp.py

backpack/utils/linear.py
backpack/utils/subsampling.py
backpack/utils/__init__.py

test/extensions/automated_settings.py
test/extensions/problem.py
test/extensions/test_backprop_extension.py
test/extensions/firstorder/firstorder_settings.py
Expand Down
179 changes: 98 additions & 81 deletions test/extensions/automated_settings.py
@@ -1,126 +1,143 @@
"""Contains helpers to create CNN test cases."""
from test.core.derivatives.utils import classification_targets
from typing import Any, Tuple, Type

import torch
from torch import Tensor, rand
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, Module, ReLU, Sequential

###
# Helpers
###

def set_requires_grad(model: Module, new_requires_grad: bool) -> None:
"""Set the ``requires_grad`` attribute of the model parameters.
def make_simple_act_setting(act_cls, bias):
Args:
model: Network or layer.
new_requires_grad: New value for ``requires_grad``.
"""
input: Activation function & Bias setting
return: simple CNN Network
for p in model.parameters():
p.requires_grad = new_requires_grad

This function is used to automatically create a
simple CNN Network consisting of CNN & Linear layer
for different activation functions.
It is used to test `test.extensions`.

def make_simple_act_setting(act_cls: Type[Module], bias: bool) -> dict:
"""Create a simple CNN with activation as test case dictionary.
Make parameters of final linear layer non-differentiable to save run time.
Args:
act_cls: Class of the activation function.
bias: Use bias in the convolution.
Returns:
Dictionary representation of the simple CNN test case.
"""

def make_simple_cnn(act_cls, bias):
return torch.nn.Sequential(
torch.nn.Conv2d(3, 2, 2, bias=bias),
act_cls(),
torch.nn.Flatten(),
torch.nn.Linear(72, 5),
)
def _make_simple_cnn(act_cls: Type[Module], bias: bool) -> Sequential:
linear = Linear(72, 5)
set_requires_grad(linear, False)

return Sequential(Conv2d(3, 2, 2, bias=bias), act_cls(), Flatten(), linear)

dict_setting = {
"input_fn": lambda: torch.rand(3, 3, 7, 7),
"module_fn": lambda: make_simple_cnn(act_cls, bias),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(),
"input_fn": lambda: rand(3, 3, 7, 7),
"module_fn": lambda: _make_simple_cnn(act_cls, bias),
"loss_function_fn": lambda: CrossEntropyLoss(),
"target_fn": lambda: classification_targets((3,), 5),
"id_prefix": "automated-simple-cnn-act",
}

return dict_setting


def make_simple_cnn_setting(input_size, conv_class, conv_params):
"""
input_size: tuple of input size of (N*C*Image Size)
conv_class: convolutional class
conv_params: configurations for convolutional class
return: simple CNN Network
This function is used to automatically create a
simple CNN Network consisting of CNN & Linear layer
for different convolutional layers.
It is used to test `test.extensions`.
def make_simple_cnn_setting(
input_size: Tuple[int], conv_cls: Type[Module], conv_params: Tuple[Any]
) -> dict:
"""Create ReLU CNN with convolution hyperparameters as test case dictionary.
Make parameters of final linear layer non-differentiable to save run time.
Args:
input_size: Input shape ``[N, C_in, ...]``.
conv_cls: Class of convolution layer.
conv_params: Convolution hyperparameters.
Returns:
Dictionary representation of the test case.
"""

def make_cnn(conv_class, output_size, conv_params):
"""Note: output class size is assumed to be 5"""
return torch.nn.Sequential(
conv_class(*conv_params),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(output_size, 5),
)
def _make_cnn(
conv_cls: Type[Module], output_dim: int, conv_params: Tuple
) -> Sequential:
linear = Linear(output_dim, 5)
set_requires_grad(linear, False)

def get_output_shape(module, module_params, input):
"""Returns the output shape for a given layer."""
output = module(*module_params)(input)
return output.numel() // output.shape[0]
return Sequential(conv_cls(*conv_params), ReLU(), Flatten(), linear)

input = torch.rand(input_size)
output_size = get_output_shape(conv_class, conv_params, input)
input = rand(input_size)
output_dim = _get_output_dim(conv_cls(*conv_params), input)

dict_setting = {
"input_fn": lambda: torch.rand(input_size),
"module_fn": lambda: make_cnn(conv_class, output_size, conv_params),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"input_fn": lambda: rand(input_size),
"module_fn": lambda: _make_cnn(conv_cls, output_dim, conv_params),
"loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
"id_prefix": "automated-simple-cnn",
}

return dict_setting


def make_simple_pooling_setting(input_size, conv_class, pool_cls, pool_params):
"""
input_size: tuple of input size of (N*C*Image Size)
conv_class: convolutional class
conv_params: configurations for convolutional class
return: simple CNN Network
This function is used to automatically create a
simple CNN Network consisting of CNN & Linear layer
for different convolutional layers.
It is used to test `test.extensions`.
def make_simple_pooling_setting(
input_size: Tuple[int],
conv_cls: Type[Module],
pool_cls: Type[Module],
pool_params: Tuple[Any],
) -> dict:
"""Create CNN with convolution and pooling layer as test case dictionary.
Make parameters of final linear layer non-differentiable to save run time.
Args:
input_size: Input shape ``[N, C_in, ...]``.
conv_cls: Class of convolution layer.
pool_cls: Class of pooling layer.
pool_params: Pooling hyperparameters.
Returns:
Dictionary representation of the test case.
"""

def make_cnn(conv_class, output_size, conv_params, pool_cls, pool_params):
"""Note: output class size is assumed to be 5"""
return torch.nn.Sequential(
conv_class(*conv_params),
torch.nn.ReLU(),
pool_cls(*pool_params),
torch.nn.Flatten(),
torch.nn.Linear(output_size, 5),
def _make_cnn(
conv_cls: Type[Module],
output_size: int,
conv_params: Tuple[Any],
pool_cls: Type[Module],
pool_params: Tuple[Any],
) -> Sequential:
linear = Linear(output_size, 5)
set_requires_grad(linear, False)

return Sequential(
conv_cls(*conv_params), ReLU(), pool_cls(*pool_params), Flatten(), linear
)

def get_output_shape(module, module_params, input, pool, pool_params):
"""Returns the output shape for a given layer."""
output_1 = module(*module_params)(input)
output = pool_cls(*pool_params)(output_1)
return output.numel() // output.shape[0]

conv_params = (3, 2, 2)
input = torch.rand(input_size)
output_size = get_output_shape(
conv_class, conv_params, input, pool_cls, pool_params
input = rand(input_size)
output_dim = _get_output_dim(
Sequential(conv_cls(*conv_params), pool_cls(*pool_params)), input
)

dict_setting = {
"input_fn": lambda: torch.rand(input_size),
"module_fn": lambda: make_cnn(
conv_class, output_size, conv_params, pool_cls, pool_params
"input_fn": lambda: rand(input_size),
"module_fn": lambda: _make_cnn(
conv_cls, output_dim, conv_params, pool_cls, pool_params
),
"loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"),
"loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"),
"target_fn": lambda: classification_targets((3,), 5),
"id_prefix": "automated-simple-cnn",
}

return dict_setting


def _get_output_dim(module: Module, input: Tensor) -> int:
output = module(input)
return output.numel() // output.shape[0]

0 comments on commit f329651

Please sign in to comment.