Skip to content

Commit

Permalink
Update JAX to use new math libraries in ROCm-5.0.
Browse files Browse the repository at this point in the history
  • Loading branch information
reza-amd committed Mar 1, 2022
1 parent 11664f8 commit a0d9d81
Show file tree
Hide file tree
Showing 58 changed files with 5,179 additions and 1,543 deletions.
7 changes: 6 additions & 1 deletion build/BUILD.bazel
Expand Up @@ -52,7 +52,12 @@ py_binary(
"//jaxlib:_cuda_prng",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib:rocblas_kernels",
"//jaxlib:hip_gpu_support",
"//jaxlib:_hipblas",
"//jaxlib:_hipsolver",
"//jaxlib:_hipsparse",
"//jaxlib:_hip_linalg",
"//jaxlib:_hip_prng",
]),
deps = ["@bazel_tools//tools/python/runfiles"],
)
5 changes: 3 additions & 2 deletions build/build.py
Expand Up @@ -383,7 +383,7 @@ def main():
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx900,gfx906,gfx90",
default="gfx900,gfx906,gfx908,gfx90a,gfx1030",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--rocm_path",
Expand Down Expand Up @@ -510,7 +510,8 @@ def main():
config_args += ["--config=tpu"]
if args.enable_rocm:
config_args += ["--config=rocm"]
config_args += ["--config=nonccl"]
if not args.enable_nccl:
config_args += ["--config=nonccl"]

command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] + config_args +
Expand Down
20 changes: 17 additions & 3 deletions build/build_wheel.py
Expand Up @@ -197,11 +197,21 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cublas.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_linalg.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_prng.so"))
if r.Rlocation("__main__/jaxlib/_hipsolver.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipsolver.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipblas.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_linalg.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_prng.so"))
if r.Rlocation("__main__/jaxlib/_cusolver.pyd") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cusolver.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cublas.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_linalg.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_prng.pyd"))
if r.Rlocation("__main__/jaxlib/_hipsolver.pyd") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipsolver.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipblas.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_linalg.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_prng.pyd"))
if r.Rlocation("__main__/jaxlib/cusolver.py") is not None:
libdevice_dir = os.path.join(jaxlib_dir, "cuda", "nvvm", "libdevice")
os.makedirs(libdevice_dir)
Expand All @@ -210,12 +220,16 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_linalg.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
if r.Rlocation("__main__/jaxlib/rocblas_kernels.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocblas_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocsolver.py"))
if r.Rlocation("__main__/jaxlib/hipsolver.py") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hipsolver.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hip_linalg.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hip_prng.py"))
if r.Rlocation("__main__/jaxlib/_cusparse.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cusparse.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.py"))
if r.Rlocation("__main__/jaxlib/_hipsparse.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipsparse.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hipsparse.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))

mlir_dir = os.path.join(jaxlib_dir, "mlir")
Expand Down
91 changes: 91 additions & 0 deletions build/rocm/Dockerfile.rocm
@@ -0,0 +1,91 @@
FROM ubuntu:bionic
MAINTAINER Reza Rahimi <reza.rahimi@amd.com>

ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.0/
ARG ROCM_BUILD_NAME=ubuntu
ARG ROCM_BUILD_NUM=main
ARG ROCM_PATH=/opt/rocm-5.0.0

ARG DEBIAN_FRONTEND=noninteractive
ENV HOME /root/
ENV ROCM_PATH=$ROCM_PATH

RUN apt-get --allow-unauthenticated update && apt install -y wget software-properties-common
RUN apt-get clean all
RUN wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -;
RUN bin/bash -c 'if [[ $ROCM_DEB_REPO == http://repo.radeon.com/rocm/* ]] ; then \
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \
else \
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \
fi'


RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
software-properties-common \
clang-6.0 \
clang-format-6.0 \
curl \
g++-multilib \
git \
vim \
libnuma-dev \
virtualenv \
python3-pip \
pciutils \
wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Add to get ppa
RUN apt-get update
RUN apt-get install -y software-properties-common
# Install rocm pkgs
RUN apt-get update --allow-insecure-repositories && \
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
rocm-dev rocm-libs rccl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Set up paths
ENV HCC_HOME=$ROCM_PATH/hcc
ENV HIP_PATH=$ROCM_PATH/hip
ENV OPENCL_ROOT=$ROCM_PATH/opencl
ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}"
ENV PATH="$ROCM_PATH/bin:${PATH}"
ENV PATH="$OPENCL_ROOT/bin:${PATH}"

# Add target file to help determine which device(s) to build for
RUN bash -c 'echo -e "gfx900\ngfx906\ngfx908\ngfx90a\ngfx1030" >> ${ROCM_PATH}/bin/target.lst'

