Skip to content

Commit

Permalink
Merge pull request tensorflow#526 from ROCmSoftwarePlatform/disable_q…
Browse files Browse the repository at this point in the history
…int8_pooling

Disable qint8 forward pooling on ROCm properly.
  • Loading branch information
whchung committed Jul 1, 2019
2 parents 970a686 + bce502e commit 3784582
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ __global__ void MaxPoolForwardNCHW(
}
}

#if GOOGLE_CUDA
// The parameters for MaxPoolForwardNoMaskKernel_NCHW_VECT_C are the same as for
// MaxPoolForwardNCHW above, except that mask is not supported, and each
// element of the input and output contains 4 adjacent channel values for
Expand Down Expand Up @@ -130,18 +131,13 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int idx = (c * height + h) * width + w;
#if GOOGLE_CUDA
maxval = __vmaxs4(maxval, bottom_data_n[idx]);
#elif TENSORFLOW_USE_ROCM
// ROCM TODO properly implement this function with corresponding GCN
// instruction
maxval = maxval;
#endif
}
}
top_data[index] = maxval;
}
}
#endif // GOOGLE_CUDA

template <bool propagate_nans, typename dtype>
__global__ void MaxPoolForwardNHWC(
Expand Down Expand Up @@ -387,6 +383,7 @@ __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,

namespace functor {

#if GOOGLE_CUDA
// Note: channels is the outer channels (dim 1) which has already been
// divided by 4.
bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()(
Expand All @@ -406,6 +403,7 @@ bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()(
pad_t, pad_l, top_data));
return d.ok();
}
#endif // GOOGLE_CUDA

template <typename T>
bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/pooling_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ template <>
struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output) {
#if GOOGLE_CUDA
bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()(
reinterpret_cast<const int32*>(input.flat<qint8>().data()),
params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols,
Expand All @@ -286,6 +287,11 @@ struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
context->SetStatus(errors::Internal(
"Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C"));
}
#else
// ROCm TODO: add support __vmaxs4 on ROCm
context->SetStatus(errors::Internal(
"Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C"));
#endif
}
};
#endif
Expand Down

0 comments on commit 3784582

Please sign in to comment.