Skip to content

Commit

Permalink
Merge pull request #309 from kevingreenman/bug/mcc-inf-loss
Browse files Browse the repository at this point in the history
Fix multiclass MCC loss function
  • Loading branch information
kevingreenman committed Jul 15, 2022
2 parents b831fda + 9be4526 commit 359e660
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 89 deletions.
82 changes: 37 additions & 45 deletions chemprop/train/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def bounded_mse_loss(
:param greater_than_target: A tensor with boolean values indicating whether the target is a greater-than inequality.
:return: A tensor containing loss values of shape(batch_size, tasks).
"""
predictions = torch.where(
torch.logical_and(predictions < targets, less_than_target), targets, predictions
)
predictions = torch.where(torch.logical_and(predictions < targets, less_than_target), targets, predictions)

predictions = torch.where(
torch.logical_and(predictions > targets, greater_than_target),
Expand Down Expand Up @@ -106,9 +104,7 @@ def mcc_class_loss(
FP = torch.sum((1 - targets) * predictions * data_weights * mask, axis=0)
FN = torch.sum(targets * (1 - predictions) * data_weights * mask, axis=0)
TN = torch.sum((1 - targets) * (1 - predictions) * data_weights * mask, axis=0)
loss = 1 - (
(TP * TN - FP * FN) / torch.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
)
loss = 1 - ((TP * TN - FP * FN) / torch.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)))
return loss


Expand All @@ -127,21 +123,33 @@ def mcc_multiclass_loss(
:param mask: A tensor with boolean values indicating whether the loss for this prediction is considered in the gradient descent with shape(batch_size).
:return: A tensor value for the loss.
"""
# targets shape (batch)
# preds shape(batch, classes)
torch_device = predictions.device
mask = mask.unsqueeze(1)

bin_targets = torch.zeros_like(predictions, device=torch_device)
bin_targets[torch.arange(predictions.shape[0]), targets] = 1
c = torch.sum(predictions * bin_targets * data_weights * mask)
s = torch.sum(predictions * data_weights * mask)
pt = torch.sum(
torch.sum(predictions * data_weights * mask, axis=0)
* torch.sum(bin_targets * data_weights * mask, axis=0)
)
p2 = torch.sum(torch.sum(predictions * data_weights * mask, axis=0) ** 2)
t2 = torch.sum(torch.sum(bin_targets * data_weights * mask, axis=0) ** 2)
loss = 1 - (c * s - pt) / torch.sqrt((s ** 2 - p2) * (s ** 2 - t2))

pred_classes = predictions.argmax(dim=1)
bin_preds = torch.zeros_like(predictions, device=torch_device)
bin_preds[torch.arange(predictions.shape[0]), pred_classes] = 1

masked_data_weights = data_weights * mask

t_sum = torch.sum(bin_targets * masked_data_weights, axis=0) # number of times each class truly occurred
p_sum = torch.sum(bin_preds * masked_data_weights, axis=0) # number of times each class was predicted

n_correct = torch.sum(bin_preds * bin_targets * masked_data_weights) # total number of samples correctly predicted
n_samples = torch.sum(predictions * masked_data_weights) # total number of samples

cov_ytyp = n_correct * n_samples - torch.dot(p_sum, t_sum)
cov_ypyp = n_samples**2 - torch.dot(p_sum, p_sum)
cov_ytyt = n_samples**2 - torch.dot(t_sum, t_sum)

if cov_ypyp * cov_ytyt == 0:
loss = torch.tensor(0.0)
else:
loss = cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp)

return loss


