Skip to content

Commit

Permalink
GPU impl. of hue adjustment op (tensorflow#6818)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkolod authored and ggfan committed Mar 14, 2017
1 parent 3549499 commit 10ab945
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 22 deletions.
17 changes: 14 additions & 3 deletions tensorflow/contrib/android/cmake/CMakeLists.txt
Expand Up @@ -32,18 +32,29 @@ set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
# Change to compile flags should be replicated into bazel build file
# LINT.IfChange
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -fno-rtti -fno-exceptions \
-fpic -O2 -mfpu=neon -DTF_LEAN_BINARY -msse4.1 \
-O2 -mfpu=neon -mfloat-abi=softfp -fPIE \
-DGOOGLE_PROTOBUF_NO_RTTI \
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
# LINT.ThenChange(//tensorflow/tensorflow.bzl)

set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
-Wl,--allow-multiple-definition \
-Wl,--whole-archive")
-Wl,--whole-archive \
-fPIE -pie -v")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD -DSTANDALONE_DEMO_LIB -Wno-narrowing")

file(GLOB tensorflow_inference_sources
${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc)
add_library(tensorflow_inference SHARED ${tensorflow_inference_sources})
file(GLOB java_api_native_sources
${TENSORFLOW_ROOT_DIR}/tensorflow/java/src/main/native/*.cc)

add_library(tensorflow_inference SHARED
${tensorflow_inference_sources}
${TENSORFLOW_ROOT_DIR}/tensorflow/c/tf_status_helper.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/checkpoint_reader.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/test_op.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/c_api.cc
${java_api_native_sources})

# Include libraries needed for hello-jni lib
target_link_libraries(tensorflow_inference
Expand Down
17 changes: 12 additions & 5 deletions tensorflow/contrib/android/cmake/build.gradle
@@ -1,16 +1,19 @@
apply plugin: 'com.android.library'

// TensorFlow repo root dir on local machine
def TF_SRC_DIR = projectDir.toString() + "/../../../.."

android {
compileSdkVersion 24
buildToolsVersion "24.0.2"
buildToolsVersion '25.0.0'

// for debugging native code purpose
publishNonDefault true

defaultConfig {
archivesBaseName = "Tensorflow-Android-Inference"
minSdkVersion 21
targetSdkVersion 21
minSdkVersion 23
targetSdkVersion 23
versionCode 1
versionName "1.0"
ndk {
Expand All @@ -25,7 +28,11 @@ android {
}
sourceSets {
main {
java.srcDirs = ["../java"]
java {
srcDir "${TF_SRC_DIR}/tensorflow/contrib/android/java"
srcDir "${TF_SRC_DIR}/tensorflow/java/src/main/java"
exclude '**/examples/**'
}
}
}

Expand Down Expand Up @@ -86,7 +93,7 @@ if(! Os.isFamily(Os.FAMILY_WINDOWS)) {
// just uncomment this line to use it:
// it can take long time to build by default
// it is disabled to avoid false first impression
// task.dependsOn buildTensorflow
task.dependsOn buildTensorflow
}
}
}
Expand Down
43 changes: 41 additions & 2 deletions tensorflow/core/kernels/adjust_hue_op.cc
@@ -1,5 +1,4 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Expand All @@ -12,16 +11,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif

#include <memory>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/adjust_hue_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

Expand Down Expand Up @@ -77,6 +84,7 @@ template <class Device>
class AdjustHueOp;

namespace internal {

// Helper function to convert a RGB color to H-and-V-range. H is in the range
// of [0, 6] instead of the normal [0, 1]
static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min,
Expand Down Expand Up @@ -185,6 +193,7 @@ static void hv_range_to_rgb(float h, float v_min, float v_max, float* r,
}
} // namespace internal


template <>
class AdjustHueOp<CPUDevice> : public AdjustHueOpBase {
public:
Expand Down Expand Up @@ -237,4 +246,34 @@ class AdjustHueOp<CPUDevice> : public AdjustHueOpBase {
REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_CPU),
AdjustHueOp<CPUDevice>);

#if GOOGLE_CUDA
template <>
class AdjustHueOp<GPUDevice> : public AdjustHueOpBase {
public:
explicit AdjustHueOp(OpKernelConstruction* context)
: AdjustHueOpBase(context) {}

virtual void DoCompute(OpKernelContext* context, const ComputeOptions& options) override {
const Tensor* input = options.input;
const Tensor* delta = options.delta;
Tensor* output = options.output;
const int64 number_of_elements = input->NumElements();
GPUDevice device = context->eigen_gpu_device();
const auto stream = device.stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
if (number_of_elements > 0) {
const float* input_data = input->flat<float>().data();
const float* delta_h = delta->flat<float>().data();
float* const output_data = output->flat<float>().data();
functor::AdjustHueGPU()(&device, number_of_elements, input_data, delta_h,
output_data);
}
}
};

REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_GPU), AdjustHueOp<GPUDevice>);

#endif

//} // namespace functor
} // namespace tensorflow
42 changes: 42 additions & 0 deletions tensorflow/core/kernels/adjust_hue_op.h
@@ -0,0 +1,42 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
#define _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H

#if GOOGLE_CUDA
#define EIGEN_USE_GPU

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;

namespace functor {

struct AdjustHueGPU {
void operator()(
GPUDevice* device,
const int64 number_of_elements,
const float* const input,
const float* const delta,
float* const output
);
};

} // namespace functor
} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
141 changes: 141 additions & 0 deletions tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc
@@ -0,0 +1,141 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/


#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/adjust_hue_op.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

namespace tensorflow {
namespace internal {

namespace {
typedef struct RgbTuple {
float r;
float g;
float b;
} RgbTuple;

typedef struct HsvTuple {
float h;
float s;
float v;
} HsvTuple;
} // anon namespace

__device__ HsvTuple rgb2hsv_cuda(const float r, const float g, const float b)
{
HsvTuple tuple;
const float M = fmaxf(r, fmaxf(g, b));
const float m = fminf(r, fminf(g, b));
const float chroma = M - m;
float h = 0.0f, s = 0.0f;
// hue
if (chroma > 0.0f) {
if (M == r) {
const float num = (g - b) / chroma;
const float sign = copysignf(1.0f, num);
h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f;
} else if (M == g) {
h = ((b - r) / chroma + 2.0f) / 6.0f;
} else {
h = ((r - g) / chroma + 4.0f) / 6.0f;
}
} else {
h = 0.0f;
}
// saturation
if (M > 0.0) {
s = chroma / M;
} else {
s = 0.0f;
}
tuple.h = h;
tuple.s = s;
tuple.v = M;
return tuple;
}

__device__ RgbTuple hsv2rgb_cuda(const float h, const float s, const float v)
{
RgbTuple tuple;
const float new_h = h * 6.0f;
const float chroma = v * s;
const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f));
const float new_m = v - chroma;
const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f;
const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f;
const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f;
const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f;
const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f;
const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f;
tuple.r = chroma * (between_0_and_1 || between_5_and_6) +
x * (between_1_and_2 || between_4_and_5) + new_m;
tuple.g = chroma * (between_1_and_2 || between_2_and_3) +
x * (between_0_and_1 || between_3_and_4) + new_m;
tuple.b = chroma * (between_3_and_4 || between_4_and_5) +
x * (between_2_and_3 || between_5_and_6) + new_m;
return tuple;
}

__global__ void adjust_hue_nhwc(const int64 number_elements,
const float * const __restrict__ input,
float * const output,
const float * const hue_delta)
{
// multiply by 3 since we're dealing with contiguous RGB bytes for each pixel (NHWC)
const int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3;
// bounds check
if (idx > number_elements - 1) {
return;
}
const float delta = hue_delta[0];
const HsvTuple hsv = rgb2hsv_cuda(input[idx], input[idx + 1], input[idx + 2]);
// hue adjustment
float new_h = fmodf(hsv.h + delta, 1.0f);
if (new_h < 0.0f) {
new_h = fmodf(1.0f + new_h, 1.0f);
}
const RgbTuple rgb = hsv2rgb_cuda(new_h, hsv.s, hsv.v);
output[idx] = rgb.r;
output[idx + 1] = rgb.g;
output[idx + 2] = rgb.b;
}
} // namespace internal


namespace functor {

void AdjustHueGPU::operator()(
GPUDevice* device,
const int64 number_of_elements,
const float* const input,
const float* const delta,
float* const output
) {
const auto stream = device->stream();
const CudaLaunchConfig config = GetCudaLaunchConfig(number_of_elements, *device);
const int threads_per_block = config.thread_per_block;
const int block_count = (number_of_elements + threads_per_block - 1) / threads_per_block;
internal::adjust_hue_nhwc<<<block_count, threads_per_block, 0, stream>>>(
number_of_elements, input, output, delta
);
}
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
2 changes: 1 addition & 1 deletion tensorflow/core/platform/default/build_config.bzl
Expand Up @@ -7,7 +7,7 @@ load("//tensorflow:tensorflow.bzl", "if_not_mobile")
# configure may change the following lines
WITH_GCP_SUPPORT = False
WITH_HDFS_SUPPORT = False
WITH_JEMALLOC = True
WITH_JEMALLOC = False

# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
Expand Down

0 comments on commit 10ab945

Please sign in to comment.