Skip to content

Commit

Permalink
[TEST] Add cases for ConvTranspose{1,3}d
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 14, 2022
1 parent f2510e3 commit ba0f78e
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion test/extensions/secondorder/hbp/kfac_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@
},
# transpose convolution with single output is a linear layer (no weight
# sharing across input)
{
"input_fn": lambda: rand(1, 2, 9),
"module_fn": lambda: Sequential(
ConvTranspose1d(2, 4, 3, padding=5),
Sigmoid(),
ConvTranspose1d(4, 1, 1),
Flatten(),
),
"loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((1, 1)),
"id_prefix": "convtranspose1d-single-output",
},
# transpose convolution with single output is a linear layer (no weight
# sharing across input)
{
"input_fn": lambda: rand(1, 2, 4, 4),
"module_fn": lambda: Sequential(
Expand All @@ -104,7 +118,21 @@
),
"loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((1, 1)),
"id_prefix": "convtranspose1d-single-output",
"id_prefix": "convtranspose2d-single-output",
},
# transpose convolution with single output is a linear layer (no weight
# sharing across input)
{
"input_fn": lambda: rand(1, 2, 6, 6, 6),
"module_fn": lambda: Sequential(
ConvTranspose3d(2, 4, 2, padding=3),
Sigmoid(),
ConvTranspose3d(4, 1, 1),
Flatten(),
),
"loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((1, 1)),
"id_prefix": "convtranspose3d-single-output",
},
]

Expand Down

0 comments on commit ba0f78e

Please sign in to comment.