Skip to content

Commit

Permalink
[OpenCL] Registers ApplyRMSProp and ApplyCenteredRMSProp (tensorflow#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffmak authored and Luke Iwanski committed Jun 12, 2017
1 parent b58839f commit 909577d
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions tensorflow/core/kernels/training_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,36 @@ struct ApplyRMSProp<CPUDevice, T> {
}
};

#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct ApplyRMSProp<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstScalar rho,
typename TTypes<T>::ConstScalar momentum,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad) {
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::array<int, 1> rank1{1};
#else
Eigen::IndexList<Eigen::type2index<1> > rank1;
#endif
const int size = grad.dimension(0);
Eigen::array<int, 1> broadcast_dim{size};
const auto one = static_cast<T>(1.0);
ms.device(d) = ms +
(rho.constant(one) - rho).reshape(rank1).broadcast(broadcast_dim) *
(grad.square() - ms);
mom.device(d) =
mom * momentum.reshape(rank1).broadcast(broadcast_dim) +
lr.reshape(rank1).broadcast(broadcast_dim) * grad /
((epsilon.reshape(rank1).broadcast(broadcast_dim) + ms).sqrt());
var.device(d) -= mom;
}
};
#endif // TENSORFLOW_USE_SYCL

template <typename T>
struct ApplyCenteredRMSProp<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
Expand All @@ -375,6 +405,37 @@ struct ApplyCenteredRMSProp<CPUDevice, T> {
}
};

#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct ApplyCenteredRMSProp<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
typename TTypes<T>::Flat mom,
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstScalar rho,
typename TTypes<T>::ConstScalar momentum,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad) {
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::array<int, 1> rank1{1};
#else
Eigen::IndexList<Eigen::type2index<1> > rank1;
#endif
const int size = grad.dimension(0);
Eigen::array<int, 1> broadcast_dim{size};
const auto one = static_cast<T>(1.0);
const auto one_minus_rho =
(rho.constant(one) - rho).reshape(rank1).broadcast(broadcast_dim);
ms.device(d) = ms + one_minus_rho * (grad.square() - ms);
mg.device(d) = mg + one_minus_rho * (grad - mg);
auto denom = (ms - mg.square()) + epsilon.reshape(rank1).broadcast(broadcast_dim);
mom.device(d) = mom * momentum.reshape(rank1).broadcast(broadcast_dim) +
lr.reshape(rank1).broadcast(broadcast_dim) * grad / denom.sqrt();
var.device(d) -= mom;
}
};
#endif // TENSORFLOW_USE_SYCL

} // namespace functor


Expand Down Expand Up @@ -2950,6 +3011,15 @@ REGISTER_KERNELS(GPU, Eigen::half);
REGISTER_KERNELS(GPU, float);
REGISTER_KERNELS(GPU, double);
#endif

#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);

TF_CALL_float(REGISTER_SYCL_KERNELS);
TF_CALL_double(REGISTER_SYCL_KERNELS);
#undef REGISTER_SYCL_KERNELS
#endif

#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS

Expand Down

0 comments on commit 909577d

Please sign in to comment.