Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update JAX to use new math libraries in ROCm-5.0.
- Loading branch information
Showing
58 changed files
with
5,179 additions
and
1,543 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
FROM ubuntu:bionic | ||
MAINTAINER Reza Rahimi <reza.rahimi@amd.com> | ||
|
||
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.0/ | ||
ARG ROCM_BUILD_NAME=ubuntu | ||
ARG ROCM_BUILD_NUM=main | ||
ARG ROCM_PATH=/opt/rocm-5.0.0 | ||
|
||
ARG DEBIAN_FRONTEND=noninteractive | ||
ENV HOME /root/ | ||
ENV ROCM_PATH=$ROCM_PATH | ||
|
||
RUN apt-get --allow-unauthenticated update && apt install -y wget software-properties-common | ||
RUN apt-get clean all | ||
RUN wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -; | ||
RUN bin/bash -c 'if [[ $ROCM_DEB_REPO == http://repo.radeon.com/rocm/* ]] ; then \ | ||
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \ | ||
else \ | ||
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \ | ||
fi' | ||
|
||
|
||
RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \ | ||
build-essential \ | ||
software-properties-common \ | ||
clang-6.0 \ | ||
clang-format-6.0 \ | ||
curl \ | ||
g++-multilib \ | ||
git \ | ||
vim \ | ||
libnuma-dev \ | ||
virtualenv \ | ||
python3-pip \ | ||
pciutils \ | ||
wget && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
# Add to get ppa | ||
RUN apt-get update | ||
RUN apt-get install -y software-properties-common | ||
# Install rocm pkgs | ||
RUN apt-get update --allow-insecure-repositories && \ | ||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ | ||
rocm-dev rocm-libs rccl && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
# Set up paths | ||
ENV HCC_HOME=$ROCM_PATH/hcc | ||
ENV HIP_PATH=$ROCM_PATH/hip | ||
ENV OPENCL_ROOT=$ROCM_PATH/opencl | ||
ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" | ||
ENV PATH="$ROCM_PATH/bin:${PATH}" | ||
ENV PATH="$OPENCL_ROOT/bin:${PATH}" | ||
|
||
# Add target file to help determine which device(s) to build for | ||
RUN bash -c 'echo -e "gfx900\ngfx906\ngfx908\ngfx90a\ngfx1030" >> ${ROCM_PATH}/bin/target.lst' | ||
|
||
# Need to explicitly create the $ROCM_PATH/.info/version file to workaround what seems to be a bazel bug | ||
# The env vars being set via --action_env in .bazelrc and .tf_configure.bazelrc files are sometimes | ||
# not getting set in the build command being spawned by bazel (in theory this should not happen) | ||
# As a consequence ROCM_PATH is sometimes not set for the hipcc commands. | ||
# When hipcc incokes hcc, it specifies $ROCM_PATH/.../include dirs via the `-isystem` options | ||
# If ROCM_PATH is not set, it defaults to /opt/rocm, and as a consequence a dependency is generated on the | ||
# header files included within `/opt/rocm`, which then leads to bazel dependency errors | ||
# Explicitly creating the $ROCM_PATH/.info/version allows ROCM path to be set correrctly, even when ROCM_PATH | ||
# is not explicitly set, and thus avoids the eventual bazel dependency error. | ||
# The bazel bug needs to be root-caused and addressed, but that is out of our control and may take a long time | ||
# to come to fruition, so implementing the workaround to make do till then | ||
# Filed https://github.com/bazelbuild/bazel/issues/11163 for tracking this | ||
RUN touch ${ROCM_PATH}/.info/version | ||
|
||
ENV PATH="/root/bin:/root/.local/bin:$PATH" | ||
|
||
|
||
# Install python3.9 | ||
RUN add-apt-repository ppa:deadsnakes/ppa && \ | ||
apt update && \ | ||
apt install -y python3.9-dev \ | ||
python3-pip \ | ||
python3.9-distutils | ||
|
||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 1 && \ | ||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2 | ||
|
||
RUN pip3 install --upgrade setuptools pip | ||
|
||
RUN pip3 install absl-py numpy==1.19.5 scipy wheel six setuptools pytest pytest-rerunfailures | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# JAX Builds on ROCm | ||
This directory contains files and setup instructions t0 build and test JAX for ROCm in Docker environment. You can build, test and run JAX on ROCm yourself! | ||
*** | ||
### Build JAX-ROCm in docker | ||
|
||
1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/). | ||
|
||
2. Build JAX by running the following command from JAX root folder. | ||
|
||
./build/rocm/ci_build.sh --keep_image bash -c "./build/rocm/build_rocm.sh" | ||
|
||
3. Launch a contianer: If the build was sucessful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them). | ||
``` | ||
sudo docker run -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --entrypoint /bin/bash jax-rocm:latest | ||
``` | ||
|
||
*** | ||
### Build and Test JAX-ROCm in docker (suitable for CI jobs) | ||
This folder has all the scripts necessary to build and run tests for JAX-ROCm. | ||
The following command will build JAX on ROCm and run all the tests inside docker (script should be called from JAX root folder). | ||
``` | ||
./build/rocm/ci_build.sh bash -c "./build/rocm/build_rocm.sh&&./build/rocm/run_single_gpu.py&&build/rocm/run_multi_gpu.sh" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Common Bash functions used by build scripts | ||
|
||
die() { | ||
# Print a message and exit with code 1. | ||
# | ||
# Usage: die <error_message> | ||
# e.g., die "Something bad happened." | ||
|
||
echo $@ | ||
exit 1 | ||
} | ||
|
||
realpath() { | ||
# Get the real path of a file | ||
# Usage: realpath <file_path> | ||
|
||
if [[ "$#" != "1" ]]; then | ||
die "realpath: incorrect usage" | ||
fi | ||
|
||
[[ "$1" = /* ]] && echo "$1" || echo "$PWD/${1#./}" | ||
} | ||
|
||
to_lower() { | ||
# Convert the string to lower case | ||
# Usage: to_lower <string> | ||
|
||
echo "$1" | tr '[:upper:]' '[:lower:]' | ||
} | ||
|
||
calc_elapsed_time() { | ||
# Calculate elapsed time. Takes nanosecond format input of the kind output | ||
# by date +'%s%N' | ||
# | ||
# Usage: calc_elapsed_time <START_TIME> <END_TIME> | ||
|
||
if [[ $# != "2" ]]; then | ||
die "calc_elapsed_time: incorrect usage" | ||
fi | ||
|
||
START_TIME=$1 | ||
END_TIME=$2 | ||
|
||
if [[ ${START_TIME} == *"N" ]]; then | ||
# Nanosecond precision not available | ||
START_TIME=$(echo ${START_TIME} | sed -e 's/N//g') | ||
END_TIME=$(echo ${END_TIME} | sed -e 's/N//g') | ||
ELAPSED="$(expr ${END_TIME} - ${START_TIME}) s" | ||
else | ||
ELAPSED="$(expr $(expr ${END_TIME} - ${START_TIME}) / 1000000) ms" | ||
fi | ||
|
||
echo ${ELAPSED} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#!/usr/bin/env bash | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
set -eux | ||
|
||
ROCM_TF_FORK_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream" | ||
ROCM_TF_FORK_BRANCH="develop-upstream" | ||
rm -rf /tmp/tensorflow-upstream || true | ||
git clone -b ${ROCM_TF_FORK_BRANCH} ${ROCM_TF_FORK_REPO} /tmp/tensorflow-upstream | ||
|
||
python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=org_tensorflow=/tmp/tensorflow-upstream | ||
pip3 install --use-feature=2020-resolver --force-reinstall dist/*.whl # installs jaxlib (includes XLA) | ||
pip3 install --use-feature=2020-resolver --force-reinstall . # installs jax |
Oops, something went wrong.