Skip to content

Commit

Permalink
Merge pull request tensorflow#57 from lissyx/ccpp-new
Browse files Browse the repository at this point in the history
Ccpp new
  • Loading branch information
lissyx committed Feb 19, 2018
2 parents 23d3d54 + e79f7cd commit a3fedf3
Show file tree
Hide file tree
Showing 252 changed files with 10,673 additions and 1,468 deletions.
2 changes: 1 addition & 1 deletion .taskcluster.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ tasks:
- pull_request.reopened
- push
branches:
- master
- ccpp

scopes: [
"queue:create-task:lowest:{{ taskcluster.docker.provisionerId }}/deepspeech-worker",
Expand Down
2 changes: 2 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,8 @@ def main():
set_computecpp_toolkit_path(environ_cp)
else:
set_trisycl_include_dir(environ_cp)
set_action_env_var(environ_cp, 'TF_USE_DOUBLE_SYCL', 'double types in SYCL', True)
set_action_env_var(environ_cp, 'TF_USE_HALF_SYCL', 'half types in SYCL', False)

set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
if (environ_cp.get('TF_NEED_CUDA') == '1' and
Expand Down
14 changes: 14 additions & 0 deletions eigen_sycl_intel.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h 2018-02-19 14:48:17.252343698 +0100
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h 2018-02-17 17:38:00.345594829 +0100
@@ -94,9 +94,9 @@
std::transform(vendor.begin(), vendor.end(), vendor.begin(), ::tolower);
bool unsuported_condition = (device.is_cpu() && platform_name.find("amd")!=std::string::npos && vendor.find("apu") == std::string::npos) ||
(device.is_gpu() && platform_name.find("intel")!=std::string::npos);
- if(!unsuported_condition){
+ // if(!unsuported_condition){
supported_devices.push_back(device);
- }
+ // }
}
}
return supported_devices;
2 changes: 1 addition & 1 deletion taskcluster/github-events.cyml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ taskcluster:
events:
- push
branches:
- master
- ccpp
13 changes: 13 additions & 0 deletions taskcluster/linux-amd64-ccpp-opt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
build:
template_file: linux-opt-base.tyml
routes:
- "index.project.deepspeech.tensorflow.pip.${event.head.branch}.ccpp"
- "index.project.deepspeech.tensorflow.pip.${event.head.branch}.${event.head.sha}.ccpp"
- "index.project.deepspeech.tensorflow.pip.ccpp.${event.head.sha}"
maxRunTime: 14400
args:
tcsetup: "--ccpp"
tcbuild: "--ccpp"
metadata:
name: "TensorFlow Linux AMD64 CCPP"
description: "Building TensorFlow for Linux/AMD64, CCPP-enabled, optimized version"
2 changes: 1 addition & 1 deletion taskcluster/linux-amd64-gpu-opt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build:
maxRunTime: 14400
args:
tcsetup: "--cuda"
tcbuild: "--gpu"
tcbuild: "--cuda"
metadata:
name: "TensorFlow Linux AMD64 CUDA"
description: "Building TensorFlow for Linux/AMD64, CUDA-enabled, optimized version"
2 changes: 1 addition & 1 deletion taskcluster/linux-opt-base.tyml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ payload:
- "-cxe"
- >
apt-get -qq update && apt-get -qq -y install git &&
apt-get -qq -y install make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl llvm libncurses5-dev libncursesw5-dev xz-utils tk-dev &&
apt-get -qq -y install make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl llvm libncurses5-dev libncursesw5-dev xz-utils tk-dev ocl-icd-libopencl1 ocl-icd-dev opencl-headers &&
adduser --system --home /home/build-user build-user &&
cd /home/build-user/ &&
echo -e "#!/bin/bash\nset -xe\nenv && id && mkdir ~/DeepSpeech/ && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/tf/ && cd ~/DeepSpeech/tf && git checkout --quiet ${event.head.sha}" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
Expand Down
31 changes: 21 additions & 10 deletions tc-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,47 @@ set -ex

source $(dirname $0)/tc-vars.sh

build_gpu=no
build_cuda=no
build_ccpp=no
build_arm=no

if [ "$1" = "--gpu" ]; then
build_gpu=yes
if [ "$1" = "--cuda" ]; then
build_cuda=yes
fi

if [ "$1" = "--ccpp" ]; then
build_ccpp=yes
fi

if [ "$1" = "--arm" ]; then
build_gpu=no
build_cuda=no
build_ccpp=no
build_arm=yes
fi

