Skip to content

Commit

Permalink
Add a test for alpha usage in DiceFocalLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
kephale committed Jun 12, 2024
1 parent 8693ada commit 06a2509
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ def test_script(self):
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)

def test_result_with_alpha(self):
size = [3, 3, 5, 5]
label = torch.randint(low=0, high=2, size=size)
pred = torch.randn(size)
alpha_values = [0.25, 0.5, 0.75]
for reduction in ["sum", "mean", "none"]:
for weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]:
common_params = {
"include_background": True,
"to_onehot_y": False,
"reduction": reduction,
"weight": weight,
}
for lambda_focal in [0.5, 1.0, 1.5]:
for alpha in alpha_values:
dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
dice = DiceLoss(**common_params)
focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)
result = dice_focal(pred, label)
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
np.testing.assert_allclose(result, expected_val)


if __name__ == "__main__":
unittest.main()

0 comments on commit 06a2509

Please sign in to comment.