Skip to content

Commit

Permalink
Rolling back PR tensorflow#44894 because it caused type check errors …
Browse files Browse the repository at this point in the history
…(got bfloat16 when expecting float).

PiperOrigin-RevId: 346214899
Change-Id: I291592f14a7e3b087e1c26d29d7ecdaef4bc2fed
  • Loading branch information
penpornk authored and tensorflower-gardener committed Dec 8, 2020
1 parent 4764f16 commit bb59188
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 49 deletions.
3 changes: 0 additions & 3 deletions tensorflow/core/kernels/sparse_xent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#define EIGEN_USE_THREADS

#include "tensorflow/core/kernels/sparse_xent_op.h"

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down Expand Up @@ -123,8 +122,6 @@ REGISTER(CPU, float, int32)
REGISTER(CPU, float, int64)
REGISTER(CPU, double, int32)
REGISTER(CPU, double, int64)
REGISTER(CPU, bfloat16, int32)
REGISTER(CPU, bfloat16, int64)
REGISTER(CPU, Eigen::half, int32)
REGISTER(CPU, Eigen::half, int64)

Expand Down
57 changes: 28 additions & 29 deletions tensorflow/core/kernels/sparse_xent_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ limitations under the License.

namespace tensorflow {

static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
static Graph* SparseXent(int batch_size, int num_classes) {
Graph* g = new Graph(OpRegistry::Global());
Tensor logits(value_type, TensorShape({batch_size, num_classes}));
Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes}));
logits.flat<float>().setRandom();
Tensor labels(DT_INT64, TensorShape({batch_size}));
std::random_device rd;
Expand All @@ -41,45 +41,44 @@ static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
return g;
}

#define BM_SparseXentDev(BATCH, CLASS, DEVICE, DTYPE) \
static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE( \
#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \
static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS, DTYPE), \
test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
CLASS); \
} \
BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE);

#define BM_SPARSE_XENT_DEV_CPU(DTYPE) \
BM_SparseXentDev(8, 1000000, cpu, DTYPE); \
BM_SparseXentDev(16, 10000, cpu, DTYPE); \
BM_SparseXentDev(16, 100000, cpu, DTYPE); \
BM_SparseXentDev(32, 10000, cpu, DTYPE); \
BM_SparseXentDev(32, 100000, cpu, DTYPE); \
BM_SparseXentDev(64, 10000, cpu, DTYPE); \
BM_SparseXentDev(64, 100000, cpu, DTYPE);

// CPU
BM_SPARSE_XENT_DEV_CPU(DT_FLOAT);
BM_SPARSE_XENT_DEV_CPU(DT_BFLOAT16);
BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE);

/// The representative tests for ptb_word on GPU
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
BM_SparseXentDev(8, 1000000, gpu, DT_FLOAT);
BM_SparseXentDev(8, 1000000, gpu);

BM_SparseXentDev(16, 10000, gpu, DT_FLOAT);
BM_SparseXentDev(16, 30000, gpu, DT_FLOAT);
BM_SparseXentDev(16, 100000, gpu, DT_FLOAT);
BM_SparseXentDev(16, 10000, gpu);
BM_SparseXentDev(16, 30000, gpu);
BM_SparseXentDev(16, 100000, gpu);

BM_SparseXentDev(32, 10000, gpu, DT_FLOAT);
BM_SparseXentDev(32, 30000, gpu, DT_FLOAT);
BM_SparseXentDev(32, 100000, gpu, DT_FLOAT);
BM_SparseXentDev(32, 10000, gpu);
BM_SparseXentDev(32, 30000, gpu);
BM_SparseXentDev(32, 100000, gpu);

BM_SparseXentDev(64, 10000, gpu, DT_FLOAT);
BM_SparseXentDev(64, 30000, gpu, DT_FLOAT);
BM_SparseXentDev(64, 100000, gpu, DT_FLOAT);
BM_SparseXentDev(64, 10000, gpu);
BM_SparseXentDev(64, 30000, gpu);
BM_SparseXentDev(64, 100000, gpu);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// CPU
BM_SparseXentDev(8, 1000000, cpu);

BM_SparseXentDev(16, 10000, cpu);
BM_SparseXentDev(16, 100000, cpu);

BM_SparseXentDev(32, 10000, cpu);
BM_SparseXentDev(32, 100000, cpu);

BM_SparseXentDev(64, 10000, cpu);
BM_SparseXentDev(64, 100000, cpu);

} // end namespace tensorflow
17 changes: 0 additions & 17 deletions tensorflow/python/kernel_tests/sparse_xent_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,23 +182,6 @@ def testDouble(self):
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
np.array([0, 3]).astype(label_dtype))

def testBfloat16(self):
for label_dtype in np.int32, np.int64:
np_features = np.array([[1., 1., 1., 1.], [1., 2., 3.,
4.]]).astype(np.float32)
np_labels = np.array([0, 3]).astype(label_dtype)
np_loss, np_backprop = self._npXent(np_features, np_labels)

np_features_bf16 = math_ops.cast(np_features, dtypes.bfloat16)
np_loss_bf16 = math_ops.cast(np_loss, dtypes.bfloat16)
np_backprop_bf16 = math_ops.cast(np_backprop, dtypes.bfloat16)
with self.cached_session(use_gpu=False):
loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
np_features_bf16, np_labels)
tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllCloseAccordingToType(np_loss_bf16, tf_loss)
self.assertAllCloseAccordingToType(np_backprop_bf16, tf_backprop)

def testHalf(self):
for label_dtype in np.int32, np.int64:
self._testXent(
Expand Down

0 comments on commit bb59188

Please sign in to comment.