Skip to content

Commit

Permalink
[OpenCL] Fixes SYCL stateless random ops (tensorflow#57)
Browse files Browse the repository at this point in the history
* [OpenCL] Create all SYCL FillPhiloxRandom functors

Adds the two distributions currently not being instantiated in the SYCL
kernel registration. These are needed for StatelessRandomOps.

* [OpenCL] Registers SYCL StatelessRandomOps

Provides SYCL device kernels for the StatelessRandomOps.
  • Loading branch information
jwlawson authored and Luke Iwanski committed Jun 8, 2017
1 parent 94285da commit 71f5d40
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 26 deletions.
58 changes: 32 additions & 26 deletions tensorflow/core/kernels/random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -695,32 +695,38 @@ void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(

}

#define REGISTER(TYPE) \
template struct functor::FillPhiloxRandom< \
SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE> >; \
REGISTER_KERNEL_BUILDER( \
Name("RandomUniform") \
.Device(DEVICE_SYCL) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
PhiloxRandomOp<SYCLDevice, random::UniformDistribution< \
random::PhiloxRandom, TYPE> >); \
REGISTER_KERNEL_BUILDER( \
Name("RandomStandardNormal") \
.Device(DEVICE_SYCL) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
PhiloxRandomOp<SYCLDevice, random::NormalDistribution< \
random::PhiloxRandom, TYPE> >); \
REGISTER_KERNEL_BUILDER( \
Name("TruncatedNormal") \
.Device(DEVICE_SYCL) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
PhiloxRandomOp< \
SYCLDevice, \
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
#define REGISTER(TYPE) \
template struct functor::FillPhiloxRandom< \
SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \
template struct functor::FillPhiloxRandom< \
SYCLDevice, random::NormalDistribution<random::PhiloxRandom, TYPE>>; \
template struct functor::FillPhiloxRandom< \
SYCLDevice, \
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>; \
REGISTER_KERNEL_BUILDER( \
Name("RandomUniform") \
.Device(DEVICE_SYCL) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
PhiloxRandomOp<SYCLDevice, random::UniformDistribution< \
random::PhiloxRandom, TYPE>>); \
REGISTER_KERNEL_BUILDER( \
Name("RandomStandardNormal") \
.Device(DEVICE_SYCL) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
PhiloxRandomOp<SYCLDevice, \
random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
REGISTER_KERNEL_BUILDER( \
Name("TruncatedNormal") \
.Device(DEVICE_SYCL) \
.HostMemory("shape") \
.TypeConstraint<TYPE>("dtype"), \
PhiloxRandomOp< \
SYCLDevice, \
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);

#define REGISTER_INT(IntType) \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
Expand Down
38 changes: 38 additions & 0 deletions tensorflow/core/kernels/stateless_random_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ namespace tensorflow {

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
#ifdef TENSORFLOW_USE_SYCL
using SYCLDevice = Eigen::SyclDevice;
#endif // TENSORFLOW_USE_SYCL

namespace {

Expand Down Expand Up @@ -170,4 +173,39 @@ TF_CALL_double(REGISTER);

#endif // GOOGLE_CUDA

#ifdef TENSORFLOW_USE_SYCL
#define REGISTER(DEV, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniform") \
.Device(DEVICE_##DEV) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp<DEV##Device, random::UniformDistribution< \
random::PhiloxRandom, TYPE> >); \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomNormal") \
.Device(DEVICE_##DEV) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp<DEV##Device, random::NormalDistribution< \
random::PhiloxRandom, TYPE> >); \
REGISTER_KERNEL_BUILDER( \
Name("StatelessTruncatedNormal") \
.Device(DEVICE_##DEV) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp< \
DEV##Device, \
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);

#define REGISTER_SYCL_KERNELS(TYPE) REGISTER(SYCL, TYPE)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS)
#undef REGISTER_SYCL_KERNELS
#undef REGISTER
#endif // TENSORFLOW_USE_SYCL

} // namespace tensorflow

0 comments on commit 71f5d40

Please sign in to comment.