Skip to content

Commit

Permalink
Merge pull request tensorflow#22231 from MichaelKonobeev:sparse-xent-…
Browse files Browse the repository at this point in the history
…op-hessian

PiperOrigin-RevId: 260802377
  • Loading branch information
tensorflower-gardener authored and mknbv committed Aug 16, 2019
1 parent b501738 commit b83cc0a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 42 deletions.
2 changes: 1 addition & 1 deletion tensorflow/python/eager/pywrap_tfe_src.cc
Expand Up @@ -2371,7 +2371,6 @@ bool OpGradientDoesntRequireInputIndices(
{"Relu6", {true, {}}},
{"Elu", {true, {}}},
{"Selu", {true, {}}},
{"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
{"Neg", {true, {}}},
{"Inv", {true, {}}},
{"Reciprocal", {true, {}}},
Expand All @@ -2389,6 +2388,7 @@ bool OpGradientDoesntRequireInputIndices(

// Ops that don't require a subset of inputs.
{"FusedBatchNorm", {false, {2}}},
{"SparseSoftmaxCrossEntropyWithLogits", {false, {1}}},
});

auto it = m->find(op_name);
Expand Down
57 changes: 39 additions & 18 deletions tensorflow/python/kernel_tests/sparse_xent_op_test.py
Expand Up @@ -24,6 +24,7 @@
import numpy as np

from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
Expand All @@ -36,9 +37,7 @@
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import app
from tensorflow.python.platform import test
Expand Down Expand Up @@ -192,7 +191,7 @@ def testEmpty(self):

@test_util.run_deprecated_v1
def testGradient(self):
with self.session(use_gpu=True):
with self.session(use_gpu=True) as sess:
l = constant_op.constant([3, 0, 1], name="l")
f = constant_op.constant(
[0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
Expand All @@ -202,26 +201,48 @@ def testGradient(self):
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])

# Check that no extra computation performed. When only first derivative is
# requested, second derivative must not be computed. So when there is no
# second derivative, there is no `BatchMatMul` op in the graph.
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertNotIn("BatchMatMul", op_names)
self.assertNotIn("BatchMatMulV2", op_names)

print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)

@test_util.run_deprecated_v1
def testSecondGradient(self):
images_placeholder = array_ops.placeholder(dtypes.float32, shape=(3, 2))
labels_placeholder = array_ops.placeholder(dtypes.int32, shape=(3))
weights = variables.Variable(random_ops.truncated_normal([2], stddev=1.0))
weights_with_zeros = array_ops.stack([array_ops.zeros([2]), weights],
axis=1)
logits = math_ops.matmul(images_placeholder, weights_with_zeros)
cross_entropy = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels_placeholder, logits=logits)
loss = math_ops.reduce_mean(cross_entropy)

# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
"explicitly disabled"):
_ = gradients_impl.hessians(loss, [weights])
with self.session() as sess:
l = constant_op.constant([3, 0, 1], name="l")
f = constant_op.constant(
[0.3, 0.4, 0.1, 1.2, 0.1, 1.9, 0.1, 0.7, 0.8, 0.2, 1.3, 1.3],
shape=[3, 4],
dtype=dtypes.float64,
name="f")
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")

gradients = gradients_impl.gradients(x, [f])[0]
err = gradient_checker.compute_gradient_error(f, [3, 4], gradients,
[3, 4])

# Check that second derivative is calculated.
# (it is equivalent to being `BatchMatMul` op in the graph because of
# implementation of xentropy grad)
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
if compat.forward_compatible(2019, 4, 25):
self.assertIn("BatchMatMulV2", op_names)
else:
self.assertIn("BatchMatMul", op_names)

print("cross entropy hessian err = ", err)
self.assertLess(err, 5e-8)

def _testHighDim(self, features, labels):
np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/kernel_tests/xent_op_test.py
Expand Up @@ -242,6 +242,7 @@ def testGradient(self):
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertNotIn("BatchMatMul", op_names)
self.assertNotIn("BatchMatMulV2", op_names)

print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
Expand Down
60 changes: 37 additions & 23 deletions tensorflow/python/ops/nn_grad.py
Expand Up @@ -513,6 +513,24 @@ def _BroadcastMul(vec, mat):
return vec * mat


def _IsZero(tensor):
"""Check if tensor contains only zeros.
Args:
tensor: tensor to check
Returns:
True if tensor contains only zeros and False otherwise
"""
if context.executing_eagerly():
# TODO(apassos) add an efficient way to detect eager zeros here.
return False
if tensor.op.type in ("ZerosLike", "Zeros"):
return True
const_fill_value = tensor_util.constant_value(tensor)
return const_fill_value is not None and (const_fill_value == 0).all()


@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
"""Gradient function for SoftmaxCrossEntropyWithLogits."""
Expand All @@ -524,18 +542,8 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)

def IsZero(g):
# Some introspection to check if the gradient is feeding zeros
if context.executing_eagerly():
# TODO(apassos) add an efficient way to detect eager zeros here.
return False
if g.op.type in ("ZerosLike", "Zeros"):
return True
const_fill_value = tensor_util.constant_value(g)
return const_fill_value is not None and (const_fill_value == 0).all()

logits = op.inputs[0]
if grad_grad is not None and not IsZero(grad_grad):
if grad_grad is not None and not _IsZero(grad_grad):
softmax = nn_ops.softmax(logits)

grad += ((grad_grad - array_ops.squeeze(
Expand All @@ -548,22 +556,28 @@ def IsZero(g):


@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
"""Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
# grad_0 is the backprop for cost, and we multiply it with the gradients
# grad_loss is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# grad_grad is the backprop for softmax gradient.
# There is no gradient for the labels
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
op.outputs[1],
message="Currently there is no way to take the second "
"derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
"implementation's interaction with tf.gradients()")
return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
# Second derivative is just softmax derivative w.r.t. logits.
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)

logits = op.inputs[0]
if grad_grad is not None and not _IsZero(grad_grad):
softmax = nn_ops.softmax(logits)

grad += ((grad_grad - array_ops.squeeze(
math_ops.matmul(
array_ops.expand_dims(grad_grad, 1),
array_ops.expand_dims(softmax, 2)),
axis=1)) * softmax)

return grad, None


@ops.RegisterGradient("Conv2D")
Expand Down

0 comments on commit b83cc0a

Please sign in to comment.