diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md
index cc0d816204..cfb67e4a4c 100644
--- a/.github/CONTRIBUTING.md
+++ b/.github/CONTRIBUTING.md
@@ -84,6 +84,14 @@ Our default training scripts train using ImageNet. Because we cannot distribute
We look forward to delivering great pre-trained models in KerasCV with the help of your contributions!
+## Contributing custom ops
+
+We do not plan to accept contributed custom ops due to the maintenance burden that they introduce. If there is a clear need for a specific custom op that should live in KerasCV, please consult the KerasCV team before implementing it, as we expect to reject contributions of custom ops by default.
+
+We currently support only a small handful of ops that run on CPU and are not used at inference time.
+
+If you are updating existing custom ops, you can re-compile the binaries from source using the instructions in the `Tests that require custom ops` section below.
+
## Setup environment
Setting up your KerasCV development environment requires you to fork the KerasCV repository,
@@ -129,6 +137,16 @@ You can run the unit tests for KerasCV by running:
pytest keras_cv/
```
+### Tests that require custom ops
+For tests that require custom ops, you'll have to compile the custom ops and make them available to your local Python code:
+```shell
+python build_deps/configure.py
+bazel build keras_cv/custom_ops:all
+cp bazel-bin/keras_cv/custom_ops/*.so keras_cv/custom_ops/
+```
+
+Tests which use custom ops are disabled by default, but can be run by setting the environment variable `TEST_CUSTOM_OPS=true`.
+
## Formatting the Code
We use `flake8`, `isort` and `black` for code formatting. You can run
the following commands manually every time you want to format your code:
diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml
index 57afec8d26..02a0c20eb5 100644
--- a/.github/workflows/actions.yml
+++ b/.github/workflows/actions.yml
@@ -32,9 +32,16 @@ jobs:
${{ runner.os }}-pip-
- name: Install dependencies
run: |
- pip install tensorflow==2.10.0
+ pip install tensorflow-cpu==2.10.0
pip install -e ".[tests]" --progress-bar off --upgrade
+ - name: Build custom ops for tests
+ run: |
+ python build_deps/configure.py
+ bazel build keras_cv/custom_ops:all
+ cp bazel-bin/keras_cv/custom_ops/*.so keras_cv/custom_ops/
- name: Test with pytest
+ env:
+ TEST_CUSTOM_OPS: true
run: |
pytest keras_cv/ --ignore keras_cv/models
format:
diff --git a/BUILD b/BUILD
new file mode 100644
index 0000000000..8e1e7999ed
--- /dev/null
+++ b/BUILD
@@ -0,0 +1,11 @@
+sh_binary(
+ name = "build_pip_pkg",
+ srcs = ["build_deps/build_pip_pkg.sh"],
+ data = [
+ "LICENSE",
+ "MANIFEST.in",
+ "README.md",
+ "setup.py",
+ "//keras_cv",
+ ],
+)
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000..0684ba4458
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+recursive-include keras_cv/custom_ops *.so
diff --git a/README.md b/README.md
index 808f4a99a8..0b0d15580e 100644
--- a/README.md
+++ b/README.md
@@ -62,6 +62,23 @@ Thank you to all of our wonderful contributors!
+
+## Installing Custom Ops from Source
+Installing from source requires the [Bazel](https://bazel.build/) build system
+(version >= 1.0.0).
+
+```
+git clone https://github.com/keras-team/keras-cv.git
+cd keras-cv
+
+python3 build_deps/configure.py
+
+bazel build build_pip_pkg
+bazel-bin/build_pip_pkg wheels
+
+pip install wheels/keras-cv-*.whl
+```
+
## Pretrained Weights
Many models in KerasCV come with pre-trained weights. With the exception of StableDiffusion,
all of these weights are trained using Keras and KerasCV components and training scripts in this
@@ -72,6 +89,7 @@ history for backbone models [here](examples/training/classification/imagenet/tra
All results are reproducible using the training scripts in this repository. Pre-trained weights
operate on images that have been rescaled using a simple `1/255` rescaling layer.
+
## Citing KerasCV
If KerasCV helps your research, we appreciate your citations.
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 0000000000..d8ea963a08
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,3 @@
+load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure")
+
+tf_configure(name = "local_config_tf")
diff --git a/build_deps/build_pip_pkg.sh b/build_deps/build_pip_pkg.sh
new file mode 100755
index 0000000000..34aebc5de9
--- /dev/null
+++ b/build_deps/build_pip_pkg.sh
@@ -0,0 +1,88 @@
+#!/usr/bin/env bash
+# Copyright 2022 The KerasCV Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Builds a wheel of KerasCV for Pip. Requires Bazel.
+# Adapted from https://github.com/tensorflow/addons/blob/master/build_deps/build_pip_pkg.sh
+
+set -e
+set -x
+
+PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
+function is_windows() {
+ if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then
+ true
+ else
+ false
+ fi
+}
+
+if is_windows; then
+ PIP_FILE_PREFIX="bazel-bin/build_pip_pkg.exe.runfiles/__main__/"
+else
+ PIP_FILE_PREFIX="bazel-bin/build_pip_pkg.runfiles/__main__/"
+fi
+
+function main() {
+ while [[ ! -z "${1}" ]]; do
+ if [[ ${1} == "make" ]]; then
+ echo "Using Makefile to build pip package."
+ PIP_FILE_PREFIX=""
+ else
+ DEST=${1}
+ fi
+ shift
+ done
+
+ if [[ -z ${DEST} ]]; then
+ echo "No destination dir provided"
+ exit 1
+ fi
+
+ # Create the directory, then do dirname on a non-existent file inside it to
+ # give us an absolute paths with tilde characters resolved to the destination
+ # directory.
+ mkdir -p ${DEST}
+ if [[ ${PLATFORM} == "darwin" ]]; then
+ DEST=$(pwd -P)/${DEST}
+ else
+ DEST=$(readlink -f "${DEST}")
+ fi
+ echo "=== destination directory: ${DEST}"
+
+ TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
+
+ echo $(date) : "=== Using tmpdir: ${TMPDIR}"
+
+ echo "=== Copy KerasCV Custom op files"
+
+ cp ${PIP_FILE_PREFIX}setup.py "${TMPDIR}"
+ cp ${PIP_FILE_PREFIX}MANIFEST.in "${TMPDIR}"
+ cp ${PIP_FILE_PREFIX}README.md "${TMPDIR}"
+ cp ${PIP_FILE_PREFIX}LICENSE "${TMPDIR}"
+ rsync -avm -L --exclude='*_test.py' ${PIP_FILE_PREFIX}keras_cv "${TMPDIR}"
+
+ pushd ${TMPDIR}
+ echo $(date) : "=== Building wheel"
+
+ python3 setup.py bdist_wheel > /dev/null
+
+ cp dist/*.whl "${DEST}"
+ popd
+ rm -rf ${TMPDIR}
+ echo $(date) : "=== Output wheel file is in: ${DEST}"
+}
+
+main "$@"
diff --git a/build_deps/configure.py b/build_deps/configure.py
new file mode 100644
index 0000000000..35a0a0628e
--- /dev/null
+++ b/build_deps/configure.py
@@ -0,0 +1,174 @@
+# Copyright 2022 The KerasCV Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Usage: python configure.py
+"""Configures local environment to prepare for building KerasCV from source."""
+
+
+import logging
+import os
+import pathlib
+import platform
+
+import tensorflow as tf
+from packaging.version import Version
+
+_TFA_BAZELRC = ".bazelrc"
+
+
+# Writes variables to bazelrc file
+def write(line):
+ with open(_TFA_BAZELRC, "a") as f:
+ f.write(line + "\n")
+
+
+def write_action_env(var_name, var):
+ write('build --action_env {}="{}"'.format(var_name, var))
+
+
+def is_macos():
+ return platform.system() == "Darwin"
+
+
+def is_windows():
+ return platform.system() == "Windows"
+
+
+def is_linux():
+ return platform.system() == "Linux"
+
+
+def is_raspi_arm():
+ return os.uname()[4] == "armv7l" or os.uname()[4] == "aarch64"
+
+
+def is_linux_ppc64le():
+ return is_linux() and platform.machine() == "ppc64le"
+
+
+def is_linux_x86_64():
+ return is_linux() and platform.machine() == "x86_64"
+
+
+def is_linux_arm():
+ return is_linux() and platform.machine() == "arm"
+
+
+def is_linux_aarch64():
+ return is_linux() and platform.machine() == "aarch64"
+
+
+def is_linux_s390x():
+ return is_linux() and platform.machine() == "s390x"
+
+
+def get_tf_header_dir():
+ import tensorflow as tf
+
+ tf_header_dir = tf.sysconfig.get_compile_flags()[0][2:]
+ if is_windows():
+ tf_header_dir = tf_header_dir.replace("\\", "/")
+ return tf_header_dir
+
+
+def get_cpp_version():
+ cpp_version = "c++14"
+ if Version(tf.__version__) >= Version("2.10"):
+ cpp_version = "c++17"
+ return cpp_version
+
+
+def get_tf_shared_lib_dir():
+ import tensorflow as tf
+
+ # OS Specific parsing
+ if is_windows():
+ tf_shared_lib_dir = tf.sysconfig.get_compile_flags()[0][2:-7] + "python"
+ return tf_shared_lib_dir.replace("\\", "/")
+ elif is_raspi_arm():
+ return tf.sysconfig.get_compile_flags()[0][2:-7] + "python"
+ else:
+ return tf.sysconfig.get_link_flags()[0][2:]
+
+
+# Converts the linkflag namespec to the full shared library name
+def get_shared_lib_name():
+ import tensorflow as tf
+
+ namespec = tf.sysconfig.get_link_flags()
+ if is_macos():
+ # MacOS
+ return "lib" + namespec[1][2:] + ".dylib"
+ elif is_windows():
+ # Windows
+ return "_pywrap_tensorflow_internal.lib"
+ elif is_raspi_arm():
+ # The below command for linux would return an empty list
+ return "_pywrap_tensorflow_internal.so"
+ else:
+ # Linux
+ return namespec[1][3:]
+
+
+def create_build_configuration():
+ print()
+ print("Configuring KerasCV to be built from source...")
+
+ if os.path.isfile(_TFA_BAZELRC):
+ os.remove(_TFA_BAZELRC)
+
+ logging.disable(logging.WARNING)
+
+ write_action_env("TF_HEADER_DIR", get_tf_header_dir())
+ write_action_env("TF_SHARED_LIBRARY_DIR", get_tf_shared_lib_dir())
+ write_action_env("TF_SHARED_LIBRARY_NAME", get_shared_lib_name())
+ write_action_env("TF_CXX11_ABI_FLAG", tf.sysconfig.CXX11_ABI_FLAG)
+
+ # This should be replaced with a call to tf.sysconfig if it's added
+ write_action_env("TF_CPLUSPLUS_VER", get_cpp_version())
+
+ write("build --spawn_strategy=standalone")
+ write("build --strategy=Genrule=standalone")
+ write("build --experimental_repo_remote_exec")
+ write("build -c opt")
+ write(
+ "build --cxxopt="
+ + '"-D_GLIBCXX_USE_CXX11_ABI="'
+ + str(tf.sysconfig.CXX11_ABI_FLAG)
+ )
+
+ if is_windows():
+ write("build --config=windows")
+ write("build:windows --enable_runfiles")
+ write("build:windows --copt=/experimental:preprocessor")
+ write("build:windows --host_copt=/experimental:preprocessor")
+ write("build:windows --copt=/arch=AVX")
+ write("build:windows --cxxopt=/std:" + get_cpp_version())
+ write("build:windows --host_cxxopt=/std:" + get_cpp_version())
+
+ if is_macos() or is_linux():
+ if not is_linux_ppc64le() and not is_linux_arm() and not is_linux_aarch64():
+ write("build --copt=-mavx")
+ write("build --cxxopt=-std=" + get_cpp_version())
+ write("build --host_cxxopt=-std=" + get_cpp_version())
+
+ print("> Building only CPU ops")
+
+ print()
+ print("Build configurations successfully written to", _TFA_BAZELRC, ":\n")
+ print(pathlib.Path(_TFA_BAZELRC).read_text())
+
+
+if __name__ == "__main__":
+ create_build_configuration()
diff --git a/build_deps/tf_dependency/BUILD b/build_deps/tf_dependency/BUILD
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/build_deps/tf_dependency/BUILD.tpl b/build_deps/tf_dependency/BUILD.tpl
new file mode 100644
index 0000000000..aae0d7ddb9
--- /dev/null
+++ b/build_deps/tf_dependency/BUILD.tpl
@@ -0,0 +1,18 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "tf_header_lib",
+ hdrs = [":tf_header_include"],
+ includes = ["include"],
+ visibility = ["//visibility:public"],
+)
+
+
+cc_library(
+ name = "libtensorflow_framework",
+ srcs = ["%{TF_SHARED_LIBRARY_NAME}"],
+ visibility = ["//visibility:public"],
+)
+
+%{TF_HEADER_GENRULE}
+%{TF_SHARED_LIBRARY_GENRULE}
diff --git a/build_deps/tf_dependency/build_defs.bzl.tpl b/build_deps/tf_dependency/build_defs.bzl.tpl
new file mode 100644
index 0000000000..f6573b67a4
--- /dev/null
+++ b/build_deps/tf_dependency/build_defs.bzl.tpl
@@ -0,0 +1,4 @@
+# Addons Build Definitions inherited from TensorFlow Core
+
+D_GLIBCXX_USE_CXX11_ABI = "%{tf_cx11_abi}"
+CPLUSPLUS_VERSION = "%{tf_cplusplus_ver}"
diff --git a/build_deps/tf_dependency/tf_configure.bzl b/build_deps/tf_dependency/tf_configure.bzl
new file mode 100644
index 0000000000..0c0b5e7064
--- /dev/null
+++ b/build_deps/tf_dependency/tf_configure.bzl
@@ -0,0 +1,244 @@
+"""Setup TensorFlow as external dependency"""
+
+_TF_HEADER_DIR = "TF_HEADER_DIR"
+
+_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
+
+_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
+
+_TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG"
+
+_TF_CPLUSPLUS_VER = "TF_CPLUSPLUS_VER"
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl
+ repository_ctx.template(
+ out,
+ Label("//build_deps/tf_dependency:%s.tpl" % tpl),
+ substitutions,
+ )
+
+def _fail(msg):
+ """Output failure message when auto configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
+
+def _is_windows(repository_ctx):
+ """Returns true if the host operating system is windows."""
+ os_name = repository_ctx.os.name.lower()
+ if os_name.find("windows") != -1:
+ return True
+ return False
+
+def _execute(
+ repository_ctx,
+ cmdline,
+ error_msg = None,
+ error_details = None,
+ empty_stdout_fine = False):
+ """Executes an arbitrary shell command.
+
+ Helper for executes an arbitrary shell command.
+
+ Args:
+ repository_ctx: the repository_ctx object.
+ cmdline: list of strings, the command to execute.
+ error_msg: string, a summary of the error if the command fails.
+ error_details: string, details about the error or steps to fix it.
+ empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
+ it's an error.
+
+ Returns:
+ The result of repository_ctx.execute(cmdline).
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ _fail("\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else "",
+ ]))
+ return result
+
+def _read_dir(repository_ctx, src_dir):
+ """Returns a string with all files in a directory.
+
+ Finds all files inside a directory, traversing subfolders and following
+ symlinks. The returned string contains the full path of all files
+ separated by line breaks.
+
+ Args:
+ repository_ctx: the repository_ctx object.
+ src_dir: directory to find files from.
+
+ Returns:
+ A string of all files inside the given dir.
+ """
+ if _is_windows(repository_ctx):
+ src_dir = src_dir.replace("/", "\\")
+ find_result = _execute(
+ repository_ctx,
+ ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
+ empty_stdout_fine = True,
+ )
+
+ # src_files will be used in genrule.outs where the paths must
+ # use forward slashes.
+ result = find_result.stdout.replace("\\", "/")
+ else:
+ find_result = _execute(
+ repository_ctx,
+ ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine = True,
+ )
+ result = find_result.stdout
+ return result
+
+def _genrule(genrule_name, command, outs):
+ """Returns a string with a genrule.
+
+ Genrule executes the given command and produces the given outputs.
+
+ Args:
+ genrule_name: A unique name for genrule target.
+ command: The command to run.
+ outs: A list of files generated by this rule.
+
+ Returns:
+ A genrule target.
+ """
+ return (
+ "genrule(\n" +
+ ' name = "' +
+ genrule_name + '",\n' +
+ " outs = [\n" +
+ outs +
+ "\n ],\n" +
+ ' cmd = """\n' +
+ command +
+ '\n """,\n' +
+ ")\n"
+ )
+
+def _norm_path(path):
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
+def _symlink_genrule_for_dir(
+ repository_ctx,
+ src_dir,
+ dest_dir,
+ genrule_name,
+ src_files = [],
+ dest_files = [],
+ tf_pip_dir_rename_pair = []):
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
+ If src_dir is passed, files will be read from the given directory; otherwise
+ we assume files are in src_files and dest_files.
+ Args:
+ repository_ctx: the repository_ctx object.
+ src_dir: source directory.
+ dest_dir: directory to create symlink in.
+ genrule_name: genrule name.
+ src_files: list of source files instead of src_dir.
+ dest_files: list of corresonding destination files.
+ tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
+ replace. For example, in TF pip package, the source code is under
+ "tensorflow_core", and we might want to replace it with
+ "tensorflow" to match the header includes.
+ Returns:
+ genrule target that creates the symlinks.
+ """
+
+ # Check that tf_pip_dir_rename_pair has the right length
+ tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
+ if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
+ _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)
+
+ if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
+ files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
+
+ # Create a list with the src_dir stripped to use for outputs.
+ if tf_pip_dir_rename_pair_len:
+ dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
+ else:
+ dest_files = files.replace(src_dir, "").splitlines()
+ src_files = files.splitlines()
+ command = []
+ outs = []
+
+ for i in range(len(dest_files)):
+ if dest_files[i] != "":
+ # If we have only one file to link we do not want to use the dest_dir, as
+ # $(@D) will include the full path to the file.
+ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
+
+ # Copy the headers to create a sandboxable setup.
+ cmd = "cp -f"
+ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
+ outs.append(' "' + dest_dir + dest_files[i] + '",')
+
+ genrule = _genrule(
+ genrule_name,
+ ";\n".join(command),
+ "\n".join(outs),
+ )
+ return genrule
+
+def _tf_pip_impl(repository_ctx):
+ tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
+ tf_header_rule = _symlink_genrule_for_dir(
+ repository_ctx,
+ tf_header_dir,
+ "include",
+ "tf_header_include",
+ tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
+ )
+
+ tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
+ tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
+ tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
+ tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG])
+ tf_cplusplus_ver = "-std=%s" % repository_ctx.os.environ[_TF_CPLUSPLUS_VER]
+
+ tf_shared_library_rule = _symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "",
+ tf_shared_library_name,
+ [tf_shared_library_path],
+ [tf_shared_library_name],
+ )
+
+ _tpl(repository_ctx, "BUILD", {
+ "%{TF_HEADER_GENRULE}": tf_header_rule,
+ "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
+ "%{TF_SHARED_LIBRARY_NAME}": tf_shared_library_name,
+ })
+
+ _tpl(
+ repository_ctx,
+ "build_defs.bzl",
+ {
+ "%{tf_cx11_abi}": tf_cx11_abi,
+ "%{tf_cplusplus_ver}": tf_cplusplus_ver,
+ },
+ )
+
+tf_configure = repository_rule(
+ environ = [
+ _TF_HEADER_DIR,
+ _TF_SHARED_LIBRARY_DIR,
+ _TF_SHARED_LIBRARY_NAME,
+ _TF_CXX11_ABI_FLAG,
+ _TF_CPLUSPLUS_VER,
+ ],
+ implementation = _tf_pip_impl,
+)
diff --git a/cloudbuild/README.md b/cloudbuild/README.md
index e7a9d2216c..d4cc28d70f 100644
--- a/cloudbuild/README.md
+++ b/cloudbuild/README.md
@@ -30,8 +30,14 @@ To add a dependency for GPU tests:
- Have a Keras team member update the Docker image for GPU tests by running the remaining steps
- Create a `Dockerfile` with the following contents:
```
-FROM tensorflow/tensorflow:2.9.1-gpu
+FROM tensorflow/tensorflow:2.10.0-gpu
+RUN \
+ apt-get -y update && \
+ apt-get -y install openjdk-8-jdk && \
+ echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add
RUN apt-get -y update
+RUN apt-get -y install bazel
RUN apt-get -y install git
RUN git clone https://github.com/{path_to_keras_cv_fork}.git
RUN cd keras-cv && git checkout {branch_name}
diff --git a/cloudbuild/unit_test_jobs.jsonnet b/cloudbuild/unit_test_jobs.jsonnet
index b003b7debb..8dbf236c17 100644
--- a/cloudbuild/unit_test_jobs.jsonnet
+++ b/cloudbuild/unit_test_jobs.jsonnet
@@ -22,6 +22,12 @@ local unittest = base.BaseTest {
'bash',
'-c',
|||
+ # Build custom ops from source
+ python build_deps/configure.py
+ bazel build keras_cv/custom_ops:all --verbose_failures
+ cp bazel-bin/keras_cv/custom_ops/*.so keras_cv/custom_ops/
+ TEST_CUSTOM_OPS=true
+
# Run whatever is in `command` here.
${@:0}
|||
diff --git a/keras_cv/BUILD b/keras_cv/BUILD
new file mode 100644
index 0000000000..bd084b7b08
--- /dev/null
+++ b/keras_cv/BUILD
@@ -0,0 +1,16 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "windows",
+ constraint_values = ["@bazel_tools//platforms:windows"],
+)
+
+py_library(
+ name = "keras_cv",
+ srcs = glob(["**/*.py"]),
+ data = [
+ "//keras_cv/custom_ops:_pairwise_iou_op.so",
+ ]
+)
diff --git a/keras_cv/custom_ops/BUILD b/keras_cv/custom_ops/BUILD
new file mode 100644
index 0000000000..2cd7d47f06
--- /dev/null
+++ b/keras_cv/custom_ops/BUILD
@@ -0,0 +1,44 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "windows",
+ constraint_values = ["@bazel_tools//platforms:windows"],
+)
+
+cc_library(
+ name = "box_util",
+ srcs = ["box_util.cc"],
+ hdrs = ["box_util.h"],
+ deps = [
+ "@local_config_tf//:libtensorflow_framework",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ copts = select({
+ ":windows": ["/DEIGEN_STRONG_INLINE=inline", "-DTENSORFLOW_MONOLITHIC_BUILD", "/DPLATFORM_WINDOWS", "/DEIGEN_HAS_C99_MATH", "/DTENSORFLOW_USE_EIGEN_THREADPOOL", "/DEIGEN_AVOID_STL_ARRAY", "/Iexternal/gemmlowp", "/wd4018", "/wd4577", "/DNOGDI", "/UTF_COMPILE_LIBRARY"],
+ "//conditions:default": ["-pthread", "-std=c++17"],
+ }),
+)
+
+cc_binary(
+ name = '_keras_cv_custom_ops.so',
+ srcs = [
+ "kernels/pairwise_iou_kernel.cc",
+ "ops/pairwise_iou_op.cc"
+ ],
+ linkshared = 1,
+ deps = [
+ "@local_config_tf//:libtensorflow_framework",
+ "@local_config_tf//:tf_header_lib",
+ ":box_util",
+ ],
+ features = select({
+ ":windows": ["windows_export_all_symbols"],
+ "//conditions:default": [],
+ }),
+ copts = select({
+ ":windows": ["/DEIGEN_STRONG_INLINE=inline", "-DTENSORFLOW_MONOLITHIC_BUILD", "/DPLATFORM_WINDOWS", "/DEIGEN_HAS_C99_MATH", "/DTENSORFLOW_USE_EIGEN_THREADPOOL", "/DEIGEN_AVOID_STL_ARRAY", "/Iexternal/gemmlowp", "/wd4018", "/wd4577", "/DNOGDI", "/UTF_COMPILE_LIBRARY"],
+ "//conditions:default": ["-pthread", "-std=c++17"],
+ }),
+)
diff --git a/keras_cv/custom_ops/__init__.py b/keras_cv/custom_ops/__init__.py
new file mode 100644
index 0000000000..65be099991
--- /dev/null
+++ b/keras_cv/custom_ops/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The KerasCV Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/keras_cv/custom_ops/box_util.cc b/keras_cv/custom_ops/box_util.cc
new file mode 100644
index 0000000000..6b4cdec4c7
--- /dev/null
+++ b/keras_cv/custom_ops/box_util.cc
@@ -0,0 +1,327 @@
+/* Copyright 2022 The KerasCV Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "keras_cv/custom_ops/box_util.h"
+
+#include
+#include
+
+namespace tensorflow {
+namespace kerascv {
+namespace box {
+
+const double kEPS = 1e-8;
+
+// Min,max box dimensions (length, width, height). Boxes with dimensions that
+// exceed these values will have box intersections of 0.
+constexpr double kMinBoxDim = 1e-3;
+constexpr double kMaxBoxDim = 1e6;
+
+// A line with the representation a*x + b*y + c = 0.
+struct Line {
+ double a = 0;
+ double b = 0;
+ double c = 0;
+
+ Line(const Vertex& v1, const Vertex& v2)
+ : a(v2.y - v1.y), b(v1.x - v2.x), c(v2.x * v1.y - v2.y * v1.x) {}
+
+ // Computes the line value for a vertex v as a * v.x + b * v.y + c
+ double LineValue(const Vertex& v) const { return a * v.x + b * v.y + c; }
+
+ // Computes the intersection point with the other line.
+ Vertex IntersectionPoint(const Line& other) const {
+ const double w = a * other.b - b * other.a;
+ CHECK_GT(std::fabs(w), kEPS) << "No intersection between the two lines.";
+ return Vertex((b * other.c - c * other.b) / w,
+ (c * other.a - a * other.c) / w);
+ }
+};
+
+// Computes the coordinates of its four vertices given a 2D rotated box,
+std::vector ComputeBoxVertices(const double cx, const double cy,
+ const double w, const double h,
+ const double heading) {
+ const double dxcos = (w / 2.) * std::cos(heading);
+ const double dxsin = (w / 2.) * std::sin(heading);
+ const double dycos = (h / 2.) * std::cos(heading);
+ const double dysin = (h / 2.) * std::sin(heading);
+ return {Vertex(cx - dxcos + dysin, cy - dxsin - dycos),
+ Vertex(cx + dxcos + dysin, cy + dxsin - dycos),
+ Vertex(cx + dxcos - dysin, cy + dxsin + dycos),
+ Vertex(cx - dxcos - dysin, cy - dxsin + dycos)};
+}
+
+// Computes the intersection points between two rotated boxes, by following:
+//
+// 1. Initiazlizes the current intersection points with the vertices of one box,
+// and the other box is taken as the cutting box;
+//
+// 2. For each cutting line in the cutting box (four cutting lines in total):
+// For each point in the current intersection points:
+// If the point is inside of the cutting line:
+// Adds it to the new intersection points;
+// if current point and its next point are in the opposite side of the
+// cutting line:
+// Computes the line of current points and its next point as tmp_line;
+// Computes the intersection point between the cutting line and
+// tmp_line;
+// Adds the intersection point to the new intersection points;
+// After checking each cutting line, sets current intersection points as
+// new intersection points;
+//
+// 3. Returns the final intersection points.
+std::vector ComputeIntersectionPoints(
+ const std::vector& rbox_1, const std::vector& rbox_2) {
+ std::vector intersection = rbox_1;
+ const int vertices_len = rbox_2.size();
+ for (int i = 0; i < rbox_2.size(); ++i) {
+ const int len = intersection.size();
+ if (len <= 2) {
+ break;
+ }
+ const Vertex& p = rbox_2[i];
+ const Vertex& q = rbox_2[(i + 1) % vertices_len];
+ Line cutting_line(p, q);
+ // Computes line value.
+ std::vector line_values;
+ line_values.reserve(len);
+ for (int j = 0; j < len; ++j) {
+ line_values.push_back(cutting_line.LineValue(intersection[j]));
+ }
+ // Updates current intersection points.
+ std::vector new_intersection;
+ for (int j = 0; j < len; ++j) {
+ const double s_val = line_values[j];
+ const Vertex& s = intersection[j];
+ // Adds the current vertex.
+ if (s_val <= 0 || std::fabs(s_val) <= kEPS) {
+ new_intersection.push_back(s);
+ }
+ const double t_val = line_values[(j + 1) % len];
+ // Skips the checking of intersection point if the next vertex is on the
+ // line.
+ if (std::fabs(t_val) <= kEPS) {
+ continue;
+ }
+ // Adds the intersection point.
+ if ((s_val > 0 && t_val < 0) || (s_val < 0 && t_val > 0)) {
+ Line s_t_line(s, intersection[(j + 1) % len]);
+ new_intersection.push_back(cutting_line.IntersectionPoint(s_t_line));
+ }
+ }
+ intersection = new_intersection;
+ }
+ return intersection;
+}
+
+// Computes the area of a convex polygon,
+double ComputePolygonArea(const std::vector& convex_polygon) {
+ const int len = convex_polygon.size();
+ if (len <= 2) {
+ return 0;
+ }
+ double area = 0;
+ for (int i = 0; i < len; ++i) {
+ const Vertex& p = convex_polygon[i];
+ const Vertex& q = convex_polygon[(i + 1) % len];
+ area += p.x * q.y - p.y * q.x;
+ }
+ return std::fabs(0.5 * area);
+}
+
+RotatedBox2D::RotatedBox2D(const double cx, const double cy, const double w,
+ const double h, const double heading)
+ : cx_(cx), cy_(cy), w_(w), h_(h), heading_(heading) {
+ // Compute loose bounds on dimensions of box that doesn't require computing
+ // full intersection. We can do this by trying to compute the largest circle
+ // swept by rotating the box around its center. The radius of that circle
+ // is the length of the ray from the center to the box corner. The upper
+ // bound for this value is the length of the longer dimension divided by two
+ // and then multiplied by root(2) (worst-case being a square box); we choose
+ // 1.5 as slightly higher than root(2), and then use these extrema to do
+ // simple extrema box checks without having to compute the true cos/sin value.
+ double max_dim = std::max(w_, h_) / 2. * 1.5;
+ loose_min_x_ = cx_ - max_dim;
+ loose_max_x_ = cx_ + max_dim;
+ loose_min_y_ = cy_ - max_dim;
+ loose_max_y_ = cy_ + max_dim;
+
+ extreme_box_dim_ = (w_ <= kMinBoxDim || h_ <= kMinBoxDim);
+ extreme_box_dim_ |= (w_ >= kMaxBoxDim || h_ >= kMaxBoxDim);
+}
+
+double RotatedBox2D::Area() const {
+ if (area_ < 0) {
+ const double area = ComputePolygonArea(box_vertices());
+ area_ = std::fabs(area) <= kEPS ? 0 : area;
+ }
+ return area_;
+}
+
+const std::vector& RotatedBox2D::box_vertices() const {
+ if (box_vertices_.empty()) {
+ box_vertices_ = ComputeBoxVertices(cx_, cy_, w_, h_, heading_);
+ }
+
+ return box_vertices_;
+}
+
+bool RotatedBox2D::NonZeroAndValid() const { return !extreme_box_dim_; }
+
+bool RotatedBox2D::MaybeIntersects(const RotatedBox2D& other) const {
+ // If the box dimensions of either box are too small / large,
+ // assume they are not well-formed boxes (otherwise we are
+ // subject to issues due to catastrophic cancellation).
+ if (extreme_box_dim_ || other.extreme_box_dim_) {
+ return false;
+ }
+
+ // Check whether the loose extrema overlap -- if not, then there is
+ // no chance that the two boxes overlap even when computing the true,
+ // more expensive overlap.
+ if ((loose_min_x_ > other.loose_max_x_) ||
+ (loose_max_x_ < other.loose_min_x_) ||
+ (loose_min_y_ > other.loose_max_y_) ||
+ (loose_max_y_ < other.loose_min_y_)) {
+ return false;
+ }
+
+ return true;
+}
+
+double RotatedBox2D::Intersection(const RotatedBox2D& other) const {
+ // Do a fast intersection check - if the boxes are not near each other
+ // then we can return early. If they are close enough to maybe overlap,
+ // we do the full check.
+ if (!MaybeIntersects(other)) {
+ return 0.0;
+ }
+
+ // Computes the intersection polygon.
+ const std::vector intersection_polygon =
+ ComputeIntersectionPoints(box_vertices(), other.box_vertices());
+ // Computes the intersection area.
+ const double intersection_area = ComputePolygonArea(intersection_polygon);
+
+ return std::fabs(intersection_area) <= kEPS ? 0 : intersection_area;
+}
+
+double RotatedBox2D::IoU(const RotatedBox2D& other) const {
+ // Computes the intersection area.
+ const double intersection_area = Intersection(other);
+ if (intersection_area == 0) {
+ return 0;
+ }
+ // Computes the union area.
+ const double union_area = Area() + other.Area() - intersection_area;
+ if (std::fabs(union_area) <= kEPS) {
+ return 0;
+ }
+ return intersection_area / union_area;
+}
+
+std::vector ParseBoxesFromTensor(const Tensor& boxes_tensor) {
+ int num_boxes = boxes_tensor.dim_size(0);
+
+ const auto t_boxes_tensor = boxes_tensor.matrix();
+
+ std::vector bboxes3d;
+ bboxes3d.reserve(num_boxes);
+ for (int i = 0; i < num_boxes; ++i) {
+ const double center_x = t_boxes_tensor(i, 0);
+ const double center_y = t_boxes_tensor(i, 1);
+ const double center_z = t_boxes_tensor(i, 2);
+ const double dimension_x = t_boxes_tensor(i, 3);
+ const double dimension_y = t_boxes_tensor(i, 4);
+ const double dimension_z = t_boxes_tensor(i, 5);
+ const double heading = t_boxes_tensor(i, 6);
+ const double z_min = center_z - dimension_z / 2;
+ const double z_max = center_z + dimension_z / 2;
+ RotatedBox2D box2d(center_x, center_y, dimension_x, dimension_y, heading);
+ if (dimension_x <= 0 || dimension_y <= 0) {
+ bboxes3d.emplace_back(RotatedBox2D(), z_min, z_max);
+ } else {
+ bboxes3d.emplace_back(box2d, z_min, z_max);
+ }
+ }
+ return bboxes3d;
+}
+
+bool Upright3DBox::NonZeroAndValid() const {
+ // If min is larger than max, the upright box is invalid.
+ //
+ // If the min and max are equal, the height of the box is 0. and thus the box
+ // is zero.
+ if (z_min - z_max >= 0.) {
+ return false;
+ }
+
+ return rbox.NonZeroAndValid();
+}
+
+double Upright3DBox::IoU(const Upright3DBox& other) const {
+ // Check that both boxes are non-zero and valid. Otherwise,
+ // return 0.
+ if (!NonZeroAndValid() || !other.NonZeroAndValid()) {
+ return 0;
+ }
+
+ // Quickly check whether z's overlap; if they don't, we can return 0.
+ const double z_inter =
+ std::max(.0, std::min(z_max, other.z_max) - std::max(z_min, other.z_min));
+ if (z_inter == 0) {
+ return 0;
+ }
+
+ const double base_inter = rbox.Intersection(other.rbox);
+ if (base_inter == 0) {
+ return 0;
+ }
+
+ const double volume_1 = rbox.Area() * (z_max - z_min);
+ const double volume_2 = other.rbox.Area() * (other.z_max - other.z_min);
+ const double volume_inter = base_inter * z_inter;
+ const double volume_union = volume_1 + volume_2 - volume_inter;
+ return volume_inter > 0 ? volume_inter / volume_union : 0;
+}
+
+double Upright3DBox::Overlap(const Upright3DBox& other) const {
+ // Check that both boxes are non-zero and valid. Otherwise,
+ // return 0.
+ if (!NonZeroAndValid() || !other.NonZeroAndValid()) {
+ return 0;
+ }
+
+ const double z_inter =
+ std::max(.0, std::min(z_max, other.z_max) - std::max(z_min, other.z_min));
+ if (z_inter == 0) {
+ return 0;
+ }
+
+ const double base_inter = rbox.Intersection(other.rbox);
+ if (base_inter == 0) {
+ return 0;
+ }
+
+ const double volume_1 = rbox.Area() * (z_max - z_min);
+ const double volume_inter = base_inter * z_inter;
+ // Normalizes intersection of volume by the volume of this box.
+ return volume_inter > 0 ? volume_inter / volume_1 : 0;
+}
+
+} // namespace box
+} // namespace kerascv
+} // namespace tensorflow
diff --git a/keras_cv/custom_ops/box_util.h b/keras_cv/custom_ops/box_util.h
new file mode 100644
index 0000000000..a384abc844
--- /dev/null
+++ b/keras_cv/custom_ops/box_util.h
@@ -0,0 +1,141 @@
+/* Copyright 2022 The Keras CV Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_PY_KERAS_CV_OPS_BOX_UTIL_H_
+#define THIRD_PARTY_PY_KERAS_CV_OPS_BOX_UTIL_H_
+
+#include
+#include
+
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace kerascv {
+namespace box {
+
+// A vertex with (x, y) coordinate.
+//
+// This is an internal implementation detail of RotatedBox2D.
+struct Vertex {
+ // Creates an empty Vertex.
+ Vertex() = default;
+
+ Vertex(const double x, const double y) : x(x), y(y) {}
+
+ double x = 0;
+ double y = 0;
+};
+
+// A rotated 2D bounding box represented as (cx, cy, w, h, r). cx, cy are the
+// box center coordinates; w, h are the box width and height; heading is the
+// rotation angle in radian relative to the 'positive x' direction.
+class RotatedBox2D {
+ public:
+ // Creates an empty rotated 2D box.
+ RotatedBox2D() : RotatedBox2D(0, 0, 0, 0, 0) {}
+
+ RotatedBox2D(const double cx, const double cy, const double w, const double h,
+ const double heading);
+
+ // Returns the area of the box.
+ double Area() const;
+
+ // Returns the intersection area between this box and the given box.
+ double Intersection(const RotatedBox2D& other) const;
+
+ // Returns the IoU between this box and the given box.
+ double IoU(const RotatedBox2D& other) const;
+
+ // Returns true if the box is valid (width and height are not extremely
+ // large or small).
+ bool NonZeroAndValid() const;
+
+ private:
+ // Computes / caches box_vertices_ calculation.
+ const std::vector& box_vertices() const;
+
+ // Returns true if this box and 'other' might intersect.
+ //
+ // If this returns false, the two boxes definitely do not intersect. If this
+ // returns true, it is still possible that the two boxes do not intersect, and
+ // the more expensive intersection code will be called.
+ bool MaybeIntersects(const RotatedBox2D& other) const;
+
+ double cx_ = 0;
+ double cy_ = 0;
+ double w_ = 0;
+ double h_ = 0;
+ double heading_ = 0;
+
+ // Loose boundaries for fast intersection test.
+ double loose_min_x_ = -1;
+ double loose_max_x_ = -1;
+ double loose_min_y_ = -1;
+ double loose_max_y_ = -1;
+
+ // True if the dimensions of the box are very small or very large in any
+ // dimension.
+ bool extreme_box_dim_ = false;
+
+ // The following fields are computed on demand. They are logically
+ // const.
+
+ // Cached area. Access via Area() public API.
+ mutable double area_ = -1;
+
+ // Stores the vertices of the box. Access via box_vertices().
+ mutable std::vector box_vertices_;
+};
+
+// A 3D box of 7-DOFs: only allows rotation around the z-axis.
+struct Upright3DBox {
+ RotatedBox2D rbox = RotatedBox2D();
+ double z_min = 0;
+ double z_max = 0;
+
+ // Creates an empty rotated 3D box.
+ Upright3DBox() = default;
+
+ // Creates a 3D box from the raw input data with size 7. The data format is
+ // (center_x, center_y, center_z, dimension_x, dimension_y, dimension_z,
+ // heading)
+ Upright3DBox(const std::vector& raw)
+ : rbox(raw[0], raw[1], raw[3], raw[4], raw[6]),
+ z_min(raw[2] - raw[5] / 2.0),
+ z_max(raw[2] + raw[5] / 2.0) {}
+
+ Upright3DBox(const RotatedBox2D& rb, const double z_min, const double z_max)
+ : rbox(rb), z_min(z_min), z_max(z_max) {}
+
+ // Computes intersection over union (of the volume).
+ double IoU(const Upright3DBox& other) const;
+
+ // Computes overlap: intersection of this box and the given box normalized
+ // over the volume of this box.
+ double Overlap(const Upright3DBox& other) const;
+
+ // Returns true if the box is valid (width and height are not extremely
+ // large or small, and zmin < zmax).
+ bool NonZeroAndValid() const;
+};
+
+// Converts a [N, 7] tensor to a vector of N Upright3DBox objects.
+std::vector ParseBoxesFromTensor(const Tensor& boxes_tensor);
+
+} // namespace box
+} // namespace kerascv
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_PY_KERAS_CV_OPS_BOX_UTIL_H_
diff --git a/keras_cv/custom_ops/kernels/pairwise_iou_kernel.cc b/keras_cv/custom_ops/kernels/pairwise_iou_kernel.cc
new file mode 100644
index 0000000000..fb3a9b6e1e
--- /dev/null
+++ b/keras_cv/custom_ops/kernels/pairwise_iou_kernel.cc
@@ -0,0 +1,72 @@
+/* Copyright 2022 The KerasCV Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include
+
+#include "keras_cv/custom_ops/box_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace kerascv {
+namespace {
+
+class PairwiseIoUOp : public OpKernel {
+ public:
+ explicit PairwiseIoUOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& a = ctx->input(0);
+ const Tensor& b = ctx->input(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
+ errors::InvalidArgument("In[0] must be a matrix, but get ",
+ a.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
+ errors::InvalidArgument("In[0] must be a matrix, but get ",
+ b.shape().DebugString()));
+ OP_REQUIRES(ctx, 7 == a.dim_size(1),
+ errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
+ a.shape().DebugString()));
+ OP_REQUIRES(ctx, 7 == b.dim_size(1),
+ errors::InvalidArgument("Matrix size-incompatible: In[1]: ",
+ b.shape().DebugString()));
+
+ const int n_a = a.dim_size(0);
+ const int n_b = b.dim_size(0);
+
+ Tensor* iou_a_b = nullptr;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output("iou", TensorShape({n_a, n_b}), &iou_a_b));
+
+ auto t_iou_a_b = iou_a_b->matrix();
+
+ std::vector box_a = box::ParseBoxesFromTensor(a);
+ std::vector box_b = box::ParseBoxesFromTensor(b);
+ for (int i_a = 0; i_a < n_a; ++i_a) {
+ for (int i_b = 0; i_b < n_b; ++i_b) {
+ t_iou_a_b(i_a, i_b) = box_a[i_a].IoU(box_b[i_b]);
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("PairwiseIou3D").Device(DEVICE_CPU),
+ PairwiseIoUOp);
+
+} // namespace
+} // namespace kerascv
+} // namespace tensorflow
diff --git a/keras_cv/custom_ops/ops/pairwise_iou_op.cc b/keras_cv/custom_ops/ops/pairwise_iou_op.cc
new file mode 100644
index 0000000000..79bcc0909d
--- /dev/null
+++ b/keras_cv/custom_ops/ops/pairwise_iou_op.cc
@@ -0,0 +1,35 @@
+/* Copyright 2022 The KerasCV Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+using namespace tensorflow;
+
+REGISTER_OP("PairwiseIou3D")
+ .Input("boxes_a: float")
+ .Input("boxes_b: float")
+ .Output("iou: float")
+ .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
+ c->set_output(
+ 0, c->MakeShape({c->Dim(c->input(0), 0), c->Dim(c->input(1), 0)}));
+ return tensorflow::Status();
+ })
+ .Doc(R"doc(
+Calculate pairwise IoUs between two set of 3D bboxes. Every bbox is represented
+as [center_x, center_y, center_z, dim_x, dim_y, dim_z, heading].
+boxes_a: A tensor of shape [num_boxes_a, 7]
+boxes_b: A tensor of shape [num_boxes_b, 7]
+)doc");
diff --git a/keras_cv/ops/__init__.py b/keras_cv/ops/__init__.py
index 19888495bf..46d735c699 100644
--- a/keras_cv/ops/__init__.py
+++ b/keras_cv/ops/__init__.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from keras_cv.ops.iou_3d import IoU3D
from keras_cv.ops.point_cloud import _box_area
from keras_cv.ops.point_cloud import _center_xyzWHD_to_corner_xyz
from keras_cv.ops.point_cloud import _is_on_lefthand_side
diff --git a/keras_cv/ops/iou_3d.py b/keras_cv/ops/iou_3d.py
new file mode 100644
index 0000000000..e6b05509ae
--- /dev/null
+++ b/keras_cv/ops/iou_3d.py
@@ -0,0 +1,48 @@
+# Copyright 2022 The KerasCV Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""IoU3D using a custom TF op."""
+
+from tensorflow.python.framework import load_library
+from tensorflow.python.platform import resource_loader
+
+
+class IoU3D:
+ """Implements IoU computation for 3D upright rotated bounding boxes.
+
+ Note that this is implemented using a custom TensorFlow op. Initializing an
+ IoU3D object will attempt to load the binary for that op.
+
+ Boxes should have the format [center_x, center_y, center_z, dimension_x,
+ dimension_y, dimension_z, heading (in radians)].
+
+ Sample Usage:
+ ```python
+ y_true = [[0, 0, 0, 2, 2, 2, 0], [1, 1, 1, 2, 2, 2, 3 * math.pi / 4]]
+ y_pred = [[1, 1, 1, 2, 2, 2, math.pi / 4], [1, 1, 1, 2, 2, 2, 0]]
+ iou = IoU3D()
+ iou(y_true, y_pred)
+ ```
+ """
+
+ def __init__(self):
+ pairwise_iou_op = load_library.load_op_library(
+ resource_loader.get_path_to_datafile(
+ "../custom_ops/_keras_cv_custom_ops.so"
+ )
+ )
+ self.iou_3d = pairwise_iou_op.pairwise_iou3d
+
+ def __call__(self, y_true, y_pred):
+ return self.iou_3d(y_true, y_pred)
diff --git a/keras_cv/ops/iou_3d_test.py b/keras_cv/ops/iou_3d_test.py
new file mode 100644
index 0000000000..b654e332fb
--- /dev/null
+++ b/keras_cv/ops/iou_3d_test.py
@@ -0,0 +1,54 @@
+# Copyright 2022 The KerasCV Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for IoU3D using custom op."""
+import math
+import os
+
+import pytest
+import tensorflow as tf
+
+from keras_cv.ops import IoU3D
+
+
+class IoU3DTest(tf.test.TestCase):
+ @pytest.mark.skipif(
+ "TEST_CUSTOM_OPS" not in os.environ or os.environ["TEST_CUSTOM_OPS"] != "true",
+ reason="Requires binaries compiled from source",
+ )
+ def testOpCall(self):
+ # Predicted boxes:
+ # 0: a 2x2x2 box centered at 0,0,0, rotated 0 degrees
+ # 1: a 2x2x2 box centered at 1,1,1, rotated 135 degrees
+ # Ground Truth boxes:
+ # 0: a 2x2x2 box centered at 1,1,1, rotated 45 degrees (idential to predicted box 1)
+ # 1: a 2x2x2 box centered at 1,1,1, rotated 0 degrees
+ box_preds = [[0, 0, 0, 2, 2, 2, 0], [1, 1, 1, 2, 2, 2, 3 * math.pi / 4]]
+ box_gt = [[1, 1, 1, 2, 2, 2, math.pi / 4], [1, 1, 1, 2, 2, 2, 0]]
+
+ # Predicted box 0 and both ground truth boxes overlap by 1/8th of the box.
+ # Therefore, IiU is 1/15
+ # Predicted box 1 is the same as ground truth box 0, therefore IoU is 1
+ # Predicted box 1 shares an origin with ground truth box 1, but is rotated by 135 degrees.
+ # Their IoU can be reduced to that of two overlapping squares that share a center with
+ # the same offset of 135 degrees, which reduces to the square root of 0.5.
+ expected_ious = [[1 / 15, 1 / 15], [1, 0.5**0.5]]
+
+ iou_3d = IoU3D()
+
+ self.assertAllClose(iou_3d(box_preds, box_gt), expected_ious)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/setup.py b/setup.py
index cc19cccb90..7973d77591 100644
--- a/setup.py
+++ b/setup.py
@@ -18,10 +18,22 @@
from setuptools import find_packages
from setuptools import setup
+from setuptools.dist import Distribution
HERE = pathlib.Path(__file__).parent
README = (HERE / "README.md").read_text()
+
+class BinaryDistribution(Distribution):
+ """This class is needed in order to create OS specific wheels."""
+
+ def has_ext_modules(self):
+ return True
+
+ def is_pure(self):
+ return False
+
+
setup(
name="keras-cv",
description="Industry-strength computer Vision extensions for Keras.",
@@ -37,6 +49,7 @@
"tests": ["flake8", "isort", "black", "pytest", "tensorflow-datasets"],
"examples": ["tensorflow_datasets", "matplotlib"],
},
+ distclass=BinaryDistribution,
classifiers=[
"Programming Language :: Python",
"Programming Language :: Python :: 3.7",
@@ -48,4 +61,5 @@
"Topic :: Software Development",
],
packages=find_packages(exclude=("*_test.py",)),
+ include_package_data=True,
)