# Need to explicitly create the $ROCM_PATH/.info/version file to workaround what seems to be a bazel bug
# The env vars being set via --action_env in .bazelrc and .tf_configure.bazelrc files are sometimes
# not getting set in the build command being spawned by bazel (in theory this should not happen)
# As a consequence ROCM_PATH is sometimes not set for the hipcc commands.
# When hipcc incokes hcc, it specifies $ROCM_PATH/.../include dirs via the `-isystem` options
# If ROCM_PATH is not set, it defaults to /opt/rocm, and as a consequence a dependency is generated on the
# header files included within `/opt/rocm`, which then leads to bazel dependency errors
# Explicitly creating the $ROCM_PATH/.info/version allows ROCM path to be set correrctly, even when ROCM_PATH
# is not explicitly set, and thus avoids the eventual bazel dependency error.
# The bazel bug needs to be root-caused and addressed, but that is out of our control and may take a long time
# to come to fruition, so implementing the workaround to make do till then
# Filed https://github.com/bazelbuild/bazel/issues/11163 for tracking this
RUN touch ${ROCM_PATH}/.info/version

ENV PATH="/root/bin:/root/.local/bin:$PATH"


# Install python3.9
RUN add-apt-repository ppa:deadsnakes/ppa && \
apt update && \
apt install -y python3.9-dev \
python3-pip \
python3.9-distutils

RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 1 && \
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2

RUN pip3 install --upgrade setuptools pip

RUN pip3 install absl-py numpy==1.19.5 scipy wheel six setuptools pytest pytest-rerunfailures

23 changes: 23 additions & 0 deletions build/rocm/README.md
@@ -0,0 +1,23 @@
# JAX Builds on ROCm
This directory contains files and setup instructions t0 build and test JAX for ROCm in Docker environment. You can build, test and run JAX on ROCm yourself!
***
### Build JAX-ROCm in docker

1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/).

2. Build JAX by running the following command from JAX root folder.

./build/rocm/ci_build.sh --keep_image bash -c "./build/rocm/build_rocm.sh"

3. Launch a contianer: If the build was sucessful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them).
```
sudo docker run -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --entrypoint /bin/bash jax-rocm:latest
```

***
### Build and Test JAX-ROCm in docker (suitable for CI jobs)
This folder has all the scripts necessary to build and run tests for JAX-ROCm.
The following command will build JAX on ROCm and run all the tests inside docker (script should be called from JAX root folder).
```
./build/rocm/ci_build.sh bash -c "./build/rocm/build_rocm.sh&&./build/rocm/run_single_gpu.py&&build/rocm/run_multi_gpu.sh"
```
70 changes: 70 additions & 0 deletions build/rocm/build_common.sh
@@ -0,0 +1,70 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Common Bash functions used by build scripts

die() {
# Print a message and exit with code 1.
#
# Usage: die <error_message>
# e.g., die "Something bad happened."

echo $@
exit 1
}

realpath() {
# Get the real path of a file
# Usage: realpath <file_path>

if [[ "$#" != "1" ]]; then
die "realpath: incorrect usage"
fi

[[ "$1" = /* ]] && echo "$1" || echo "$PWD/${1#./}"
}

to_lower() {
# Convert the string to lower case
# Usage: to_lower <string>

echo "$1" | tr '[:upper:]' '[:lower:]'
}

calc_elapsed_time() {
# Calculate elapsed time. Takes nanosecond format input of the kind output
# by date +'%s%N'
#
# Usage: calc_elapsed_time <START_TIME> <END_TIME>

if [[ $# != "2" ]]; then
die "calc_elapsed_time: incorrect usage"
fi

START_TIME=$1
END_TIME=$2

if [[ ${START_TIME} == *"N" ]]; then
# Nanosecond precision not available
START_TIME=$(echo ${START_TIME} | sed -e 's/N//g')
END_TIME=$(echo ${END_TIME} | sed -e 's/N//g')
ELAPSED="$(expr ${END_TIME} - ${START_TIME}) s"
else
ELAPSED="$(expr $(expr ${END_TIME} - ${START_TIME}) / 1000000) ms"
fi

echo ${ELAPSED}
}


25 changes: 25 additions & 0 deletions build/rocm/build_rocm.sh
@@ -0,0 +1,25 @@
#!/usr/bin/env bash
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -eux

ROCM_TF_FORK_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream"
ROCM_TF_FORK_BRANCH="develop-upstream"
rm -rf /tmp/tensorflow-upstream || true
git clone -b ${ROCM_TF_FORK_BRANCH} ${ROCM_TF_FORK_REPO} /tmp/tensorflow-upstream

python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=org_tensorflow=/tmp/tensorflow-upstream
pip3 install --use-feature=2020-resolver --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
pip3 install --use-feature=2020-resolver --force-reinstall . # installs jax

0 comments on commit a0d9d81

Please sign in to comment.