pushd ${DS_ROOT_TASK}/DeepSpeech/tf/
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s --explain bazel_monolithic_tf.log --verbose_explanations --experimental_strict_action_env --config=monolithic"

# Pure amd64 CPU-only build
if [ "${build_gpu}" = "no" -a "${build_arm}" = "no" ]; then
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} -c opt ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_GRAPH_TRANSFORMS} ${BUILD_TARGET_GRAPH_SUMMARIZE} ${BUILD_TARGET_GRAPH_BENCHMARK} ${BUILD_TARGET_CONVERT_MMAP} ${BUILD_TARGET_AOT_DEPS}
if [ "${build_cuda}" = "no" -a "${build_ccpp}" = "no" -a "${build_arm}" = "no" ]; then
echo "" | ./configure && ${BAZEL_BUILD} -c opt ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_GRAPH_TRANSFORMS} ${BUILD_TARGET_GRAPH_SUMMARIZE} ${BUILD_TARGET_GRAPH_BENCHMARK} ${BUILD_TARGET_CONVERT_MMAP} ${BUILD_TARGET_AOT_DEPS}
fi

# Cross RPi3 CPU-only build
if [ "${build_gpu}" = "no" -a "${build_arm}" = "yes" ]; then
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} -c opt ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_GRAPH_TRANSFORMS} ${BUILD_TARGET_GRAPH_SUMMARIZE} ${BUILD_TARGET_GRAPH_BENCHMARK} ${BUILD_TARGET_AOT_DEPS}
if [ "${build_cuda}" = "no" -a "${build_ccpp}" = "no" -a "${build_arm}" = "yes" ]; then
echo "" | ./configure && ${BAZEL_BUILD} -c opt ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_GRAPH_TRANSFORMS} ${BUILD_TARGET_GRAPH_SUMMARIZE} ${BUILD_TARGET_GRAPH_BENCHMARK} ${BUILD_TARGET_AOT_DEPS}
fi

# Pure amd64 GPU-enabled build
if [ "${build_gpu}" = "yes" -a "${build_arm}" = "no" ]; then
# Pure amd64 CUDA-enabled build
if [ "${build_cuda}" = "yes" -a "${build_ccpp}" = "no" -a "${build_arm}" = "no" ]; then
eval "export ${TF_CUDA_FLAGS}" && (echo "" | TF_NEED_CUDA=1 ./configure) && ${BAZEL_BUILD} -c opt ${BAZEL_CUDA_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BAZEL_OPT_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_GRAPH_TRANSFORMS} ${BUILD_TARGET_GRAPH_SUMMARIZE} ${BUILD_TARGET_GRAPH_BENCHMARK}
fi

# Pure amd64 CCPP-enabled build
if [ "${build_cuda}" = "no" -a "${build_ccpp}" = "yes" -a "${build_arm}" = "no" ]; then
eval "export ${TF_CCPP_FLAGS}" && (echo "" | TF_NEED_OPENCL_SYCL=1 ./configure) && ${BAZEL_BUILD} -c opt ${BAZEL_CCPP_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BAZEL_OPT_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_GRAPH_TRANSFORMS} ${BUILD_TARGET_GRAPH_SUMMARIZE} ${BUILD_TARGET_GRAPH_BENCHMARK}
fi

if [ $? -ne 0 ]; then
# There was a failure, just account for it.
echo "Build failure, please check the output above. Exit code was: $?"
Expand Down
20 changes: 17 additions & 3 deletions tc-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ if [ "$1" = "--cuda" ]; then
install_cuda=yes
fi

install_ccpp=
if [ "$1" = "--ccpp" ]; then
install_ccpp=yes
fi;

# $1 url
# $2 sha256
download()
Expand All @@ -35,6 +40,10 @@ if [ ! -z "${install_cuda}" ]; then
download $CUDNN_URL $CUDNN_SHA256
fi;

if [ ! -z "${install_ccpp}" ]; then
download $COMPUTECPP_URL $COMPUTECPP_SHA256
fi;

# For debug
ls -hal ${DS_ROOT_TASK}/dls/

Expand Down Expand Up @@ -72,9 +81,14 @@ if [ ! -z "${install_cuda}" ]; then
if [ ! -h "${DS_ROOT_TASK}/DeepSpeech/CUDA/lib64/stubs/libcuda.so.1" ]; then
ln -s "${DS_ROOT_TASK}/DeepSpeech/CUDA/lib64/stubs/libcuda.so" "${DS_ROOT_TASK}/DeepSpeech/CUDA/lib64/stubs/libcuda.so.1"
fi;
fi;

