From 1ac99a329ff95a43c201e10bae1bca0f6efc8342 Mon Sep 17 00:00:00 2001 From: "F. Dangel" Date: Sun, 26 Apr 2020 18:14:08 +0200 Subject: [PATCH] Add tests for ConvTranspose2d bias_jac_t_mat_prod --- test/test_bias_jac_t_mat_prod.py | 18 +++++++++++++++++- test/test_weight_jac_t_mat_prod.py | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/test/test_bias_jac_t_mat_prod.py b/test/test_bias_jac_t_mat_prod.py index 2d8ba112b..45c79f23d 100644 --- a/test/test_bias_jac_t_mat_prod.py +++ b/test/test_bias_jac_t_mat_prod.py @@ -5,7 +5,7 @@ """ import pytest import torch -from torch.nn import Conv2d, Linear +from torch.nn import Conv2d, ConvTranspose2d, Linear from backpack import extend from backpack.hessianfree.lop import transposed_jacobian_vector_product @@ -33,6 +33,14 @@ def make_id(layer, input_shape, sum_batch): (3, 2, 11, 13), True, ], + [ConvTranspose2d(2, 3, kernel_size=2), (3, 2, 11, 13), True], + [ConvTranspose2d(2, 3, kernel_size=2, padding=1), (3, 2, 11, 13), True], + [ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2), (3, 2, 11, 13), True], + [ + ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2, dilation=2), + (3, 2, 11, 13), + True, + ], # sum_batch = False [Linear(5, 1), (3, 5), False], [Linear(20, 10), (5, 20), False], @@ -44,6 +52,14 @@ def make_id(layer, input_shape, sum_batch): (3, 2, 11, 13), False, ], + [ConvTranspose2d(2, 3, kernel_size=2), (3, 2, 11, 13), False], + [ConvTranspose2d(2, 3, kernel_size=2, padding=1), (3, 2, 11, 13), False], + [ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2), (3, 2, 11, 13), False], + [ + ConvTranspose2d(2, 3, kernel_size=2, padding=1, stride=2, dilation=2), + (3, 2, 11, 13), + False, + ], ] IDS = [ make_id(layer, input_shape, sum_batch) diff --git a/test/test_weight_jac_t_mat_prod.py b/test/test_weight_jac_t_mat_prod.py index a263233fc..801bb7006 100644 --- a/test/test_weight_jac_t_mat_prod.py +++ b/test/test_weight_jac_t_mat_prod.py @@ -14,6 +14,7 @@ from .test_bias_jac_t_mat_prod import ARGS, SETTINGS, make_id from .test_ea_jac_t_mat_jac_prod import derivative_from_layer, get_output_shape +SETTINGS = [s for s in SETTINGS if not isinstance(s[0], torch.nn.ConvTranspose2d)] IDS = [ make_id(layer, input_shape, sum_batch) for (layer, input_shape, sum_batch) in SETTINGS