Skip to content

Commit

Permalink
Fix for //tensorflow/python:stateful_random_ops_test:
Browse files Browse the repository at this point in the history
Move the thread counter into the global namespace
  • Loading branch information
ekuznetsov139 committed Dec 24, 2019
1 parent f7b2819 commit eee5851
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc
Expand Up @@ -26,12 +26,13 @@ limitations under the License.
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"

// ROCm hipMemcpyToSymbol can only see this variable if it's in global namespace
__device__ int tensorflow_philox_thread_counter;

namespace tensorflow {

using random::PhiloxRandom;

__device__ int thread_counter;

template <typename Distribution>
__global__ void FillKernel(
Distribution dist, int64 state_size, int64 output_size,
Expand All @@ -50,7 +51,7 @@ __global__ void FillKernel(
.Run(*philox, output_data, output_size, dist);
// The last thread updates the state.
auto total_thread_count = gridDim.x * blockDim.x;
auto old_counter_value = atomicAdd(&thread_counter, 1);
auto old_counter_value = atomicAdd(&tensorflow_philox_thread_counter, 1);
if (old_counter_value == total_thread_count - 1) {
UpdateMemWithPhiloxRandom(*philox, output_size, state_data);
}
Expand All @@ -64,6 +65,8 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
int64 output_size = arg->output_size;
int64 alg_tag_skip = arg->alg_tag_skip;
Tensor* state_tensor = arg->state_tensor;
OP_REQUIRES(ctx, state_tensor != 0,
errors::InvalidArgument("Null state tensor"));
OP_REQUIRES(
ctx, alg_tag_skip == 0,
errors::InvalidArgument(
Expand All @@ -80,9 +83,12 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
GetGpuLaunchConfig(work_element_count, d, FillKernel<Distribution>, 0, 0);
int zero = 0;
#if GOOGLE_CUDA
cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int));
cudaMemcpyToSymbol(tensorflow_philox_thread_counter, &zero, sizeof(int));
#else // TENSORFLOW_USE_ROCM
hipMemcpyToSymbol(HIP_SYMBOL(thread_counter), &zero, sizeof(int));
int status = hipMemcpyToSymbol(HIP_SYMBOL(tensorflow_philox_thread_counter),
&zero, sizeof(int));
OP_REQUIRES(ctx, status == hipSuccess,
errors::InvalidArgument("hipMemcpyToSymbol failed"));
#endif
TF_CHECK_OK(GpuLaunchKernel(
FillKernel<Distribution>, cfg.block_count, cfg.thread_per_block, 0,
Expand Down

0 comments on commit eee5851

Please sign in to comment.