else
echo "No CUDA/CuDNN to install"
fi
if [ ! -z "${install_ccpp}" ]; then
mkdir -p ${DS_ROOT_TASK}/DeepSpeech/ComputeCpp-CE/ || true
pushd ${DS_ROOT_TASK}
COMPUTECPP_FILE=`basename ${COMPUTECPP_URL}`
tar xvf ${DS_ROOT_TASK}/dls/${COMPUTECPP_FILE} --strip-components=1 -C ${DS_ROOT_TASK}/DeepSpeech/ComputeCpp-CE/
popd
fi;

mkdir -p ${TASKCLUSTER_ARTIFACTS} || true
6 changes: 6 additions & 0 deletions tc-vars.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ if [ "${OS}" = "Linux" ]; then
CUDNN_URL=http://developer.download.nvidia.com/compute/redist/cudnn/v7.0.5/cudnn-9.0-linux-x64-v7.tgz
CUDNN_SHA256=1a3e076447d5b9860c73d9bebe7087ffcb7b0c8814fd1e506096435a2ad9ab0e

COMPUTECPP_URL=https://computecpp.codeplay.com/downloads/computecpp-ce/0.5.1/ubuntu-14.04-64bit.tar.gz
COMPUTECPP_SHA256=57bd757a878f0ce81557e1dad62e7dfa88eaf48b111f692d31f5e5861ca4c0a2

elif [ "${OS}" = "Darwin" ]; then
if [ -z "${TASKCLUSTER_TASK_DIR}" -o -z "${TASKCLUSTER_ARTIFACTS}" ]; then
echo "Inconsistent OSX setup: missing some vars."
Expand Down Expand Up @@ -55,6 +58,7 @@ fi;
export TF_NEED_JEMALLOC
export TF_NEED_GCP=0
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
export TF_NEED_OPENCL_SYCL=0
export TF_NEED_MKL=0
export TF_NEED_VERBS=0
Expand Down Expand Up @@ -96,8 +100,10 @@ fi;

### Define build parameters/env variables that we will re-ues in sourcing scripts.
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=9.0 TF_CUDNN_VERSION=7 CUDA_TOOLKIT_PATH=${DS_ROOT_TASK}/DeepSpeech/CUDA CUDNN_INSTALL_PATH=${DS_ROOT_TASK}/DeepSpeech/CUDA TF_CUDA_COMPUTE_CAPABILITIES=\"3.0,3.5,3.7,5.2,6.0,6.1\""
TF_CCPP_FLAGS="TF_NEED_COMPUTECPP=1 COMPUTECPP_TOOLKIT_PATH=${DS_ROOT_TASK}/DeepSpeech/ComputeCpp-CE/"
BAZEL_ARM_FLAGS="--config=rpi3"
BAZEL_CUDA_FLAGS="--config=cuda"
BAZEL_CCPP_FLAGS="--config=sycl --copt=-DCTC_DISABLE_OMP"
BAZEL_EXTRA_FLAGS="--copt=-fvisibility=hidden"

### Define build targets that we will re-ues in sourcing scripts.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@


def data_format():
return "channels_first" if tf.test.is_gpu_available() else "channels_last"
is_gpu = tf.test.is_gpu_available(cuda_only=True)
return "channels_first" if is_gpu else "channels_last"


class MNISTGraphTest(tf.test.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1537,7 +1537,8 @@ def test_there_is_no_xpu(self):

def test_whether_there_is_a_gpu(self):
if test.is_gpu_available():
self.assertTrue(len(replicate_model_fn._get_local_devices('GPU')))
gpu_type = test_util.gpu_device_type()
self.assertTrue(len(replicate_model_fn._get_local_devices(gpu_type)))


class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
Expand Down
12 changes: 6 additions & 6 deletions tensorflow/contrib/gdr/gdr_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
} else {
// Non-DMA cases.
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
#if GOOGLE_CUDA
const DeviceContext* send_dev_context = send_args.device_context;
#if GOOGLE_CUDA || TENSORFLOW_USE_SYCL
DeviceContext* send_dev_context = send_args.device_context;
AllocatorAttributes alloc_attrs;
alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_on_host(true);
Expand All @@ -127,17 +127,17 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
CHECK(send_dev_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
// "val" is on a GPU. Uses GPUUtil to fill the response proto.
// "val" is on a GPU/SYCL. Uses CopyDeviceTensorToCPU to fill the
// response proto.
StatusCallback copy_ready = [response, done, copy,
is_dead](const Status& s) {
// The value is now ready to be returned on the wire.
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
done(s);
delete copy;
};

GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
copy_ready);
send_dev_context->CopyDeviceTensorToCPU(&val,
"IGNORE_MY_TENSOR_NAME", src_dev , copy, copy_ready);
#else
done(errors::Internal("No GPU device in process"));
#endif // GOOGLE_CUDA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def testLinearlySeparableBinaryDataNoKernels(self):
# Since the data is linearly separable, the classifier should have small
# loss and perfect accuracy.
self.assertLess(metrics['loss'], 0.1)
self.assertEqual(metrics['accuracy'], 1.0)
self.assertAllClose(metrics['accuracy'], 1.0)

# As a result, it should assign higher probability to class 1 for the 1st
# and 3rd example and higher probability to class 0 for the second example.
Expand Down
30 changes: 15 additions & 15 deletions tensorflow/contrib/metrics/python/ops/metric_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,8 @@ def testAllCorrect(self):

with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op))
self.assertAlmostEqual(1, precision.eval())
self.assertAlmostEqual(1.0, sess.run(update_op), 6)
self.assertAlmostEqual(1.0, precision.eval(), 6)

