Skip to content

Commit

Permalink
Support keras squared_hinge and sparse_categorial_crossentropy (#1865)
Browse files Browse the repository at this point in the history
* add squared hinge loss

* classNLL support prob as input

* add python api

* meet code review

* fix python tests

* more doc

* more doc
  • Loading branch information
yangw1234 committed Nov 14, 2017
1 parent 71498f0 commit 6b23f56
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 42 deletions.
20 changes: 12 additions & 8 deletions docs/docs/APIGuide/Losses.md
Expand Up @@ -263,24 +263,25 @@ Gives the gradInput,

**Scala:**
```scala
val criterion = ClassNLLCriterion(weights = null, sizeAverage = true)
val criterion = ClassNLLCriterion(weights = null, sizeAverage = true, logProbAsInput=true)
```
**Python:**
```python
criterion = ClassNLLCriterion(weights=None, size_average=True)
criterion = ClassNLLCriterion(weights=None, size_average=True, logProbAsInput=true)
```

The negative log likelihood criterion. It is useful to train a classification problem with n
classes. If provided, the optional argument weights should be a 1D Tensor assigning weight to
each of the classes. This is particularly useful when you have an unbalanced training set.

The input given through a `forward()` is expected to contain log-probabilities of each class:
input has to be a 1D Tensor of size `n`. Obtaining log-probabilities in a neural network is easily
achieved by adding a `LogSoftMax` layer in the last layer of your neural network. You may use
The input given through a `forward()` is expected to contain log-probabilities/probabilities of each class:
input has to be a 1D Tensor of size `n`. Obtaining log-probabilities/probabilities in a neural network is easily
achieved by adding a `LogSoftMax`/`SoftMax` layer in the last layer of your neural network. You may use
`CrossEntropyCriterion` instead, if you prefer not to add an extra layer to your network. This
criterion expects a class index (1 to the number of class) as target when calling
`forward(input, target)` and `backward(input, target)`.

In the log-probabilities case,
The loss can be described as:
`loss(x, class) = -x[class]`
or in the case of the weights argument it is specified as follows:
Expand All @@ -300,6 +301,8 @@ Parameters:
* `weights` weights of each element of the input
* `sizeAverage` A boolean indicating whether normalizing by the number of elements in the input.
Default: true
* `logProbAsInput` indicating whether to accept log-probabilities or probabilities as input. True means accepting
log-probabilities as input.

**Scala example:**
```scala
Expand Down Expand Up @@ -799,16 +802,17 @@ creating: createHingeEmbeddingCriterion

**Scala:**
```scala
criterion = MarginCriterion(margin=1.0, sizeAverage=true)
criterion = MarginCriterion(margin=1.0, sizeAverage=true, squared=false)
```
**Python:**
```python
criterion = MarginCriterion(margin=1.0, sizeAverage=true, bigdl_type="float")
criterion = MarginCriterion(margin=1.0, sizeAverage=True, squared=False, bigdl_type="float")
```

Creates a criterion that optimizes a two-class classification hinge loss (margin-based loss) between input x (a Tensor of dimension 1) and output y.
Creates a criterion that optimizes a two-class classification (squared) hinge loss (margin-based loss) between input x (a Tensor of dimension 1) and output y.
* `margin` if unspecified, is by default 1.
* `sizeAverage` whether to average the loss, is by default true
* `squared` whether to calculate the squared hinge loss

**Scala example:**
```scala
Expand Down
33 changes: 22 additions & 11 deletions pyspark/bigdl/nn/criterion.py
Expand Up @@ -103,13 +103,14 @@ class ClassNLLCriterion(Criterion):
classes. If provided, the optional argument weights should be a 1D Tensor assigning weight to
each of the classes. This is particularly useful when you have an unbalanced training set.
The input given through a forward() is expected to contain log-probabilities of each class:
input has to be a 1D Tensor of size n. Obtaining log-probabilities in a neural network is easily
achieved by adding a LogSoftMax layer in the last layer of your neural network. You may use
CrossEntropyCriterion instead, if you prefer not to add an extra layer to your network. This
criterion expects a class index (1 to the number of class) as target when calling
forward(input, target) and backward(input, target).
The input given through a forward() is expected to contain log-probabilities/probabilities of
each class: input has to be a 1D Tensor of size n. Obtaining log-probabilities/probabilities
in a neural network is easily achieved by adding a LogSoftMax/SoftMax layer in the last layer
of your neural network. You may use CrossEntropyCriterion instead, if you prefer not to add an
extra layer to your network. This criterion expects a class index (1 to the number of class) as
target when calling forward(input, target) and backward(input, target).
In the log-probabilities case,
The loss can be described as:
loss(x, class) = -x[class]
or in the case of the weights argument it is specified as follows:
Expand All @@ -124,14 +125,18 @@ class ClassNLLCriterion(Criterion):
By default, the losses are averaged over observations for each minibatch. However, if the field
sizeAverage is set to false, the losses are instead summed for each minibatch.
In particular, when weights=None, size_average=True and logProbAsInput=False, this is same as
`sparse_categorical_crossentropy` loss in keras.
:param weights: weights of each class
:param size_average: whether to average or not
:param logProbAsInput: indicating whether to accept log-probabilities or probabilities as input.
>>> np.random.seed(123)
>>> weights = np.random.uniform(0, 1, (2,)).astype("float32")
>>> classNLLCriterion = ClassNLLCriterion(weights,True)
>>> classNLLCriterion = ClassNLLCriterion(weights, True, True)
creating: createClassNLLCriterion
>>> classNLLCriterion = ClassNLLCriterion()
creating: createClassNLLCriterion
Expand All @@ -140,10 +145,11 @@ class ClassNLLCriterion(Criterion):
def __init__(self,
weights=None,
size_average=True,
logProbAsInput=True,
bigdl_type="float"):
super(ClassNLLCriterion, self).__init__(None, bigdl_type,
JTensor.from_ndarray(weights),
size_average)
size_average, logProbAsInput)


