Skip to content

Commit

Permalink
Add tests for ConvTranspose2d bias_jac_t_mat_prod
Browse files Browse the repository at this point in the history
  • Loading branch information
F. Dangel committed Apr 26, 2020
1 parent e6d1382 commit 1ac99a3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
18 changes: 17 additions & 1 deletion test/test_bias_jac_t_mat_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import pytest
import torch
from torch.nn import Conv2d, Linear
from torch.nn import Conv2d, ConvTranspose2d, Linear

from backpack import extend
from backpack.hessianfree.lop import transposed_jacobian_vector_product
Expand Down Expand Up @@ -33,6 +33,14 @@ def make_id(layer, input_shape, sum_batch):
(3, 2, 11, 13),
True,
],
[ConvTranspose2d(2, 3, kernel_size=2), (3, 2, 11, 13), True],
[ConvTranspose2d(2, 3, kernel_size=2, padding=1), (3, 2, 11, 13), True],
[ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2), (3, 2, 11, 13), True],
[
ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2, dilation=2),
(3, 2, 11, 13),
True,
],
# sum_batch = False
[Linear(5, 1), (3, 5), False],
[Linear(20, 10), (5, 20), False],
Expand All @@ -44,6 +52,14 @@ def make_id(layer, input_shape, sum_batch):
(3, 2, 11, 13),
False,
],
[ConvTranspose2d(2, 3, kernel_size=2), (3, 2, 11, 13), False],
[ConvTranspose2d(2, 3, kernel_size=2, padding=1), (3, 2, 11, 13), False],
[ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2), (3, 2, 11, 13), False],
[
ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2, dilation=2),
(3, 2, 11, 13),
False,
],
]
IDS = [
make_id(layer, input_shape, sum_batch)
Expand Down
1 change: 1 addition & 0 deletions test/test_weight_jac_t_mat_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .test_bias_jac_t_mat_prod import ARGS, SETTINGS, make_id
from .test_ea_jac_t_mat_jac_prod import derivative_from_layer, get_output_shape

SETTINGS = [s for s in SETTINGS if not isinstance(s[0], torch.nn.ConvTranspose2d)]
IDS = [
make_id(layer, input_shape, sum_batch)
for (layer, input_shape, sum_batch) in SETTINGS
Expand Down

0 comments on commit 1ac99a3

Please sign in to comment.