Skip to content

Commit

Permalink
[TEST] Individual gradients of linear with additional dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Apr 12, 2021
1 parent bf89c8b commit 8672e75
Showing 1 changed file with 39 additions and 5 deletions.
44 changes: 39 additions & 5 deletions test/extensions/firstorder/batch_grad/batchgrad_settings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,48 @@
"""Test configurations to test batch_grad
"""Test cases for ``backpack.extensions.BatchGrad``.
The tests are taken from `test.extensions.firstorder.firstorder_settings`,
but additional custom tests can be defined here by appending it to the list.
The cases are taken from ``test.extensions.firstorder.firstorder_settings``.
Additional local cases can be defined by appending them to ``LOCAL_SETTINGS``.
"""

from test.core.derivatives.utils import regression_targets
from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS

import torch

BATCHGRAD_SETTINGS = []

SHARED_SETTINGS = FIRSTORDER_SETTINGS
LOCAL_SETTING = []
LOCAL_SETTINGS = [
# nn.Linear with one additional dimension
{
"input_fn": lambda: torch.rand(3, 4, 5),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Linear(5, 3), torch.nn.Linear(3, 2)
),
"loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((3, 4, 2)),
"id_prefix": "one-additional",
},
# nn.Linear with two additional dimensions
{
"input_fn": lambda: torch.rand(3, 4, 2, 5),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Linear(5, 3), torch.nn.Linear(3, 2)
),
"loss_function_fn": lambda: torch.nn.MSELoss(reduction="mean"),
"target_fn": lambda: regression_targets((3, 4, 2, 2)),
"id_prefix": "two-additional",
},
# nn.Linear with three additional dimensions, sum reduction
{
"input_fn": lambda: torch.rand(3, 4, 2, 3, 5),
"module_fn": lambda: torch.nn.Sequential(
torch.nn.Linear(5, 3), torch.nn.Linear(3, 2)
),
"loss_function_fn": lambda: torch.nn.MSELoss(reduction="sum"),
"target_fn": lambda: regression_targets((3, 4, 2, 3, 2)),
"id_prefix": "three-additional",
},
]

BATCHGRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTING
BATCHGRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS

0 comments on commit 8672e75

Please sign in to comment.