diff --git a/backpack/extensions/curvature.py b/backpack/extensions/curvature.py index 954dc7e0..b9a41ac6 100644 --- a/backpack/extensions/curvature.py +++ b/backpack/extensions/curvature.py @@ -1,4 +1,11 @@ -import torch +"""Modification of second-order module effects during Hessian backpropagation. + +The residual term is tweaked to give rise to the following curvatures: +- No modification: Exact Hessian +- Neglect module second order information: Generalized Gauss-Newton matrix +- Cast negative residual eigenvalue to their absolute value: PCH-abs +- Set negative residual eigenvalues to zero: PCH-clip +""" class ResidualModifications: @@ -18,28 +25,18 @@ def remove_negative_values(res): def to_abs(res): return res.abs() - @staticmethod - def to_med(res, if_negative_return=None): - median = res.median() - if median < 0: - return None - else: - return median * torch.ones_like(res) - class Curvature: HESSIAN = "hessian" GGN = "ggn" PCH_ABS = "pch-abs" PCH_CLIP = "pch-clip" - PCH_MED = "pch-med" CHOICES = [ HESSIAN, GGN, PCH_CLIP, PCH_ABS, - PCH_MED, ] REQUIRE_PSD_LOSS_HESSIAN = { @@ -47,7 +44,6 @@ class Curvature: GGN: True, PCH_ABS: True, PCH_CLIP: True, - PCH_MED: True, } REQUIRE_RESIDUAL = { @@ -55,7 +51,6 @@ class Curvature: GGN: False, PCH_ABS: True, PCH_CLIP: True, - PCH_MED: True, } RESIDUAL_MODS = { @@ -63,7 +58,6 @@ class Curvature: GGN: ResidualModifications.to_zero, PCH_ABS: ResidualModifications.to_abs, PCH_CLIP: ResidualModifications.remove_negative_values, - PCH_MED: ResidualModifications.to_med, } @classmethod