Skip to content

Commit

Permalink
Merge pull request #28 from f-dangel/remove-pch-med
Browse files Browse the repository at this point in the history
Remove PCH-med
  • Loading branch information
fKunstner committed Jan 10, 2020
2 parents f2a9806 + 06c9479 commit 3a92411
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions backpack/extensions/curvature.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import torch
"""Modification of second-order module effects during Hessian backpropagation.
The residual term is tweaked to give rise to the following curvatures:
- No modification: Exact Hessian
- Neglect module second order information: Generalized Gauss-Newton matrix
- Cast negative residual eigenvalue to their absolute value: PCH-abs
- Set negative residual eigenvalues to zero: PCH-clip
"""


class ResidualModifications:
Expand All @@ -18,52 +25,39 @@ def remove_negative_values(res):
def to_abs(res):
return res.abs()

@staticmethod
def to_med(res, if_negative_return=None):
median = res.median()
if median < 0:
return None
else:
return median * torch.ones_like(res)


class Curvature:
HESSIAN = "hessian"
GGN = "ggn"
PCH_ABS = "pch-abs"
PCH_CLIP = "pch-clip"
PCH_MED = "pch-med"

CHOICES = [
HESSIAN,
GGN,
PCH_CLIP,
PCH_ABS,
PCH_MED,
]

REQUIRE_PSD_LOSS_HESSIAN = {
HESSIAN: False,
GGN: True,
PCH_ABS: True,
PCH_CLIP: True,
PCH_MED: True,
}

REQUIRE_RESIDUAL = {
HESSIAN: True,
GGN: False,
PCH_ABS: True,
PCH_CLIP: True,
PCH_MED: True,
}

RESIDUAL_MODS = {
HESSIAN: ResidualModifications.nothing,
GGN: ResidualModifications.to_zero,
PCH_ABS: ResidualModifications.to_abs,
PCH_CLIP: ResidualModifications.remove_negative_values,
PCH_MED: ResidualModifications.to_med,
}

@classmethod
Expand Down

0 comments on commit 3a92411

Please sign in to comment.