Skip to content

Commit

Permalink
Merge 788bcaf into d3b134f
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Dec 19, 2022
2 parents d3b134f + 788bcaf commit 3b1e9c2
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 23 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ jobs:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
pytorch-version: [1.9.0, 1.12.0]
pytorch-version:
- "==1.9.1"
- "==1.10.1"
- "==1.11.0"
- "==1.12.1"
- "==1.13.1"
- "" # latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
Expand All @@ -30,7 +36,7 @@ jobs:
run: |
python -m pip install --upgrade pip
make install-test
pip install torch==${{ matrix.pytorch-version }} torchvision
pip install torch${{ matrix.pytorch-version }} torchvision
- name: Run test
if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref)
run: |
Expand Down
7 changes: 6 additions & 1 deletion backpack/core/derivatives/conv_transposend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
from numpy import prod
from torch import Tensor, einsum
from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module
from torch.nn.grad import _grad_input_padding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_13
from backpack.utils.conv import get_conv_function
from backpack.utils.conv_transpose import (
get_conv_transpose_function,
unfold_by_conv_transpose,
)
from backpack.utils.subsampling import subsample

if TORCH_VERSION_AT_LEAST_1_13:
from backpack.utils.conv import _grad_input_padding
else:
from torch.nn.grad import _grad_input_padding


class ConvTransposeNDDerivatives(BaseParameterDerivatives):
"""Base class for partial derivatives of transpose convolution."""
Expand Down
7 changes: 6 additions & 1 deletion backpack/core/derivatives/convnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from numpy import prod
from torch import Tensor, einsum
from torch.nn import Conv1d, Conv2d, Conv3d, Module
from torch.nn.grad import _grad_input_padding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_13
from backpack.utils.conv import get_conv_function, unfold_by_conv
from backpack.utils.conv_transpose import get_conv_transpose_function
from backpack.utils.subsampling import subsample

if TORCH_VERSION_AT_LEAST_1_13:
from backpack.utils.conv import _grad_input_padding
else:
from torch.nn.grad import _grad_input_padding


class weight_jac_t_save_memory:
"""Choose algorithm to apply transposed convolution weight Jacobian."""
Expand Down
1 change: 1 addition & 0 deletions backpack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
TORCH_VERSION = packaging.version.parse(get_distribution("torch").version)
TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1")
TORCH_VERSION_AT_LEAST_2_0_0 = TORCH_VERSION >= packaging.version.parse("2.0.0")
TORCH_VERSION_AT_LEAST_1_13 = TORCH_VERSION >= packaging.version.parse("1.13")

ADAPTIVE_AVG_POOL_BUG: bool = not TORCH_VERSION_AT_LEAST_2_0_0
48 changes: 48 additions & 0 deletions backpack/utils/conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions for convolution layers."""

from typing import Callable, Tuple, Type, Union
from warnings import warn

import torch
from einops import rearrange
Expand Down Expand Up @@ -179,3 +180,50 @@ def unfold_by_conv(
padding=module.padding,
stride=module.stride,
)


def _grad_input_padding(
grad_output, input_size, stride, padding, kernel_size, dilation=None
):
"""Determine padding for the VJP of convolution.
# noqa: DAR101, DAR201, DAR 401
Note:
This function was copied from the PyTorch repository (version 1.9).
It was removed between torch 1.12.1 and torch 1.13.
"""
if dilation is None:
# For backward compatibility
warn(
"_grad_input_padding 'dilation' argument not provided. Default of 1 is used."
)
dilation = [1] * len(stride)

input_size = list(input_size)
k = grad_output.dim() - 2

if len(input_size) == k + 2:
input_size = input_size[-k:]
if len(input_size) != k:
raise ValueError(f"input_size must have {k+2} elements (got {len(input_size)})")

def dim_size(d):
return (
(grad_output.size(d + 2) - 1) * stride[d]
- 2 * padding[d]
+ 1
+ dilation[d] * (kernel_size[d] - 1)
)

min_sizes = [dim_size(d) for d in range(k)]
max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)]
for size, min_size, max_size in zip(input_size, min_sizes, max_sizes):
if size < min_size or size > max_size:
raise ValueError(
f"requested an input grad size of {input_size}, but valid sizes range "
f"from {min_sizes} to {max_sizes} (for a grad_output of "
f"{grad_output.size()[2:]})"
)

return tuple(input_size[d] - min_sizes[d] for d in range(k))
14 changes: 1 addition & 13 deletions backpack/utils/convert_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Iterable, List

from torch import Tensor, cat, typename
from torch import Tensor, typename


def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[Tensor]:
Expand Down Expand Up @@ -51,15 +51,3 @@ def vector_to_parameter_list(vec: Tensor, parameters: Iterable[Tensor]) -> List[
pointer += num_param

return params_new


def tensor_list_to_vector(tensor_list: Iterable[Tensor]) -> Tensor:
"""Convert a list of tensors into a vector by flattening and concatenation.
Args:
tensor_list: List of tensors.
Returns:
Vector containing the flattened and concatenated tensor inputs.
"""
return cat([t.flatten() for t in tensor_list])
9 changes: 4 additions & 5 deletions backpack/utils/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@

from torch import Tensor, stack, zeros
from torch.nn import Module
from torch.nn.utils.convert_parameters import parameters_to_vector
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor

from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import (
tensor_list_to_vector,
vector_to_parameter_list,
)
from backpack.utils.convert_parameters import vector_to_parameter_list


def load_mnist_dataset() -> Dataset:
Expand Down Expand Up @@ -115,5 +113,6 @@ def _autograd_ggn_exact_columns(
e_d_list = vector_to_parameter_list(e_d, trainable_parameters)

ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list)
ggn_d_list = [t.contiguous() for t in ggn_d_list]

yield d, tensor_list_to_vector(ggn_d_list)
yield d, parameters_to_vector(ggn_d_list)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ ignore =
W291, # trailing whitespace
W503, # line break before binary operator
W504, # line break after binary operator
B905, # 'zip()' without an explicit 'strict=' parameter
exclude = docs, build, .git, docs_src/rtd, docs_src/rtd_output, .eggs

# Differences with pytorch
Expand Down
2 changes: 1 addition & 1 deletion test/core/derivatives/implementation/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor:
for t in tensor.flatten():
try:
yield self._hessian(t, x)
except (RuntimeError, AttributeError):
except (RuntimeError, AttributeError, TypeError):
yield zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype)

def hessian_is_zero(self) -> bool: # noqa: D102
Expand Down

0 comments on commit 3b1e9c2

Please sign in to comment.