Expand All @@ -168,21 +176,17 @@ def sid_loss(
one_sub = torch.ones_like(model_spectra, device=torch_device)
if threshold is not None:
threshold_sub = torch.full(model_spectra.shape, threshold, device=torch_device)
model_spectra = torch.where(
model_spectra < threshold, threshold_sub, model_spectra
)
model_spectra = torch.where(model_spectra < threshold, threshold_sub, model_spectra)
model_spectra = torch.where(mask, model_spectra, zero_sub)
sum_model_spectra = torch.sum(model_spectra, axis=1, keepdim=True)
model_spectra = torch.div(model_spectra, sum_model_spectra)

# Calculate loss value
target_spectra = torch.where(mask, target_spectra, one_sub)
model_spectra = torch.where(
mask, model_spectra, one_sub
) # losses in excluded regions will be zero because log(1/1) = 0.
loss = torch.mul(
torch.log(torch.div(model_spectra, target_spectra)), model_spectra
) + torch.mul(torch.log(torch.div(target_spectra, model_spectra)), target_spectra)
model_spectra = torch.where(mask, model_spectra, one_sub) # losses in excluded regions will be zero because log(1/1) = 0.
loss = torch.mul(torch.log(torch.div(model_spectra, target_spectra)), model_spectra) + torch.mul(
torch.log(torch.div(target_spectra, model_spectra)), target_spectra
)

return loss

Expand All @@ -209,9 +213,7 @@ def wasserstein_loss(
zero_sub = torch.zeros_like(model_spectra, device=torch_device)
if threshold is not None:
threshold_sub = torch.full(model_spectra.shape, threshold, device=torch_device)
model_spectra = torch.where(
model_spectra < threshold, threshold_sub, model_spectra
)
model_spectra = torch.where(model_spectra < threshold, threshold_sub, model_spectra)
model_spectra = torch.where(mask, model_spectra, zero_sub)
sum_model_spectra = torch.sum(model_spectra, axis=1, keepdim=True)
model_spectra = torch.div(model_spectra, sum_model_spectra)
Expand All @@ -236,9 +238,7 @@ def normal_mve(pred_values, targets):
# Unpack combined prediction values
pred_means, pred_var = torch.split(pred_values, pred_values.shape[1] // 2, dim=1)

return torch.log(2 * np.pi * pred_var) / 2 + (pred_means - targets) ** 2 / (
2 * pred_var
)
return torch.log(2 * np.pi * pred_var) / 2 + (pred_means - targets) ** 2 / (2 * pred_var)


# evidential classification
Expand Down Expand Up @@ -282,7 +282,7 @@ def dirichlet_multiclass_loss(alphas, target_labels, lam=0):
def dirichlet_common_loss(alphas, y_one_hot, lam=0):
"""
Use Evidential Learning Dirichlet loss from Sensoy et al. This function follows
after the classification and multiclass specific functions that reshape the
after the classification and multiclass specific functions that reshape the
alpha inputs and create one-hot targets.
:param alphas: Predicted parameters for Dirichlet in shape(datapoints, task, classes).
Expand All @@ -294,7 +294,7 @@ def dirichlet_common_loss(alphas, y_one_hot, lam=0):
# SOS term
S = torch.sum(alphas, dim=-1, keepdim=True)
p = alphas / S
A = torch.sum((y_one_hot - p)**2, dim=-1, keepdim=True)
A = torch.sum((y_one_hot - p) ** 2, dim=-1, keepdim=True)
B = torch.sum((p * (1 - p)) / (S + 1), dim=-1, keepdim=True)
SOS = A + B

Expand All @@ -304,23 +304,15 @@ def dirichlet_common_loss(alphas, y_one_hot, lam=0):
S_alpha = torch.sum(alpha_hat, dim=-1, keepdim=True)
S_beta = torch.sum(beta, dim=-1, keepdim=True)

ln_alpha = torch.lgamma(S_alpha) - torch.sum(
torch.lgamma(alpha_hat), dim=-1, keepdim=True
)
ln_beta = torch.sum(torch.lgamma(beta), dim=-1, keepdim=True) - torch.lgamma(
S_beta
)
ln_alpha = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha_hat), dim=-1, keepdim=True)
ln_beta = torch.sum(torch.lgamma(beta), dim=-1, keepdim=True) - torch.lgamma(S_beta)

# digamma terms
dg_alpha = torch.digamma(alpha_hat)
dg_S_alpha = torch.digamma(S_alpha)

# KL
KL = (
ln_alpha
+ ln_beta
+ torch.sum((alpha_hat - beta) * (dg_alpha - dg_S_alpha), dim=-1, keepdim=True)
)
KL = ln_alpha + ln_beta + torch.sum((alpha_hat - beta) * (dg_alpha - dg_S_alpha), dim=-1, keepdim=True)

KL = lam * KL

Expand Down
117 changes: 73 additions & 44 deletions tests/test_unit/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
dirichlet_class_loss,
evidential_loss,
get_loss_func,
mcc_multiclass_loss,
normal_mve,
)

Expand Down Expand Up @@ -47,7 +48,7 @@ def test_get_regression_function(regression_function):
"""
args = SimpleNamespace(
loss_function=regression_function,
dataset_type='regression',
dataset_type="regression",
)
assert get_loss_func(args)

Expand All @@ -58,7 +59,7 @@ def test_get_class_function(classification_function):
"""
args = SimpleNamespace(
loss_function=classification_function,
dataset_type='classification',
dataset_type="classification",
)
assert get_loss_func(args)

Expand All @@ -69,7 +70,7 @@ def test_get_multiclass_function(multiclass_function):
"""
args = SimpleNamespace(
loss_function=multiclass_function,
dataset_type='multiclass',
dataset_type="multiclass",
)
assert get_loss_func(args)

Expand All @@ -80,7 +81,7 @@ def test_get_spectra_function(spectra_function):
"""
args = SimpleNamespace(
loss_function=spectra_function,
dataset_type='spectra',
dataset_type="spectra",
)
assert get_loss_func(args)

Expand All @@ -90,9 +91,7 @@ def test_get_unsupported_function(dataset_type):
Tests the error triggering for unsupported loss functions in get_loss_func.
"""
with pytest.raises(ValueError):
args = SimpleNamespace(
dataset_type=dataset_type, loss_function="dummy_loss"
)
args = SimpleNamespace(dataset_type=dataset_type, loss_function="dummy_loss")
get_loss_func(args=args)


Expand All @@ -104,23 +103,23 @@ def test_get_unsupported_function(dataset_type):
torch.zeros([2, 2], dtype=float),
torch.zeros([2, 2], dtype=bool),
torch.zeros([2, 2], dtype=bool),
15
15,
),
(
torch.tensor([[-3, 2], [1, -1]], dtype=float),
torch.zeros([2, 2], dtype=float),
torch.zeros([2, 2], dtype=bool),
torch.ones([2, 2], dtype=bool),
10
10,
),
(
torch.tensor([[-3, 2], [1, -1]], dtype=float),
torch.zeros([2, 2], dtype=float),
torch.ones([2, 2], dtype=bool),
torch.zeros([2, 2], dtype=bool),
5
5,
),
]
],
)
def test_bounded_mse(preds, targets, lt_targets, gt_targets, mse):
"""
Expand All @@ -132,11 +131,13 @@ def test_bounded_mse(preds, targets, lt_targets, gt_targets, mse):

@pytest.mark.parametrize(
"preds,targets,likelihood",
[(
torch.tensor([[0, 1]], dtype=float),
torch.zeros([1, 1], dtype=float),
[[0.3989]],
)]
[
(
torch.tensor([[0, 1]], dtype=float),
torch.zeros([1, 1], dtype=float),
[[0.3989]],
)
],
)
def test_mve(preds, targets, likelihood):
"""
Expand All @@ -150,19 +151,9 @@ def test_mve(preds, targets, likelihood):
@pytest.mark.parametrize(
"alphas,target_labels,lam,expected_loss",
[
(
torch.tensor([[2, 2]], dtype=float),
torch.ones([1, 1], dtype=float),
0,
[[0.6]]
),
(
torch.tensor([[2, 2]], dtype=float),
torch.ones([1, 1], dtype=float),
0.2,
[[0.63862943]]
)
]
(torch.tensor([[2, 2]], dtype=float), torch.ones([1, 1], dtype=float), 0, [[0.6]]),
(torch.tensor([[2, 2]], dtype=float), torch.ones([1, 1], dtype=float), 0.2, [[0.63862943]]),
],
)
def test_dirichlet(alphas, target_labels, lam, expected_loss):
"""
Expand All @@ -181,7 +172,7 @@ def test_dirichlet(alphas, target_labels, lam, expected_loss):
torch.ones([1, 1], dtype=float),
torch.ones([1, 1], dtype=float),
),
]
],
)
def test_dirichlet_wrong_dimensions(alphas, target_labels):
"""
Expand All @@ -195,19 +186,9 @@ def test_dirichlet_wrong_dimensions(alphas, target_labels):
@pytest.mark.parametrize(
"alphas,targets,lam,expected_loss",
[
(
torch.tensor([[2, 2, 2, 2]], dtype=float),
torch.ones([1, 1], dtype=float),
0,
[[1.56893861]]
),
(
torch.tensor([[2, 2, 2, 2]], dtype=float),
torch.ones([1, 1], dtype=float),
0.2,
[[2.768938541]]
)
]
(torch.tensor([[2, 2, 2, 2]], dtype=float), torch.ones([1, 1], dtype=float), 0, [[1.56893861]]),
(torch.tensor([[2, 2, 2, 2]], dtype=float), torch.ones([1, 1], dtype=float), 0.2, [[2.768938541]]),
],
)
def test_evidential(alphas, targets, lam, expected_loss):
"""
Expand All @@ -226,7 +207,7 @@ def test_evidential(alphas, targets, lam, expected_loss):
torch.ones([2, 2], dtype=float),
torch.ones([2, 2], dtype=float),
),
]
],
)
def test_evidential_wrong_dimensions(alphas, targets):
"""
Expand All @@ -235,3 +216,51 @@ def test_evidential_wrong_dimensions(alphas, targets):
"""
with pytest.raises(RuntimeError):
evidential_loss(alphas, targets)


@pytest.mark.parametrize(
"predictions,targets,data_weights,mask,expected_loss",
[
(
torch.tensor([[0.2, 0.7, 0.1], [0.8, 0.1, 0.1], [0.2, 0.3, 0.5]], dtype=float),
torch.tensor([0, 0, 0], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.0,
),
(
torch.tensor([[0.8, 0.1, 0.1], [0.7, 0.2, 0.1], [0.6, 0.3, 0.1]], dtype=float),
torch.tensor([1, 2, 0], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.0,
),
(
torch.tensor([[0.2, 0.7, 0.1], [0.8, 0.1, 0.1], [0.2, 0.3, 0.5]], dtype=float),
torch.tensor([2, 0, 2], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.6123724356957946,
),
(
torch.tensor([[0.2, 0.7, 0.1], [0.8, 0.1, 0.1], [0.2, 0.3, 0.5]], dtype=float),
torch.tensor([2, 0, 2], dtype=int),
torch.tensor([[0.5], [1], [1.5]], dtype=float),
torch.tensor([True, True, True], dtype=bool),
0.7462025072446364,
),
(
torch.tensor([[0.2, 0.7, 0.1], [0.8, 0.1, 0.1], [0.2, 0.3, 0.5]], dtype=float),
torch.tensor([2, 0, 2], dtype=int),
torch.tensor([[1], [1], [1]], dtype=float),
torch.tensor([False, True, True], dtype=bool),
1.0,
),
],
)
def test_multiclass_mcc(predictions, targets, data_weights, mask, expected_loss):
"""
Test the multiclass MCC loss function by comparing to sklearn's results.
"""
loss = mcc_multiclass_loss(predictions, targets, data_weights, mask)
np.testing.assert_almost_equal(loss.item(), expected_loss)

0 comments on commit 359e660

Please sign in to comment.