Skip to content

Commit

Permalink
[FIX] Use batch size 1 in KFAC ResNet tests (#265)
Browse files Browse the repository at this point in the history
* [FIX] Use batch size 1

* [DEL] Remove unused import

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>
  • Loading branch information
f-dangel and f-dangel committed Sep 24, 2022
1 parent 53ddd86 commit 5bb1810
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions test/extensions/secondorder/hbp/kfac_settings.py
Expand Up @@ -9,7 +9,6 @@
from torch import rand
from torch.nn import (
CrossEntropyLoss,
Flatten,
Identity,
Linear,
MSELoss,
Expand All @@ -31,15 +30,12 @@
BATCH_SIZE_1_SETTINGS = [
{
"input_fn": lambda: rand(1, 7),
"module_fn": lambda: Sequential(
Linear(7, 3), ReLU(), Flatten(start_dim=1, end_dim=-1), Linear(3, 1)
),
"module_fn": lambda: Sequential(Linear(7, 3), ReLU(), Linear(3, 1)),
"loss_function_fn": lambda: MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((1, 1)),
"id_prefix": "one-additional",
},
{
"input_fn": lambda: rand(3, 10),
"input_fn": lambda: rand(1, 10),
"module_fn": lambda: Sequential(
Linear(10, 5),
ReLU(),
Expand All @@ -53,11 +49,11 @@
Linear(5, 4),
),
"loss_function_fn": lambda: CrossEntropyLoss(),
"target_fn": lambda: classification_targets((3,), 4),
"target_fn": lambda: classification_targets((1,), 4),
"id_prefix": "branching-linear",
},
{
"input_fn": lambda: rand(3, 10),
"input_fn": lambda: rand(1, 10),
"module_fn": lambda: Sequential(
Linear(10, 5),
ReLU(),
Expand All @@ -71,7 +67,7 @@
Linear(5, 4),
),
"loss_function_fn": lambda: CrossEntropyLoss(),
"target_fn": lambda: classification_targets((3,), 4),
"target_fn": lambda: classification_targets((1,), 4),
"id_prefix": "branching-scalar",
},
]

0 comments on commit 5bb1810

Please sign in to comment.