def testSomeCorrect(self):
predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
Expand Down Expand Up @@ -1249,7 +1249,7 @@ def testAllCorrect(self):
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, recall.eval())
self.assertAlmostEqual(1, recall.eval(), 6)

def testSomeCorrect(self):
predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
Expand Down Expand Up @@ -1424,7 +1424,7 @@ def testAllIncorrect(self):
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fpr.eval())
self.assertAlmostEqual(1.0, fpr.eval(), 6)

def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self):
predictions = array_ops.ones((1, 4))
Expand Down Expand Up @@ -1547,7 +1547,7 @@ def testAllIncorrect(self):
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fnr.eval())
self.assertAlmostEqual(1.0, fnr.eval(), 6)

def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self):
predictions = array_ops.zeros((1, 4))
Expand Down Expand Up @@ -2356,8 +2356,8 @@ def testAllCorrect(self):

with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
self.assertAlmostEqual(1.0, sess.run(update_op), 6)
self.assertAlmostEqual(1.0, specificity.eval(), 6)

def testSomeCorrectHighSensitivity(self):
predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.45, 0.5, 0.8, 0.9]
Expand Down Expand Up @@ -2492,8 +2492,8 @@ def testAllCorrect(self):

with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
self.assertAlmostEqual(1.0, sess.run(update_op), 6)
self.assertAlmostEqual(1.0, specificity.eval(), 6)

def testSomeCorrectHighSpecificity(self):
predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
Expand Down Expand Up @@ -3128,8 +3128,8 @@ def testAllCorrect(self):

with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, recall.eval())
self.assertAlmostEqual(1.0, sess.run(update_op), 6)
self.assertAlmostEqual(1.0, recall.eval(), 6)

def testSomeCorrectHighPrecision(self):
predictions_values = [1, .9, .8, .7, .6, .5, .4, .3]
Expand Down Expand Up @@ -5118,8 +5118,8 @@ def testSingleUpdateWithError(self):

with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
self.assertAlmostEqual(6.0, sess.run(update_op), 6)
self.assertAlmostEqual(6.0, error.eval(), 6)

def testSingleUpdateWithErrorAndWeights(self):
predictions = constant_op.constant(
Expand Down Expand Up @@ -5790,8 +5790,8 @@ def testSingleUpdateZeroError(self):

with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
self.assertAlmostEqual(0.0, sess.run(update_op), 6)
self.assertAlmostEqual(0.0, error.eval(), 6)

def testSingleUpdateWithError1(self):
np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/contrib/rnn/kernels/gru_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL

template <typename Device, typename T, bool USE_CUBLAS>
class GRUCellBlockOp : public OpKernel {
Expand Down Expand Up @@ -167,6 +170,16 @@ class GRUCellBlockOp : public OpKernel {
REGISTER_KERNEL(float);
#undef REGISTER_KERNEL

#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("GRUBlockCell").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
GRUCellBlockOp<SYCLDevice, T, false>);

REGISTER_KERNEL(float);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL

template <typename Device, typename T, bool USE_CUBLAS>
class GRUBlockCellGradOp : public OpKernel {
public:
Expand Down Expand Up @@ -379,6 +392,16 @@ class GRUBlockCellGradOp : public OpKernel {
REGISTER_KERNEL(float);
#undef REGISTER_KERNEL

#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("GRUBlockCellGrad").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
GRUBlockCellGradOp<SYCLDevice, T, false>);

REGISTER_KERNEL(float);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL

// GPU support.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
Expand Down

0 comments on commit a3fedf3

Please sign in to comment.