class MSECriterion(Criterion):
Expand Down Expand Up @@ -344,22 +350,27 @@ class MarginCriterion(Criterion):
Creates a criterion that optimizes a two-class classification hinge loss (margin-based loss)
between input x (a Tensor of dimension 1) and output y.
When margin = 1, size_average = True and squared = False, this is the same as hinge loss in keras;
When margin = 1, size_average = False and squared = True, this is the same as squared_hinge loss in keras.
:param margin: if unspecified, is by default 1.
:param size_average: size average in a mini-batch
:param squared: whether to calculate the squared hinge loss
>>> marginCriterion = MarginCriterion(1e-5, True)
>>> marginCriterion = MarginCriterion(1e-5, True, False)
creating: createMarginCriterion
'''

def __init__(self,
margin=1.0,
size_average=True,
squared=False,
bigdl_type="float"):
super(MarginCriterion, self).__init__(None, bigdl_type,
margin,
size_average)
size_average,
squared)


class MarginRankingCriterion(Criterion):
Expand Down
Expand Up @@ -30,17 +30,19 @@ import com.intel.analytics.bigdl.utils.Engine
* classes. If provided, the optional argument weights should be a 1D Tensor assigning weight to
* each of the classes. This is particularly useful when you have an unbalanced training set.
*
* The input given through a forward() is expected to contain log-probabilities of each class:
* input has to be a 1D Tensor of size n. Obtaining log-probabilities in a neural network is easily
* achieved by adding a LogSoftMax layer in the last layer of your neural network. You may use
* CrossEntropyCriterion instead, if you prefer not to add an extra layer to your network. This
* criterion expects a class index (1 to the number of class) as target when calling
* forward(input, target) and backward(input, target).
* The input given through a forward() is expected to contain log-probabilities/probabilities of
* each class: input has to be a 1D Tensor of size n. Obtaining log-probabilities/probabilities
* in a neural network is easily achieved by adding a LogSoftMax/SoftMax layer in the last layer
* of your neural network. You may use CrossEntropyCriterion instead, if you prefer not to add
* an extra layer to your network. This criterion expects a class index (1 to the number of class)
* as target when calling forward(input, target) and backward(input, target).
*
* In the log-probabilities case,
* The loss can be described as:
* loss(x, class) = -x[class]
* or in the case of the weights argument it is specified as follows:
* loss(x, class) = -weights[class] * x[class]
*
* Due to the behaviour of the backend code, it is necessary to set sizeAverage to false when
* calculating losses in non-batch mode.
*
Expand All @@ -51,14 +53,19 @@ import com.intel.analytics.bigdl.utils.Engine
* By default, the losses are averaged over observations for each minibatch. However, if the field
* sizeAverage is set to false, the losses are instead summed for each minibatch.
*
* In particular, when weights=None, size_average=True and logProbAsInput=False, this is same as
* `sparse_categorical_crossentropy` loss in keras.
*
* @param weights weights of each element of the input
* @param sizeAverage size average of batch
* @param logProbAsInput indicating whether to accept log-probabilities or probabilities as input.
* True means accepting log-probabilities as input.
* @param ev numeric operator
* @tparam T numeric type
*/
@SerialVersionUID(- 8696382776046599502L)
class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
(weights: Tensor[T] = null, sizeAverage: Boolean = true)
(weights: Tensor[T] = null, sizeAverage: Boolean = true, logProbAsInput: Boolean = true)
(implicit ev: TensorNumeric[T]) extends TensorCriterion[T] {
private var total_weight = ev.fromType[Int](0)
if (weights != null) require(weights.dim() == 1,
Expand All @@ -70,6 +77,12 @@ class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
@transient
private var resultsBackward: Array[Future[_]] = null

private val epsilon: T = ev.fromType(1e-8)

private val oneMinusEpsilon: T = ev.minus(ev.one, epsilon)



override def updateOutput(input: Tensor[T], target: Tensor[T]): T = {
require(input.dim() == 1 || input.dim() == 2,
"ClassNLLCriterion: " +
Expand All @@ -85,7 +98,15 @@ class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
s"curTarget ${curTarget} is out of range, should be 1 to ${nClasses}")
total_weight = if (weights != null) weights(Array(curTarget)) else ev.fromType[Int](1)
output = if (curTarget == -1) ev.zero
else ev.times(ev.negative(input.valueAt(curTarget)), total_weight)
else {
if (!logProbAsInput) {
val clipped = ev.clip(input.valueAt(curTarget), epsilon, oneMinusEpsilon)
ev.times(ev.negative(ev.log(clipped)), total_weight)
} else {
ev.times(ev.negative(input.valueAt(curTarget)), total_weight)
}
}

} else if (input.dim() == 2) {
val batchSize = input.size(1)
val targetSize = target.size()
Expand All @@ -111,7 +132,13 @@ class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
if (curTarget == -1) (ev.zero, ev.one)
else {
val curWeight = if (weights != null) weights.valueAt(curTarget) else ev.fromType[Int](1)
(ev.times(input.valueAt(_i, curTarget), curWeight), curWeight)
if (!logProbAsInput) {
val clipped = ev.clip(input.valueAt(_i, curTarget), epsilon, oneMinusEpsilon)
(ev.times(ev.log(clipped), curWeight), curWeight)
} else {
(ev.times(input.valueAt(_i, curTarget), curWeight), curWeight)
}

}
})
i += 1
Expand Down Expand Up @@ -152,6 +179,11 @@ class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
else ev.fromType[Int](-1))
if (sizeAverage) gradInput.setValue(curTarget, ev.divide(gradInput.valueAt(curTarget),
total_weight))
if (!logProbAsInput) {
val clipped = ev.clip(input.valueAt(curTarget), epsilon, oneMinusEpsilon)
gradInput.setValue(curTarget,
ev.times(gradInput.valueAt(curTarget), ev.inv(clipped)))
}
}
else if (input.dim() == 2) {
val batchSize = input.size(1)
Expand All @@ -172,6 +204,11 @@ class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
else ev.fromType[Int](-1))
if (sizeAverage) gradInput.setValue(_i, curTarget, ev.divide(gradInput.valueAt(_i,
curTarget), total_weight))
if (!logProbAsInput) {
val clipped = ev.clip(input.valueAt(_i, curTarget), epsilon, oneMinusEpsilon)
gradInput.setValue(_i, curTarget,
ev.times(gradInput.valueAt(_i, curTarget), ev.inv(clipped)))
}
}
})
i += 1
Expand All @@ -191,7 +228,8 @@ class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
object ClassNLLCriterion {
def apply[@specialized(Float, Double) T: ClassTag](
weights: Tensor[T] = null,
sizeAverage: Boolean = true)(implicit ev: TensorNumeric[T]) : ClassNLLCriterion[T] = {
new ClassNLLCriterion[T](weights, sizeAverage)
sizeAverage: Boolean = true,
logProbAsInput: Boolean = true)(implicit ev: TensorNumeric[T]) : ClassNLLCriterion[T] = {
new ClassNLLCriterion[T](weights, sizeAverage, logProbAsInput)
}
}
Expand Up @@ -22,16 +22,21 @@ import com.intel.analytics.bigdl.tensor.{DenseTensorApply, Tensor, TensorFunc4,
import scala.reflect.ClassTag

/**
* Creates a criterion that optimizes a two-class classification hinge loss (margin-based loss)
* between input x (a Tensor of dimension 1) and output y.
* Creates a criterion that optimizes a two-class classification (squared)
* hinge loss (margin-based loss) between input x (a Tensor of dimension 1) and output y.
*
* When margin = 1, sizeAverage = True and squared = False, this is the same as hinge loss in keras;
* When margin = 1, sizeAverage = False and squared = True, this is the same as squared_hinge loss
* in keras.
*
* @param margin if unspecified, is by default 1.
* @param sizeAverage whether to average the loss
* @param squared whether to calculate the squared hinge loss
*/

@SerialVersionUID( - 5028892499250398130L)
class MarginCriterion[@specialized(Float, Double) T: ClassTag]
(val margin: Double = 1.0, val sizeAverage: Boolean = true)
(val margin: Double = 1.0, val sizeAverage: Boolean = true, squared: Boolean = false)
(implicit ev: TensorNumeric[T]) extends TensorCriterion[T] {

override def updateOutput(input: Tensor[T], target: Tensor[T]): T = {
Expand All @@ -40,7 +45,13 @@ class MarginCriterion[@specialized(Float, Double) T: ClassTag]
val func = new TensorFunc4[T] {
override def apply(data1: Array[T], index1: Int, data2: Array[T], index2: Int): Unit = {
val z = ev.minus(ev.fromType(margin), ev.times(data1(index1), data2(index2)))
if (ev.isGreater(z, ev.fromType(0))) sum = ev.plus(sum, z)
if (ev.isGreater(z, ev.fromType(0))) {
if (squared) {
sum = ev.plus(sum, ev.times(z, z))
} else {
sum = ev.plus(sum, z)
}
}
}
}
DenseTensorApply.apply2[T](input, target, func)
Expand All @@ -49,15 +60,23 @@ class MarginCriterion[@specialized(Float, Double) T: ClassTag]
}

override def updateGradInput(input: Tensor[T], target: Tensor[T]): Tensor[T] = {
val norm = ev.fromType(if (sizeAverage) -1.0 / input.nElement() else 1.0)
val norm = ev.fromType(if (sizeAverage) -1.0 / input.nElement() else -1.0)
gradInput.resizeAs(input)

// todo: the performance of contiguous tensor should be optimized
val func = new TensorFunc6[T] {
override def apply (data1: Array[T], offset1: Int, data2: Array[T],
offset2: Int, data3: Array[T], offset3: Int): Unit = {
if (ev.isGreater(ev.fromType(margin), ev.times(data2(offset2), data3(offset3)))) {
data1(offset1) = ev.times(norm, data3(offset3))
if (squared) {
// dl/dx = -2y(1-xy)
data1(offset1) = ev.times(
ev.times(ev.times(ev.fromType(2), norm), data3(offset3)),
ev.minus(ev.fromType(margin),
ev.times(data2(offset2), data3(offset3))))
} else {
data1(offset1) = ev.times(norm, data3(offset3))
}
}
}
}
Expand Down Expand Up @@ -90,7 +109,8 @@ class MarginCriterion[@specialized(Float, Double) T: ClassTag]
object MarginCriterion {
def apply[@specialized(Float, Double) T: ClassTag](
margin: Double = 1.0,
sizeAverage: Boolean = true)(implicit ev: TensorNumeric[T]) : MarginCriterion[T] = {
new MarginCriterion[T](margin, sizeAverage)
sizeAverage: Boolean = true,
squared: Boolean = false)(implicit ev: TensorNumeric[T]) : MarginCriterion[T] = {
new MarginCriterion[T](margin, sizeAverage, squared)
}
}
Expand Up @@ -1419,10 +1419,10 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
}

def createClassNLLCriterion(weights: JTensor = null,
sizeAverage: Boolean = true)
sizeAverage: Boolean = true, logProbAsInput: Boolean = true)
: ClassNLLCriterion[T] = {
ClassNLLCriterion[T](if (weights == null) null else toTensor(weights),
sizeAverage)
sizeAverage, logProbAsInput)
}

def createMSECriterion: MSECriterion[T] = {
Expand Down Expand Up @@ -1470,10 +1470,10 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
}

def createMarginCriterion(margin: Double = 1.0,
sizeAverage: Boolean = true)
sizeAverage: Boolean = true, squared: Boolean = false)
: MarginCriterion[T] = {
MarginCriterion[T](margin,
sizeAverage)
sizeAverage, squared)
}

def createMarginRankingCriterion(margin: Double = 1.0,
Expand Down

0 comments on commit 6b23f56

Please sign in to comment.