Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- added weighting for MSELoss + test
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Jan 8, 2018
1 parent 9ea69f8 commit b0d03fd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
30 changes: 30 additions & 0 deletions inferno/extensions/criteria/elementwise_measures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch.nn as nn
from torch.autograd import Variable
from ...utils.exceptions import assert_


class WeightedMSELoss(nn.Module):
NEGATIVE_CLASS_WEIGHT = 1.

def __init__(self, positive_class_weight=1., positive_class_value=1., size_average=True):
super(WeightedMSELoss, self).__init__()
assert_(positive_class_weight >= 0,
"Positive class weight can't be less than zero, got {}."
.format(positive_class_weight),
ValueError)
self.mse = nn.MSELoss(size_average=size_average)
self.positive_class_weight = positive_class_weight
self.positive_class_value = positive_class_value

def forward(self, input, target):
# Get a mask
positive_class_mask = target.data.eq(self.positive_class_value).type_as(target.data)
# Get differential weights (positive_weight - negative_weight,
# i.e. subtract 1, assuming the negative weight is gauged at 1)
weight_differential = (positive_class_mask
.mul_(self.positive_class_weight - self.NEGATIVE_CLASS_WEIGHT))
# Get final weight by adding weight differential to a tensor with negative weights
weights = weight_differential.add_(self.NEGATIVE_CLASS_WEIGHT)
# `weights` should be positive if NEGATIVE_CLASS_WEIGHT is not messed with.
sqrt_weights = Variable(weights.sqrt_(), requires_grad=False)
return self.mse(input * sqrt_weights, target * sqrt_weights)
20 changes: 20 additions & 0 deletions tests/extensions/criteria/elementwise_measures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
import inferno.extensions.criteria.elementwise_measures as em
import torch
from torch.autograd import Variable


class TestElementwiseMeasures(unittest.TestCase):
def test_weighted_mse_loss(self):
input = Variable(torch.zeros(10, 10))
target = Variable(torch.ones(10, 10))
loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target)
self.assertAlmostEqual(loss.data[0], 2., delta=1e-5)
target = Variable(torch.zeros(10, 10))
input = Variable(torch.ones(10, 10))
loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target)
self.assertAlmostEqual(loss.data[0], 1., delta=1e-5)


if __name__ == '__main__':
unittest.main()

0 comments on commit b0d03fd

Please sign in to comment.