Skip to content

Commit

Permalink
Merge pull request #371 from kevingreenman/clamp-v-evidential
Browse files Browse the repository at this point in the history
Clamp evidential 'v' parameter
  • Loading branch information
kevingreenman committed Feb 10, 2023
2 parents a780a06 + eb33981 commit 86007ca
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions chemprop/train/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def dirichlet_common_loss(alphas, y_one_hot, lam=0):


# updated evidential regression loss (evidential_loss_new from Amini repo)
def evidential_loss(pred_values, targets, lam=0, epsilon=1e-8):
def evidential_loss(pred_values, targets, lam: float = 0, epsilon: float = 1e-8, v_min: float = 1e-5):
"""
Use Deep Evidential Regression negative log likelihood loss + evidential
regularizer
Expand All @@ -335,14 +335,17 @@ def evidential_loss(pred_values, targets, lam=0, epsilon=1e-8):
:v: pred lam parameter for NIG
:alpha: predicted parameter for NIG
:beta: Predicted parmaeter for NIG
:targets: Outputs to predict
:param targets: Outputs to predict
:param lam: regularization coefficient
:param v_min: clamp any v below this value to prevent Inf from division
:return: Loss
"""
# Unpack combined prediction values
mu, v, alpha, beta = torch.split(pred_values, pred_values.shape[1] // 4, dim=1)

# Calculate NLL loss
v = torch.clamp(v, v_min)
twoBlambda = 2 * beta * (1 + v)
nll = (
0.5 * torch.log(np.pi / v)
Expand Down

0 comments on commit 86007ca

Please sign in to comment.