diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index cc0d816204..cfb67e4a4c 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -84,6 +84,14 @@ Our default training scripts train using ImageNet. Because we cannot distribute We look forward to delivering great pre-trained models in KerasCV with the help of your contributions! +## Contributing custom ops + +We do not plan to accept contributed custom ops due to the maintenance burden that they introduce. If there is a clear need for a specific custom op that should live in KerasCV, please consult the KerasCV team before implementing it, as we expect to reject contributions of custom ops by default. + +We currently support only a small handful of ops that run on CPU and are not used at inference time. + +If you are updating existing custom ops, you can re-compile the binaries from source using the instructions in the `Tests that require custom ops` section below. + ## Setup environment Setting up your KerasCV development environment requires you to fork the KerasCV repository, @@ -129,6 +137,16 @@ You can run the unit tests for KerasCV by running: pytest keras_cv/ ``` +### Tests that require custom ops +For tests that require custom ops, you'll have to compile the custom ops and make them available to your local Python code: +```shell +python build_deps/configure.py +bazel build keras_cv/custom_ops:all +cp bazel-bin/keras_cv/custom_ops/*.so keras_cv/custom_ops/ +``` + +Tests which use custom ops are disabled by default, but can be run by setting the environment variable `TEST_CUSTOM_OPS=true`. + ## Formatting the Code We use `flake8`, `isort` and `black` for code formatting. You can run the following commands manually every time you want to format your code: diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 57afec8d26..02a0c20eb5 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -32,9 +32,16 @@ jobs: ${{ runner.os }}-pip- - name: Install dependencies run: | - pip install tensorflow==2.10.0 + pip install tensorflow-cpu==2.10.0 pip install -e ".[tests]" --progress-bar off --upgrade + - name: Build custom ops for tests + run: | + python build_deps/configure.py + bazel build keras_cv/custom_ops:all + cp bazel-bin/keras_cv/custom_ops/*.so keras_cv/custom_ops/ - name: Test with pytest + env: + TEST_CUSTOM_OPS: true run: | pytest keras_cv/ --ignore keras_cv/models format: diff --git a/BUILD b/BUILD new file mode 100644 index 0000000000..8e1e7999ed --- /dev/null +++ b/BUILD @@ -0,0 +1,11 @@ +sh_binary( + name = "build_pip_pkg", + srcs = ["build_deps/build_pip_pkg.sh"], + data = [ + "LICENSE", + "MANIFEST.in", + "README.md", + "setup.py", + "//keras_cv", + ], +) diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..0684ba4458 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include keras_cv/custom_ops *.so diff --git a/README.md b/README.md index 808f4a99a8..0b0d15580e 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,23 @@ Thank you to all of our wonderful contributors! + +## Installing Custom Ops from Source +Installing from source requires the [Bazel](https://bazel.build/) build system +(version >= 1.0.0). + +``` +git clone https://github.com/keras-team/keras-cv.git +cd keras-cv + +python3 build_deps/configure.py + +bazel build build_pip_pkg +bazel-bin/build_pip_pkg wheels + +pip install wheels/keras-cv-*.whl +``` + ## Pretrained Weights Many models in KerasCV come with pre-trained weights. With the exception of StableDiffusion, all of these weights are trained using Keras and KerasCV components and training scripts in this @@ -72,6 +89,7 @@ history for backbone models [here](examples/training/classification/imagenet/tra All results are reproducible using the training scripts in this repository. Pre-trained weights operate on images that have been rescaled using a simple `1/255` rescaling layer. + ## Citing KerasCV If KerasCV helps your research, we appreciate your citations. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000000..d8ea963a08 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,3 @@ +load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure") + +tf_configure(name = "local_config_tf") diff --git a/build_deps/build_pip_pkg.sh b/build_deps/build_pip_pkg.sh new file mode 100755 index 0000000000..34aebc5de9 --- /dev/null +++ b/build_deps/build_pip_pkg.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash +# Copyright 2022 The KerasCV Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Builds a wheel of KerasCV for Pip. Requires Bazel. +# Adapted from https://github.com/tensorflow/addons/blob/master/build_deps/build_pip_pkg.sh + +set -e +set -x + +PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" +function is_windows() { + if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then + true + else + false + fi +} + +if is_windows; then + PIP_FILE_PREFIX="bazel-bin/build_pip_pkg.exe.runfiles/__main__/" +else + PIP_FILE_PREFIX="bazel-bin/build_pip_pkg.runfiles/__main__/" +fi + +function main() { + while [[ ! -z "${1}" ]]; do + if [[ ${1} == "make" ]]; then + echo "Using Makefile to build pip package." + PIP_FILE_PREFIX="" + else + DEST=${1} + fi + shift + done + + if [[ -z ${DEST} ]]; then + echo "No destination dir provided" + exit 1 + fi + + # Create the directory, then do dirname on a non-existent file inside it to + # give us an absolute paths with tilde characters resolved to the destination + # directory. + mkdir -p ${DEST} + if [[ ${PLATFORM} == "darwin" ]]; then + DEST=$(pwd -P)/${DEST} + else + DEST=$(readlink -f "${DEST}") + fi + echo "=== destination directory: ${DEST}" + + TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) + + echo $(date) : "=== Using tmpdir: ${TMPDIR}" + + echo "=== Copy KerasCV Custom op files" + + cp ${PIP_FILE_PREFIX}setup.py "${TMPDIR}" + cp ${PIP_FILE_PREFIX}MANIFEST.in "${TMPDIR}" + cp ${PIP_FILE_PREFIX}README.md "${TMPDIR}" + cp ${PIP_FILE_PREFIX}LICENSE "${TMPDIR}" + rsync -avm -L --exclude='*_test.py' ${PIP_FILE_PREFIX}keras_cv "${TMPDIR}" + + pushd ${TMPDIR} + echo $(date) : "=== Building wheel" + + python3 setup.py bdist_wheel > /dev/null + + cp dist/*.whl "${DEST}" + popd + rm -rf ${TMPDIR} + echo $(date) : "=== Output wheel file is in: ${DEST}" +} + +main "$@" diff --git a/build_deps/configure.py b/build_deps/configure.py new file mode 100644 index 0000000000..35a0a0628e --- /dev/null +++ b/build_deps/configure.py @@ -0,0 +1,174 @@ +# Copyright 2022 The KerasCV Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Usage: python configure.py +"""Configures local environment to prepare for building KerasCV from source.""" + + +import logging +import os +import pathlib +import platform + +import tensorflow as tf +from packaging.version import Version + +_TFA_BAZELRC = ".bazelrc" + + +# Writes variables to bazelrc file +def write(line): + with open(_TFA_BAZELRC, "a") as f: + f.write(line + "\n") + + +def write_action_env(var_name, var): + write('build --action_env {}="{}"'.format(var_name, var)) + + +def is_macos(): + return platform.system() == "Darwin" + + +def is_windows(): + return platform.system() == "Windows" + + +def is_linux(): + return platform.system() == "Linux" + + +def is_raspi_arm(): + return os.uname()[4] == "armv7l" or os.uname()[4] == "aarch64" + + +def is_linux_ppc64le(): + return is_linux() and platform.machine() == "ppc64le" + + +def is_linux_x86_64(): + return is_linux() and platform.machine() == "x86_64" + + +def is_linux_arm(): + return is_linux() and platform.machine() == "arm" + + +def is_linux_aarch64(): + return is_linux() and platform.machine() == "aarch64" + + +def is_linux_s390x(): + return is_linux() and platform.machine() == "s390x" + + +def get_tf_header_dir(): + import tensorflow as tf + + tf_header_dir = tf.sysconfig.get_compile_flags()[0][2:] + if is_windows(): + tf_header_dir = tf_header_dir.replace("\\", "/") + return tf_header_dir + + +def get_cpp_version(): + cpp_version = "c++14" + if Version(tf.__version__) >= Version("2.10"): + cpp_version = "c++17" + return cpp_version + + +def get_tf_shared_lib_dir(): + import tensorflow as tf + + # OS Specific parsing + if is_windows(): + tf_shared_lib_dir = tf.sysconfig.get_compile_flags()[0][2:-7] + "python" + return tf_shared_lib_dir.replace("\\", "/") + elif is_raspi_arm(): + return tf.sysconfig.get_compile_flags()[0][2:-7] + "python" + else: + return tf.sysconfig.get_link_flags()[0][2:] + + +# Converts the linkflag namespec to the full shared library name +def get_shared_lib_name(): + import tensorflow as tf + + namespec = tf.sysconfig.get_link_flags() + if is_macos(): + # MacOS + return "lib" + namespec[1][2:] + ".dylib" + elif is_windows(): + # Windows + return "_pywrap_tensorflow_internal.lib" + elif is_raspi_arm(): + # The below command for linux would return an empty list + return "_pywrap_tensorflow_internal.so" + else: + # Linux + return namespec[1][3:] + + +def create_build_configuration(): + print() + print("Configuring KerasCV to be built from source...") + + if os.path.isfile(_TFA_BAZELRC): + os.remove(_TFA_BAZELRC) + + logging.disable(logging.WARNING) + + write_action_env("TF_HEADER_DIR", get_tf_header_dir()) + write_action_env("TF_SHARED_LIBRARY_DIR", get_tf_shared_lib_dir()) + write_action_env("TF_SHARED_LIBRARY_NAME", get_shared_lib_name()) + write_action_env("TF_CXX11_ABI_FLAG", tf.sysconfig.CXX11_ABI_FLAG) + + # This should be replaced with a call to tf.sysconfig if it's added + write_action_env("TF_CPLUSPLUS_VER", get_cpp_version()) + + write("build --spawn_strategy=standalone") + write("build --strategy=Genrule=standalone") + write("build --experimental_repo_remote_exec") + write("build -c opt") + write( + "build --cxxopt=" + + '"-D_GLIBCXX_USE_CXX11_ABI="' + + str(tf.sysconfig.CXX11_ABI_FLAG) + ) + + if is_windows(): + write("build --config=windows") + write("build:windows --enable_runfiles") + write("build:windows --copt=/experimental:preprocessor") + write("build:windows --host_copt=/experimental:preprocessor") + write("build:windows --copt=/arch=AVX") + write("build:windows --cxxopt=/std:" + get_cpp_version()) + write("build:windows --host_cxxopt=/std:" + get_cpp_version()) + + if is_macos() or is_linux(): + if not is_linux_ppc64le() and not is_linux_arm() and not is_linux_aarch64(): + write("build --copt=-mavx") + write("build --cxxopt=-std=" + get_cpp_version()) + write("build --host_cxxopt=-std=" + get_cpp_version()) + + print("> Building only CPU ops") + + print() + print("Build configurations successfully written to", _TFA_BAZELRC, ":\n") + print(pathlib.Path(_TFA_BAZELRC).read_text()) + + +if __name__ == "__main__": + create_build_configuration() diff --git a/build_deps/tf_dependency/BUILD b/build_deps/tf_dependency/BUILD new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build_deps/tf_dependency/BUILD.tpl b/build_deps/tf_dependency/BUILD.tpl new file mode 100644 index 0000000000..aae0d7ddb9 --- /dev/null +++ b/build_deps/tf_dependency/BUILD.tpl @@ -0,0 +1,18 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "tf_header_lib", + hdrs = [":tf_header_include"], + includes = ["include"], + visibility = ["//visibility:public"], +) + + +cc_library( + name = "libtensorflow_framework", + srcs = ["%{TF_SHARED_LIBRARY_NAME}"], + visibility = ["//visibility:public"], +) + +%{TF_HEADER_GENRULE} +%{TF_SHARED_LIBRARY_GENRULE} diff --git a/build_deps/tf_dependency/build_defs.bzl.tpl b/build_deps/tf_dependency/build_defs.bzl.tpl new file mode 100644 index 0000000000..f6573b67a4 --- /dev/null +++ b/build_deps/tf_dependency/build_defs.bzl.tpl @@ -0,0 +1,4 @@ +# Addons Build Definitions inherited from TensorFlow Core + +D_GLIBCXX_USE_CXX11_ABI = "%{tf_cx11_abi}" +CPLUSPLUS_VERSION = "%{tf_cplusplus_ver}" diff --git a/build_deps/tf_dependency/tf_configure.bzl b/build_deps/tf_dependency/tf_configure.bzl new file mode 100644 index 0000000000..0c0b5e7064 --- /dev/null +++ b/build_deps/tf_dependency/tf_configure.bzl @@ -0,0 +1,244 @@ +"""Setup TensorFlow as external dependency""" + +_TF_HEADER_DIR = "TF_HEADER_DIR" + +_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" + +_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME" + +_TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG" + +_TF_CPLUSPLUS_VER = "TF_CPLUSPLUS_VER" + +def _tpl(repository_ctx, tpl, substitutions = {}, out = None): + if not out: + out = tpl + repository_ctx.template( + out, + Label("//build_deps/tf_dependency:%s.tpl" % tpl), + substitutions, + ) + +def _fail(msg): + """Output failure message when auto configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _is_windows(repository_ctx): + """Returns true if the host operating system is windows.""" + os_name = repository_ctx.os.name.lower() + if os_name.find("windows") != -1: + return True + return False + +def _execute( + repository_ctx, + cmdline, + error_msg = None, + error_details = None, + empty_stdout_fine = False): + """Executes an arbitrary shell command. + + Helper for executes an arbitrary shell command. + + Args: + repository_ctx: the repository_ctx object. + cmdline: list of strings, the command to execute. + error_msg: string, a summary of the error if the command fails. + error_details: string, details about the error or steps to fix it. + empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise + it's an error. + + Returns: + The result of repository_ctx.execute(cmdline). + """ + result = repository_ctx.execute(cmdline) + if result.stderr or not (empty_stdout_fine or result.stdout): + _fail("\n".join([ + error_msg.strip() if error_msg else "Repository command failed", + result.stderr.strip(), + error_details if error_details else "", + ])) + return result + +def _read_dir(repository_ctx, src_dir): + """Returns a string with all files in a directory. + + Finds all files inside a directory, traversing subfolders and following + symlinks. The returned string contains the full path of all files + separated by line breaks. + + Args: + repository_ctx: the repository_ctx object. + src_dir: directory to find files from. + + Returns: + A string of all files inside the given dir. + """ + if _is_windows(repository_ctx): + src_dir = src_dir.replace("/", "\\") + find_result = _execute( + repository_ctx, + ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], + empty_stdout_fine = True, + ) + + # src_files will be used in genrule.outs where the paths must + # use forward slashes. + result = find_result.stdout.replace("\\", "/") + else: + find_result = _execute( + repository_ctx, + ["find", src_dir, "-follow", "-type", "f"], + empty_stdout_fine = True, + ) + result = find_result.stdout + return result + +def _genrule(genrule_name, command, outs): + """Returns a string with a genrule. + + Genrule executes the given command and produces the given outputs. + + Args: + genrule_name: A unique name for genrule target. + command: The command to run. + outs: A list of files generated by this rule. + + Returns: + A genrule target. + """ + return ( + "genrule(\n" + + ' name = "' + + genrule_name + '",\n' + + " outs = [\n" + + outs + + "\n ],\n" + + ' cmd = """\n' + + command + + '\n """,\n' + + ")\n" + ) + +def _norm_path(path): + """Returns a path with '/' and remove the trailing slash.""" + path = path.replace("\\", "/") + if path[-1] == "/": + path = path[:-1] + return path + +def _symlink_genrule_for_dir( + repository_ctx, + src_dir, + dest_dir, + genrule_name, + src_files = [], + dest_files = [], + tf_pip_dir_rename_pair = []): + """Returns a genrule to symlink(or copy if on Windows) a set of files. + If src_dir is passed, files will be read from the given directory; otherwise + we assume files are in src_files and dest_files. + Args: + repository_ctx: the repository_ctx object. + src_dir: source directory. + dest_dir: directory to create symlink in. + genrule_name: genrule name. + src_files: list of source files instead of src_dir. + dest_files: list of corresonding destination files. + tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to + replace. For example, in TF pip package, the source code is under + "tensorflow_core", and we might want to replace it with + "tensorflow" to match the header includes. + Returns: + genrule target that creates the symlinks. + """ + + # Check that tf_pip_dir_rename_pair has the right length + tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair) + if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2: + _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len) + + if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) + files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) + + # Create a list with the src_dir stripped to use for outputs. + if tf_pip_dir_rename_pair_len: + dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines() + else: + dest_files = files.replace(src_dir, "").splitlines() + src_files = files.splitlines() + command = [] + outs = [] + + for i in range(len(dest_files)): + if dest_files[i] != "": + # If we have only one file to link we do not want to use the dest_dir, as + # $(@D) will include the full path to the file. + dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] + + # Copy the headers to create a sandboxable setup. + cmd = "cp -f" + command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) + outs.append(' "' + dest_dir + dest_files[i] + '",') + + genrule = _genrule( + genrule_name, + ";\n".join(command), + "\n".join(outs), + ) + return genrule + +def _tf_pip_impl(repository_ctx): + tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR] + tf_header_rule = _symlink_genrule_for_dir( + repository_ctx, + tf_header_dir, + "include", + "tf_header_include", + tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"], + ) + + tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] + tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME] + tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name) + tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG]) + tf_cplusplus_ver = "-std=%s" % repository_ctx.os.environ[_TF_CPLUSPLUS_VER] + + tf_shared_library_rule = _symlink_genrule_for_dir( + repository_ctx, + None, + "", + tf_shared_library_name, + [tf_shared_library_path], + [tf_shared_library_name], + ) + + _tpl(repository_ctx, "BUILD", { + "%{TF_HEADER_GENRULE}": tf_header_rule, + "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, + "%{TF_SHARED_LIBRARY_NAME}": tf_shared_library_name, + }) + + _tpl( + repository_ctx, + "build_defs.bzl", + { + "%{tf_cx11_abi}": tf_cx11_abi, + "%{tf_cplusplus_ver}": tf_cplusplus_ver, + }, + ) + +tf_configure = repository_rule( + environ = [ + _TF_HEADER_DIR, + _TF_SHARED_LIBRARY_DIR, + _TF_SHARED_LIBRARY_NAME, + _TF_CXX11_ABI_FLAG, + _TF_CPLUSPLUS_VER, + ], + implementation = _tf_pip_impl, +) diff --git a/cloudbuild/README.md b/cloudbuild/README.md index e7a9d2216c..d4cc28d70f 100644 --- a/cloudbuild/README.md +++ b/cloudbuild/README.md @@ -30,8 +30,14 @@ To add a dependency for GPU tests: - Have a Keras team member update the Docker image for GPU tests by running the remaining steps - Create a `Dockerfile` with the following contents: ``` -FROM tensorflow/tensorflow:2.9.1-gpu +FROM tensorflow/tensorflow:2.10.0-gpu +RUN \ + apt-get -y update && \ + apt-get -y install openjdk-8-jdk && \ + echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \ + curl https://bazel.build/bazel-release.pub.gpg | apt-key add RUN apt-get -y update +RUN apt-get -y install bazel RUN apt-get -y install git RUN git clone https://github.com/{path_to_keras_cv_fork}.git RUN cd keras-cv && git checkout {branch_name} diff --git a/cloudbuild/unit_test_jobs.jsonnet b/cloudbuild/unit_test_jobs.jsonnet index b003b7debb..8dbf236c17 100644 --- a/cloudbuild/unit_test_jobs.jsonnet +++ b/cloudbuild/unit_test_jobs.jsonnet @@ -22,6 +22,12 @@ local unittest = base.BaseTest { 'bash', '-c', ||| + # Build custom ops from source + python build_deps/configure.py + bazel build keras_cv/custom_ops:all --verbose_failures + cp bazel-bin/keras_cv/custom_ops/*.so keras_cv/custom_ops/ + TEST_CUSTOM_OPS=true + # Run whatever is in `command` here. ${@:0} ||| diff --git a/keras_cv/BUILD b/keras_cv/BUILD new file mode 100644 index 0000000000..bd084b7b08 --- /dev/null +++ b/keras_cv/BUILD @@ -0,0 +1,16 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "windows", + constraint_values = ["@bazel_tools//platforms:windows"], +) + +py_library( + name = "keras_cv", + srcs = glob(["**/*.py"]), + data = [ + "//keras_cv/custom_ops:_pairwise_iou_op.so", + ] +) diff --git a/keras_cv/custom_ops/BUILD b/keras_cv/custom_ops/BUILD new file mode 100644 index 0000000000..2cd7d47f06 --- /dev/null +++ b/keras_cv/custom_ops/BUILD @@ -0,0 +1,44 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "windows", + constraint_values = ["@bazel_tools//platforms:windows"], +) + +cc_library( + name = "box_util", + srcs = ["box_util.cc"], + hdrs = ["box_util.h"], + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + copts = select({ + ":windows": ["/DEIGEN_STRONG_INLINE=inline", "-DTENSORFLOW_MONOLITHIC_BUILD", "/DPLATFORM_WINDOWS", "/DEIGEN_HAS_C99_MATH", "/DTENSORFLOW_USE_EIGEN_THREADPOOL", "/DEIGEN_AVOID_STL_ARRAY", "/Iexternal/gemmlowp", "/wd4018", "/wd4577", "/DNOGDI", "/UTF_COMPILE_LIBRARY"], + "//conditions:default": ["-pthread", "-std=c++17"], + }), +) + +cc_binary( + name = '_keras_cv_custom_ops.so', + srcs = [ + "kernels/pairwise_iou_kernel.cc", + "ops/pairwise_iou_op.cc" + ], + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ":box_util", + ], + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + copts = select({ + ":windows": ["/DEIGEN_STRONG_INLINE=inline", "-DTENSORFLOW_MONOLITHIC_BUILD", "/DPLATFORM_WINDOWS", "/DEIGEN_HAS_C99_MATH", "/DTENSORFLOW_USE_EIGEN_THREADPOOL", "/DEIGEN_AVOID_STL_ARRAY", "/Iexternal/gemmlowp", "/wd4018", "/wd4577", "/DNOGDI", "/UTF_COMPILE_LIBRARY"], + "//conditions:default": ["-pthread", "-std=c++17"], + }), +) diff --git a/keras_cv/custom_ops/__init__.py b/keras_cv/custom_ops/__init__.py new file mode 100644 index 0000000000..65be099991 --- /dev/null +++ b/keras_cv/custom_ops/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_cv/custom_ops/box_util.cc b/keras_cv/custom_ops/box_util.cc new file mode 100644 index 0000000000..6b4cdec4c7 --- /dev/null +++ b/keras_cv/custom_ops/box_util.cc @@ -0,0 +1,327 @@ +/* Copyright 2022 The KerasCV Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "keras_cv/custom_ops/box_util.h" + +#include +#include + +namespace tensorflow { +namespace kerascv { +namespace box { + +const double kEPS = 1e-8; + +// Min,max box dimensions (length, width, height). Boxes with dimensions that +// exceed these values will have box intersections of 0. +constexpr double kMinBoxDim = 1e-3; +constexpr double kMaxBoxDim = 1e6; + +// A line with the representation a*x + b*y + c = 0. +struct Line { + double a = 0; + double b = 0; + double c = 0; + + Line(const Vertex& v1, const Vertex& v2) + : a(v2.y - v1.y), b(v1.x - v2.x), c(v2.x * v1.y - v2.y * v1.x) {} + + // Computes the line value for a vertex v as a * v.x + b * v.y + c + double LineValue(const Vertex& v) const { return a * v.x + b * v.y + c; } + + // Computes the intersection point with the other line. + Vertex IntersectionPoint(const Line& other) const { + const double w = a * other.b - b * other.a; + CHECK_GT(std::fabs(w), kEPS) << "No intersection between the two lines."; + return Vertex((b * other.c - c * other.b) / w, + (c * other.a - a * other.c) / w); + } +}; + +// Computes the coordinates of its four vertices given a 2D rotated box, +std::vector ComputeBoxVertices(const double cx, const double cy, + const double w, const double h, + const double heading) { + const double dxcos = (w / 2.) * std::cos(heading); + const double dxsin = (w / 2.) * std::sin(heading); + const double dycos = (h / 2.) * std::cos(heading); + const double dysin = (h / 2.) * std::sin(heading); + return {Vertex(cx - dxcos + dysin, cy - dxsin - dycos), + Vertex(cx + dxcos + dysin, cy + dxsin - dycos), + Vertex(cx + dxcos - dysin, cy + dxsin + dycos), + Vertex(cx - dxcos - dysin, cy - dxsin + dycos)}; +} + +// Computes the intersection points between two rotated boxes, by following: +// +// 1. Initiazlizes the current intersection points with the vertices of one box, +// and the other box is taken as the cutting box; +// +// 2. For each cutting line in the cutting box (four cutting lines in total): +// For each point in the current intersection points: +// If the point is inside of the cutting line: +// Adds it to the new intersection points; +// if current point and its next point are in the opposite side of the +// cutting line: +// Computes the line of current points and its next point as tmp_line; +// Computes the intersection point between the cutting line and +// tmp_line; +// Adds the intersection point to the new intersection points; +// After checking each cutting line, sets current intersection points as +// new intersection points; +// +// 3. Returns the final intersection points. +std::vector ComputeIntersectionPoints( + const std::vector& rbox_1, const std::vector& rbox_2) { + std::vector intersection = rbox_1; + const int vertices_len = rbox_2.size(); + for (int i = 0; i < rbox_2.size(); ++i) { + const int len = intersection.size(); + if (len <= 2) { + break; + } + const Vertex& p = rbox_2[i]; + const Vertex& q = rbox_2[(i + 1) % vertices_len]; + Line cutting_line(p, q); + // Computes line value. + std::vector line_values; + line_values.reserve(len); + for (int j = 0; j < len; ++j) { + line_values.push_back(cutting_line.LineValue(intersection[j])); + } + // Updates current intersection points. + std::vector new_intersection; + for (int j = 0; j < len; ++j) { + const double s_val = line_values[j]; + const Vertex& s = intersection[j]; + // Adds the current vertex. + if (s_val <= 0 || std::fabs(s_val) <= kEPS) { + new_intersection.push_back(s); + } + const double t_val = line_values[(j + 1) % len]; + // Skips the checking of intersection point if the next vertex is on the + // line. + if (std::fabs(t_val) <= kEPS) { + continue; + } + // Adds the intersection point. + if ((s_val > 0 && t_val < 0) || (s_val < 0 && t_val > 0)) { + Line s_t_line(s, intersection[(j + 1) % len]); + new_intersection.push_back(cutting_line.IntersectionPoint(s_t_line)); + } + } + intersection = new_intersection; + } + return intersection; +} + +// Computes the area of a convex polygon, +double ComputePolygonArea(const std::vector& convex_polygon) { + const int len = convex_polygon.size(); + if (len <= 2) { + return 0; + } + double area = 0; + for (int i = 0; i < len; ++i) { + const Vertex& p = convex_polygon[i]; + const Vertex& q = convex_polygon[(i + 1) % len]; + area += p.x * q.y - p.y * q.x; + } + return std::fabs(0.5 * area); +} + +RotatedBox2D::RotatedBox2D(const double cx, const double cy, const double w, + const double h, const double heading) + : cx_(cx), cy_(cy), w_(w), h_(h), heading_(heading) { + // Compute loose bounds on dimensions of box that doesn't require computing + // full intersection. We can do this by trying to compute the largest circle + // swept by rotating the box around its center. The radius of that circle + // is the length of the ray from the center to the box corner. The upper + // bound for this value is the length of the longer dimension divided by two + // and then multiplied by root(2) (worst-case being a square box); we choose + // 1.5 as slightly higher than root(2), and then use these extrema to do + // simple extrema box checks without having to compute the true cos/sin value. + double max_dim = std::max(w_, h_) / 2. * 1.5; + loose_min_x_ = cx_ - max_dim; + loose_max_x_ = cx_ + max_dim; + loose_min_y_ = cy_ - max_dim; + loose_max_y_ = cy_ + max_dim; + + extreme_box_dim_ = (w_ <= kMinBoxDim || h_ <= kMinBoxDim); + extreme_box_dim_ |= (w_ >= kMaxBoxDim || h_ >= kMaxBoxDim); +} + +double RotatedBox2D::Area() const { + if (area_ < 0) { + const double area = ComputePolygonArea(box_vertices()); + area_ = std::fabs(area) <= kEPS ? 0 : area; + } + return area_; +} + +const std::vector& RotatedBox2D::box_vertices() const { + if (box_vertices_.empty()) { + box_vertices_ = ComputeBoxVertices(cx_, cy_, w_, h_, heading_); + } + + return box_vertices_; +} + +bool RotatedBox2D::NonZeroAndValid() const { return !extreme_box_dim_; } + +bool RotatedBox2D::MaybeIntersects(const RotatedBox2D& other) const { + // If the box dimensions of either box are too small / large, + // assume they are not well-formed boxes (otherwise we are + // subject to issues due to catastrophic cancellation). + if (extreme_box_dim_ || other.extreme_box_dim_) { + return false; + } + + // Check whether the loose extrema overlap -- if not, then there is + // no chance that the two boxes overlap even when computing the true, + // more expensive overlap. + if ((loose_min_x_ > other.loose_max_x_) || + (loose_max_x_ < other.loose_min_x_) || + (loose_min_y_ > other.loose_max_y_) || + (loose_max_y_ < other.loose_min_y_)) { + return false; + } + + return true; +} + +double RotatedBox2D::Intersection(const RotatedBox2D& other) const { + // Do a fast intersection check - if the boxes are not near each other + // then we can return early. If they are close enough to maybe overlap, + // we do the full check. + if (!MaybeIntersects(other)) { + return 0.0; + } + + // Computes the intersection polygon. + const std::vector intersection_polygon = + ComputeIntersectionPoints(box_vertices(), other.box_vertices()); + // Computes the intersection area. + const double intersection_area = ComputePolygonArea(intersection_polygon); + + return std::fabs(intersection_area) <= kEPS ? 0 : intersection_area; +} + +double RotatedBox2D::IoU(const RotatedBox2D& other) const { + // Computes the intersection area. + const double intersection_area = Intersection(other); + if (intersection_area == 0) { + return 0; + } + // Computes the union area. + const double union_area = Area() + other.Area() - intersection_area; + if (std::fabs(union_area) <= kEPS) { + return 0; + } + return intersection_area / union_area; +} + +std::vector ParseBoxesFromTensor(const Tensor& boxes_tensor) { + int num_boxes = boxes_tensor.dim_size(0); + + const auto t_boxes_tensor = boxes_tensor.matrix(); + + std::vector bboxes3d; + bboxes3d.reserve(num_boxes); + for (int i = 0; i < num_boxes; ++i) { + const double center_x = t_boxes_tensor(i, 0); + const double center_y = t_boxes_tensor(i, 1); + const double center_z = t_boxes_tensor(i, 2); + const double dimension_x = t_boxes_tensor(i, 3); + const double dimension_y = t_boxes_tensor(i, 4); + const double dimension_z = t_boxes_tensor(i, 5); + const double heading = t_boxes_tensor(i, 6); + const double z_min = center_z - dimension_z / 2; + const double z_max = center_z + dimension_z / 2; + RotatedBox2D box2d(center_x, center_y, dimension_x, dimension_y, heading); + if (dimension_x <= 0 || dimension_y <= 0) { + bboxes3d.emplace_back(RotatedBox2D(), z_min, z_max); + } else { + bboxes3d.emplace_back(box2d, z_min, z_max); + } + } + return bboxes3d; +} + +bool Upright3DBox::NonZeroAndValid() const { + // If min is larger than max, the upright box is invalid. + // + // If the min and max are equal, the height of the box is 0. and thus the box + // is zero. + if (z_min - z_max >= 0.) { + return false; + } + + return rbox.NonZeroAndValid(); +} + +double Upright3DBox::IoU(const Upright3DBox& other) const { + // Check that both boxes are non-zero and valid. Otherwise, + // return 0. + if (!NonZeroAndValid() || !other.NonZeroAndValid()) { + return 0; + } + + // Quickly check whether z's overlap; if they don't, we can return 0. + const double z_inter = + std::max(.0, std::min(z_max, other.z_max) - std::max(z_min, other.z_min)); + if (z_inter == 0) { + return 0; + } + + const double base_inter = rbox.Intersection(other.rbox); + if (base_inter == 0) { + return 0; + } + + const double volume_1 = rbox.Area() * (z_max - z_min); + const double volume_2 = other.rbox.Area() * (other.z_max - other.z_min); + const double volume_inter = base_inter * z_inter; + const double volume_union = volume_1 + volume_2 - volume_inter; + return volume_inter > 0 ? volume_inter / volume_union : 0; +} + +double Upright3DBox::Overlap(const Upright3DBox& other) const { + // Check that both boxes are non-zero and valid. Otherwise, + // return 0. + if (!NonZeroAndValid() || !other.NonZeroAndValid()) { + return 0; + } + + const double z_inter = + std::max(.0, std::min(z_max, other.z_max) - std::max(z_min, other.z_min)); + if (z_inter == 0) { + return 0; + } + + const double base_inter = rbox.Intersection(other.rbox); + if (base_inter == 0) { + return 0; + } + + const double volume_1 = rbox.Area() * (z_max - z_min); + const double volume_inter = base_inter * z_inter; + // Normalizes intersection of volume by the volume of this box. + return volume_inter > 0 ? volume_inter / volume_1 : 0; +} + +} // namespace box +} // namespace kerascv +} // namespace tensorflow diff --git a/keras_cv/custom_ops/box_util.h b/keras_cv/custom_ops/box_util.h new file mode 100644 index 0000000000..a384abc844 --- /dev/null +++ b/keras_cv/custom_ops/box_util.h @@ -0,0 +1,141 @@ +/* Copyright 2022 The Keras CV Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_KERAS_CV_OPS_BOX_UTIL_H_ +#define THIRD_PARTY_PY_KERAS_CV_OPS_BOX_UTIL_H_ + +#include +#include + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace kerascv { +namespace box { + +// A vertex with (x, y) coordinate. +// +// This is an internal implementation detail of RotatedBox2D. +struct Vertex { + // Creates an empty Vertex. + Vertex() = default; + + Vertex(const double x, const double y) : x(x), y(y) {} + + double x = 0; + double y = 0; +}; + +// A rotated 2D bounding box represented as (cx, cy, w, h, r). cx, cy are the +// box center coordinates; w, h are the box width and height; heading is the +// rotation angle in radian relative to the 'positive x' direction. +class RotatedBox2D { + public: + // Creates an empty rotated 2D box. + RotatedBox2D() : RotatedBox2D(0, 0, 0, 0, 0) {} + + RotatedBox2D(const double cx, const double cy, const double w, const double h, + const double heading); + + // Returns the area of the box. + double Area() const; + + // Returns the intersection area between this box and the given box. + double Intersection(const RotatedBox2D& other) const; + + // Returns the IoU between this box and the given box. + double IoU(const RotatedBox2D& other) const; + + // Returns true if the box is valid (width and height are not extremely + // large or small). + bool NonZeroAndValid() const; + + private: + // Computes / caches box_vertices_ calculation. + const std::vector& box_vertices() const; + + // Returns true if this box and 'other' might intersect. + // + // If this returns false, the two boxes definitely do not intersect. If this + // returns true, it is still possible that the two boxes do not intersect, and + // the more expensive intersection code will be called. + bool MaybeIntersects(const RotatedBox2D& other) const; + + double cx_ = 0; + double cy_ = 0; + double w_ = 0; + double h_ = 0; + double heading_ = 0; + + // Loose boundaries for fast intersection test. + double loose_min_x_ = -1; + double loose_max_x_ = -1; + double loose_min_y_ = -1; + double loose_max_y_ = -1; + + // True if the dimensions of the box are very small or very large in any + // dimension. + bool extreme_box_dim_ = false; + + // The following fields are computed on demand. They are logically + // const. + + // Cached area. Access via Area() public API. + mutable double area_ = -1; + + // Stores the vertices of the box. Access via box_vertices(). + mutable std::vector box_vertices_; +}; + +// A 3D box of 7-DOFs: only allows rotation around the z-axis. +struct Upright3DBox { + RotatedBox2D rbox = RotatedBox2D(); + double z_min = 0; + double z_max = 0; + + // Creates an empty rotated 3D box. + Upright3DBox() = default; + + // Creates a 3D box from the raw input data with size 7. The data format is + // (center_x, center_y, center_z, dimension_x, dimension_y, dimension_z, + // heading) + Upright3DBox(const std::vector& raw) + : rbox(raw[0], raw[1], raw[3], raw[4], raw[6]), + z_min(raw[2] - raw[5] / 2.0), + z_max(raw[2] + raw[5] / 2.0) {} + + Upright3DBox(const RotatedBox2D& rb, const double z_min, const double z_max) + : rbox(rb), z_min(z_min), z_max(z_max) {} + + // Computes intersection over union (of the volume). + double IoU(const Upright3DBox& other) const; + + // Computes overlap: intersection of this box and the given box normalized + // over the volume of this box. + double Overlap(const Upright3DBox& other) const; + + // Returns true if the box is valid (width and height are not extremely + // large or small, and zmin < zmax). + bool NonZeroAndValid() const; +}; + +// Converts a [N, 7] tensor to a vector of N Upright3DBox objects. +std::vector ParseBoxesFromTensor(const Tensor& boxes_tensor); + +} // namespace box +} // namespace kerascv +} // namespace tensorflow + +#endif // THIRD_PARTY_PY_KERAS_CV_OPS_BOX_UTIL_H_ diff --git a/keras_cv/custom_ops/kernels/pairwise_iou_kernel.cc b/keras_cv/custom_ops/kernels/pairwise_iou_kernel.cc new file mode 100644 index 0000000000..fb3a9b6e1e --- /dev/null +++ b/keras_cv/custom_ops/kernels/pairwise_iou_kernel.cc @@ -0,0 +1,72 @@ +/* Copyright 2022 The KerasCV Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "keras_cv/custom_ops/box_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace kerascv { +namespace { + +class PairwiseIoUOp : public OpKernel { + public: + explicit PairwiseIoUOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), + errors::InvalidArgument("In[0] must be a matrix, but get ", + a.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), + errors::InvalidArgument("In[0] must be a matrix, but get ", + b.shape().DebugString())); + OP_REQUIRES(ctx, 7 == a.dim_size(1), + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + a.shape().DebugString())); + OP_REQUIRES(ctx, 7 == b.dim_size(1), + errors::InvalidArgument("Matrix size-incompatible: In[1]: ", + b.shape().DebugString())); + + const int n_a = a.dim_size(0); + const int n_b = b.dim_size(0); + + Tensor* iou_a_b = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("iou", TensorShape({n_a, n_b}), &iou_a_b)); + + auto t_iou_a_b = iou_a_b->matrix(); + + std::vector box_a = box::ParseBoxesFromTensor(a); + std::vector box_b = box::ParseBoxesFromTensor(b); + for (int i_a = 0; i_a < n_a; ++i_a) { + for (int i_b = 0; i_b < n_b; ++i_b) { + t_iou_a_b(i_a, i_b) = box_a[i_a].IoU(box_b[i_b]); + } + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("PairwiseIou3D").Device(DEVICE_CPU), + PairwiseIoUOp); + +} // namespace +} // namespace kerascv +} // namespace tensorflow diff --git a/keras_cv/custom_ops/ops/pairwise_iou_op.cc b/keras_cv/custom_ops/ops/pairwise_iou_op.cc new file mode 100644 index 0000000000..79bcc0909d --- /dev/null +++ b/keras_cv/custom_ops/ops/pairwise_iou_op.cc @@ -0,0 +1,35 @@ +/* Copyright 2022 The KerasCV Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; + +REGISTER_OP("PairwiseIou3D") + .Input("boxes_a: float") + .Input("boxes_b: float") + .Output("iou: float") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + c->set_output( + 0, c->MakeShape({c->Dim(c->input(0), 0), c->Dim(c->input(1), 0)})); + return tensorflow::Status(); + }) + .Doc(R"doc( +Calculate pairwise IoUs between two set of 3D bboxes. Every bbox is represented +as [center_x, center_y, center_z, dim_x, dim_y, dim_z, heading]. +boxes_a: A tensor of shape [num_boxes_a, 7] +boxes_b: A tensor of shape [num_boxes_b, 7] +)doc"); diff --git a/keras_cv/ops/__init__.py b/keras_cv/ops/__init__.py index 19888495bf..46d735c699 100644 --- a/keras_cv/ops/__init__.py +++ b/keras_cv/ops/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_cv.ops.iou_3d import IoU3D from keras_cv.ops.point_cloud import _box_area from keras_cv.ops.point_cloud import _center_xyzWHD_to_corner_xyz from keras_cv.ops.point_cloud import _is_on_lefthand_side diff --git a/keras_cv/ops/iou_3d.py b/keras_cv/ops/iou_3d.py new file mode 100644 index 0000000000..e6b05509ae --- /dev/null +++ b/keras_cv/ops/iou_3d.py @@ -0,0 +1,48 @@ +# Copyright 2022 The KerasCV Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""IoU3D using a custom TF op.""" + +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + + +class IoU3D: + """Implements IoU computation for 3D upright rotated bounding boxes. + + Note that this is implemented using a custom TensorFlow op. Initializing an + IoU3D object will attempt to load the binary for that op. + + Boxes should have the format [center_x, center_y, center_z, dimension_x, + dimension_y, dimension_z, heading (in radians)]. + + Sample Usage: + ```python + y_true = [[0, 0, 0, 2, 2, 2, 0], [1, 1, 1, 2, 2, 2, 3 * math.pi / 4]] + y_pred = [[1, 1, 1, 2, 2, 2, math.pi / 4], [1, 1, 1, 2, 2, 2, 0]] + iou = IoU3D() + iou(y_true, y_pred) + ``` + """ + + def __init__(self): + pairwise_iou_op = load_library.load_op_library( + resource_loader.get_path_to_datafile( + "../custom_ops/_keras_cv_custom_ops.so" + ) + ) + self.iou_3d = pairwise_iou_op.pairwise_iou3d + + def __call__(self, y_true, y_pred): + return self.iou_3d(y_true, y_pred) diff --git a/keras_cv/ops/iou_3d_test.py b/keras_cv/ops/iou_3d_test.py new file mode 100644 index 0000000000..b654e332fb --- /dev/null +++ b/keras_cv/ops/iou_3d_test.py @@ -0,0 +1,54 @@ +# Copyright 2022 The KerasCV Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for IoU3D using custom op.""" +import math +import os + +import pytest +import tensorflow as tf + +from keras_cv.ops import IoU3D + + +class IoU3DTest(tf.test.TestCase): + @pytest.mark.skipif( + "TEST_CUSTOM_OPS" not in os.environ or os.environ["TEST_CUSTOM_OPS"] != "true", + reason="Requires binaries compiled from source", + ) + def testOpCall(self): + # Predicted boxes: + # 0: a 2x2x2 box centered at 0,0,0, rotated 0 degrees + # 1: a 2x2x2 box centered at 1,1,1, rotated 135 degrees + # Ground Truth boxes: + # 0: a 2x2x2 box centered at 1,1,1, rotated 45 degrees (idential to predicted box 1) + # 1: a 2x2x2 box centered at 1,1,1, rotated 0 degrees + box_preds = [[0, 0, 0, 2, 2, 2, 0], [1, 1, 1, 2, 2, 2, 3 * math.pi / 4]] + box_gt = [[1, 1, 1, 2, 2, 2, math.pi / 4], [1, 1, 1, 2, 2, 2, 0]] + + # Predicted box 0 and both ground truth boxes overlap by 1/8th of the box. + # Therefore, IiU is 1/15 + # Predicted box 1 is the same as ground truth box 0, therefore IoU is 1 + # Predicted box 1 shares an origin with ground truth box 1, but is rotated by 135 degrees. + # Their IoU can be reduced to that of two overlapping squares that share a center with + # the same offset of 135 degrees, which reduces to the square root of 0.5. + expected_ious = [[1 / 15, 1 / 15], [1, 0.5**0.5]] + + iou_3d = IoU3D() + + self.assertAllClose(iou_3d(box_preds, box_gt), expected_ious) + + +if __name__ == "__main__": + tf.test.main() diff --git a/setup.py b/setup.py index cc19cccb90..7973d77591 100644 --- a/setup.py +++ b/setup.py @@ -18,10 +18,22 @@ from setuptools import find_packages from setuptools import setup +from setuptools.dist import Distribution HERE = pathlib.Path(__file__).parent README = (HERE / "README.md").read_text() + +class BinaryDistribution(Distribution): + """This class is needed in order to create OS specific wheels.""" + + def has_ext_modules(self): + return True + + def is_pure(self): + return False + + setup( name="keras-cv", description="Industry-strength computer Vision extensions for Keras.", @@ -37,6 +49,7 @@ "tests": ["flake8", "isort", "black", "pytest", "tensorflow-datasets"], "examples": ["tensorflow_datasets", "matplotlib"], }, + distclass=BinaryDistribution, classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 3.7", @@ -48,4 +61,5 @@ "Topic :: Software Development", ], packages=find_packages(exclude=("*_test.py",)), + include_package_data=True, )