Skip to content

Commit

Permalink
Register batch normalization kernels for OpenCL (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
ville-k authored and Luke committed Apr 10, 2017
1 parent 76ecf13 commit 608336f
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tensorflow/core/kernels/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL

template <typename Device, typename T>
class BatchNormOp : public OpKernel {
Expand Down Expand Up @@ -201,6 +204,18 @@ TF_CALL_float(REGISTER_GPU_KERNEL);

#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
.Device(DEVICE_SYCL) \
.TypeConstraint<T>("T"), \
BatchNormOp<SYCLDevice, T>);

TF_CALL_float(REGISTER_KERNEL);
TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL

#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
.Device(DEVICE_CPU) \
Expand Down Expand Up @@ -248,4 +263,17 @@ TF_CALL_float(REGISTER_GPU_KERNEL);

#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
.Device(DEVICE_SYCL) \
.TypeConstraint<T>("T"), \
BatchNormGradOp<SYCLDevice, T>);

TF_CALL_float(REGISTER_KERNEL);
TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL

#endif // TENSORFLOW_USE_SYCL

} // namespace tensorflow

0 comments on commit 608336f

Please sign in to comment.