From 5bb18100955f2c66bd85498fe88c43bb3c0b640a Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Sat, 24 Sep 2022 18:50:43 +0200 Subject: [PATCH] [FIX] Use batch size 1 in KFAC ResNet tests (#265) * [FIX] Use batch size 1 * [DEL] Remove unused import Co-authored-by: Felix Dangel --- test/extensions/secondorder/hbp/kfac_settings.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/test/extensions/secondorder/hbp/kfac_settings.py b/test/extensions/secondorder/hbp/kfac_settings.py index 7b244883..82bfa9b2 100644 --- a/test/extensions/secondorder/hbp/kfac_settings.py +++ b/test/extensions/secondorder/hbp/kfac_settings.py @@ -9,7 +9,6 @@ from torch import rand from torch.nn import ( CrossEntropyLoss, - Flatten, Identity, Linear, MSELoss, @@ -31,15 +30,12 @@ BATCH_SIZE_1_SETTINGS = [ { "input_fn": lambda: rand(1, 7), - "module_fn": lambda: Sequential( - Linear(7, 3), ReLU(), Flatten(start_dim=1, end_dim=-1), Linear(3, 1) - ), + "module_fn": lambda: Sequential(Linear(7, 3), ReLU(), Linear(3, 1)), "loss_function_fn": lambda: MSELoss(reduction="mean"), "target_fn": lambda: regression_targets((1, 1)), - "id_prefix": "one-additional", }, { - "input_fn": lambda: rand(3, 10), + "input_fn": lambda: rand(1, 10), "module_fn": lambda: Sequential( Linear(10, 5), ReLU(), @@ -53,11 +49,11 @@ Linear(5, 4), ), "loss_function_fn": lambda: CrossEntropyLoss(), - "target_fn": lambda: classification_targets((3,), 4), + "target_fn": lambda: classification_targets((1,), 4), "id_prefix": "branching-linear", }, { - "input_fn": lambda: rand(3, 10), + "input_fn": lambda: rand(1, 10), "module_fn": lambda: Sequential( Linear(10, 5), ReLU(), @@ -71,7 +67,7 @@ Linear(5, 4), ), "loss_function_fn": lambda: CrossEntropyLoss(), - "target_fn": lambda: classification_targets((3,), 4), + "target_fn": lambda: classification_targets((1,), 4), "id_prefix": "branching-scalar", }, ]