Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide AArch64 (ARM) Linux jaxlib wheels #7097

Closed
bensworth opened this issue Jun 24, 2021 · 19 comments
Closed

Provide AArch64 (ARM) Linux jaxlib wheels #7097

bensworth opened this issue Jun 24, 2021 · 19 comments
Assignees
Labels
enhancement New feature or request

Comments

@bensworth
Copy link

Potentially related to #6932 or #7052.

I am using a heterogeneous cluster with various different compute options. I got Jax and Flax installed fine on the CPU system. Then I tried using a GPU node (specifically dual socket Cavium ThunderX2 99xx 32-core processors with two NVIDIA Tesla V100 GPUs), but am unable to install jaxlib:

$ pip install --upgrade jax jaxlib==0.1.68+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Requirement already satisfied: jax in /home/southworth/.conda/envs/jax_arch/lib/python3.7/site-packages (0.2.16)
ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.68+cuda111 (from versions: none)
ERROR: No matching distribution found for jaxlib==0.1.68+cuda111

or even

pip install jaxlib
ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none)
ERROR: No matching distribution found for jaxlib

Below is my pip version, which is up to date. Pip also installs other libraries fine (numpy, scipy, etc.)

$ pip --version
pip 21.1.2 from /home/southworth/.conda/envs/jax_arch/lib/python3.7/site-packages/pip (python 3.7)
@bensworth bensworth added the bug Something isn't working label Jun 24, 2021
@hawkinsp
Copy link
Member

hawkinsp commented Jun 24, 2021

The issue is that we don't provide Linux AArch64 wheels for jaxlib. Cavium is an ARM machine.

I suspect you may be able to build jaxlib from source on those machines: https://jax.readthedocs.io/en/latest/developer.html#building-from-source

Try it out and see how it goes?

(I'm not sure we have the resources to release ARM wheels ourselves right now, but I wouldn't rule it out for the future.)

@hawkinsp hawkinsp added contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request and removed bug Something isn't working labels Jun 24, 2021
@hawkinsp hawkinsp changed the title No matching distribution for jaxlib Provide AArch64 (ARM) Linux jaxlib wheels Jun 24, 2021
@hawkinsp
Copy link
Member

@bensworth Unfortunately we don't support Intel (Iris) GPUs either. Currently your hardware options are CPU, NVidia GPU, AMD GPU, or Google TPU.

So perhaps your best bet is to use CPU on that machine or build jaxlib with CUDA support on ARM.

@bensworth
Copy link
Author

bensworth commented Jun 24, 2021

Yep, realized that a few minutes after I posted, now I am on a Nvidia GPU and have things working! Previously I mostly worked with CPUs/MPI, didn't realize GPUs were so architecture specific. This issue can be closed unless you want to leave it open for somebody who may be able to build wheels for ARM or Intel; that's outside of my area of work though.

One thought - it may be worth adding a comment in the readme as to what GPUs are supported by the build instructions.

@angusdunnett
Copy link

angusdunnett commented Mar 31, 2022

I'm attempting to build jaxlib wheels for aarch64 Linux via Docker but without success. My Dockerfile is the following

FROM python:3.9-slim-bullseye

RUN apt-get update \
    && apt-get install -y --no-install-recommends g++ git python3.9-dev \
    && python3.9 -m venv jax-env \
    && /bin/bash -c "source ./jax-env/bin/activate; python3.9 -m pip install numpy six wheel; git clone https://github.com/google/jax; \
cd jax; python3.9 build/build.py --enable_mkl_dnn False --target_cpu_features default; python3.9 -m pip install dist/*.whl"

My machine is a MacBook Pro with M1 apple silicon chip, macOS 12.2. Building with docker build -t test-minimal results in

...
#5 35.42 Cloning into 'jax'...
#5 91.72 Extracting Bazel installation...
#5 94.11 Starting local Bazel server and connecting to it...
#5 94.87 INFO: Options provided by the client:
#5 94.87   Inherited 'common' options: --isatty=0 --terminal_columns=80
#5 94.87 INFO: Reading rc options for 'run' from /jax/.bazelrc:
#5 94.87   Inherited 'common' options: --experimental_repo_remote_exec
#5 94.87 INFO: Reading rc options for 'run' from /jax/.bazelrc:
#5 94.87   Inherited 'build' options: --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
#5 94.87 INFO: Reading rc options for 'run' from /jax/.jax_configure.bazelrc:
#5 94.87   Inherited 'build' options: --strategy=Genrule=standalone --repo_env PYTHON_BIN_PATH=/jax-env/bin/python3.9 --action_env=PYENV_ROOT --python_path=/jax-env/bin/python3.9 --distinct_host_configuration=false
#5 94.87 INFO: Found applicable config definition build:short_logs in file /jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
#5 94.87 INFO: Found applicable config definition build:linux in file /jax/.bazelrc: --config=posix --copt=-Wno-stringop-truncation --copt=-Wno-array-parameter
#5 94.87 INFO: Found applicable config definition build:posix in file /jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
#5 94.97 Loading: 
#5 94.97 Loading: 0 packages loaded
#5 96.00 Loading: 0 packages loaded
#5 97.01 Loading: 0 packages loaded
#5 98.02 Loading: 0 packages loaded
#5 99.01 Loading: 0 packages loaded
#5 100.1 Loading: 0 packages loaded
#5 101.1 Loading: 0 packages loaded
#5 102.1 Loading: 0 packages loaded
#5 103.4 Loading: 0 packages loaded
#5 104.8 Loading: 0 packages loaded
#5 106.5 Loading: 0 packages loaded
#5 108.3 Loading: 0 packages loaded
#5 110.5 Loading: 0 packages loaded
#5 113.0 Loading: 0 packages loaded
#5 115.7 Loading: 0 packages loaded
#5 118.9 Loading: 0 packages loaded
#5 120.7 WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/tensorflow/runtime/archive/0d730e48a1e31c0ce9033a222e0a8db10b4cb1f7.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
#5 122.5 Loading: 0 packages loaded
#5 126.7 Loading: 0 packages loaded
#5 131.6 Loading: 0 packages loaded
#5 137.2 Loading: 0 packages loaded
#5 143.7 Loading: 0 packages loaded
#5 153.7 Loading: 0 packages loaded
#5 162.5 Loading: 0 packages loaded
#5 172.7 Loading: 0 packages loaded
#5 178.5 WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
#5 178.6 Analyzing: target //build:build_wheel (1 packages loaded, 0 targets configured)
#5 181.1 DEBUG: Rule 'io_bazel_rules_docker' indicated that a canonical reproducible form can be obtained by modifying arguments shallow_since = "1596824487 -0400"
#5 181.1 DEBUG: Repository io_bazel_rules_docker instantiated at:
#5 181.1   /jax/WORKSPACE:37:14: in <toplevel>
#5 181.1   /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/external/org_tensorflow/tensorflow/workspace0.bzl:107:34: in workspace
#5 181.1   /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/external/bazel_toolchains/repositories/repositories.bzl:35:23: in repositories
#5 181.1 Repository rule git_repository defined at:
#5 181.1   /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/external/bazel_tools/tools/build_defs/repo/git.bzl:199:33: in <toplevel>
#5 191.3 Analyzing: target //build:build_wheel (34 packages loaded, 12 targets configured)
#5 207.3 Analyzing: target //build:build_wheel (34 packages loaded, 12 targets configured)
#5 222.1 INFO: Analyzed target //build:build_wheel (214 packages loaded, 14537 targets configured).
#5 222.1 INFO: Found 1 target...
#5 222.3 [0 / 10] [Prepa] Creating source manifest for //build:build_wheel
#5 241.5 [135 / 2,939] Compiling src/google/protobuf/descriptor.cc; 3s local ... (5 actions, 4 running)
#5 263.7 [274 / 3,221] Compiling llvm/utils/TableGen/DFAEmitter.cpp; 1s local ... (5 actions, 4 running)
#5 289.3 [420 / 3,221] Compiling llvm/utils/TableGen/X86MnemonicTables.cpp; 0s local ... (5 actions, 4 running)
#5 318.5 [634 / 3,455] Compiling llvm/lib/Support/LockFileManager.cpp; 0s local ... (5 actions, 4 running)
#5 352.1 [1,443 / 4,187] Compiling src/google/protobuf/descriptor.pb.cc; 2s local ... (5 actions, 4 running)
#5 390.8 [1,917 / 4,204] Compiling src/google/protobuf/wire_format.cc; 0s local ... (5 actions, 4 running)
#5 435.2 [2,554 / 4,208] Compiling src/core/ext/transport/chttp2/transport/stream_lists.cc; 0s local ... (5 actions, 4 running)
#5 486.9 [2,891 / 4,290] Compiling mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp; 3s local ... (5 actions, 4 running)
#5 546.2 [3,004 / 4,290] Compiling mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp; 3s local ... (5 actions, 4 running)
#5 615.7 [3,062 / 4,290] Compiling tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc; 36s local ... (5 actions, 4 running)
#5 693.9 [3,267 / 4,290] Compiling external/org_tensorflow/tensorflow/stream_executor/dnn.pb.cc; 1s local ... (5 actions, 4 running)
#5 786.0 [3,507 / 4,290] Compiling tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc; 6s local ... (5 actions running)
#5 890.3 [3,547 / 4,290] Compiling tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc; 110s local ... (5 actions running)
#5 920.3 ERROR: /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/BUILD:433:15: Compiling tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc failed: (Exit 1): gcc failed: error executing command 
#5 920.3   (cd /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/execroot/__main__ && \
#5 920.3   exec env - \
#5 920.3     PATH=/jax-env/bin:/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin \
#5 920.3     PWD=/proc/self/cwd \
#5 920.3   /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections -fdata-sections '-std=c++0x' -MD -MF bazel-out/aarch64-opt/bin/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/_objs/tensorflow_ops_n_z/tf_ops_n_z.pic.d '-frandom-seed=bazel-out/aarch64-opt/bin/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/_objs/tensorflow_ops_n_z/tf_ops_n_z.pic.o' -fPIC '-DLLVM_ON_UNIX=1' '-DHAVE_BACKTRACE=1' '-DBACKTRACE_HEADER=<execinfo.h>' '-DLTDL_SHLIB_EXT=".so"' '-DLLVM_PLUGIN_EXT=".so"' '-DLLVM_ENABLE_THREADS=1' '-DHAVE_DEREGISTER_FRAME=1' '-DHAVE_LIBPTHREAD=1' '-DHAVE_PTHREAD_GETNAME_NP=1' '-DHAVE_PTHREAD_GETSPECIFIC=1' '-DHAVE_PTHREAD_H=1' '-DHAVE_PTHREAD_SETNAME_NP=1' '-DHAVE_REGISTER_FRAME=1' '-DHAVE_SETENV_R=1' '-DHAVE_STRERROR_R=1' '-DHAVE_SYSEXITS_H=1' '-DHAVE_UNISTD_H=1' -D_GNU_SOURCE '-DHAVE_LINK_H=1' '-DHAVE_LSEEK64=1' '-DHAVE_MALLINFO=1' '-DHAVE_SBRK=1' '-DHAVE_STRUCT_STAT_ST_MTIM_TV_NSEC=1' '-DLLVM_NATIVE_ARCH="AArch64"' '-DLLVM_NATIVE_ASMPARSER=LLVMInitializeAArch64AsmParser' '-DLLVM_NATIVE_ASMPRINTER=LLVMInitializeAArch64AsmPrinter' '-DLLVM_NATIVE_DISASSEMBLER=LLVMInitializeAArch64Disassembler' '-DLLVM_NATIVE_TARGET=LLVMInitializeAArch64Target' '-DLLVM_NATIVE_TARGETINFO=LLVMInitializeAArch64TargetInfo' '-DLLVM_NATIVE_TARGETMC=LLVMInitializeAArch64TargetMC' '-DLLVM_NATIVE_TARGETMCA=LLVMInitializeAArch64TargetMCA' '-DLLVM_HOST_TRIPLE="aarch64-unknown-linux-gnu"' '-DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-unknown-linux-gnu"' -D__STDC_LIMIT_MACROS -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -DBLAKE3_NO_AVX2 -DBLAKE3_NO_AVX512 -DBLAKE3_NO_SSE2 -DBLAKE3_NO_SSE41 '-DBLAKE3_USE_NEON=0' -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -DHAVE_SYS_UIO_H -DTF_USE_SNAPPY -iquote external/org_tensorflow -iquote bazel-out/aarch64-opt/bin/external/org_tensorflow -iquote external/llvm-project -iquote bazel-out/aarch64-opt/bin/external/llvm-project -iquote external/llvm_terminfo -iquote bazel-out/aarch64-opt/bin/external/llvm_terminfo -iquote external/llvm_zlib -iquote bazel-out/aarch64-opt/bin/external/llvm_zlib -iquote external/com_google_absl -iquote bazel-out/aarch64-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/aarch64-opt/bin/external/nsync -iquote external/eigen_archive -iquote bazel-out/aarch64-opt/bin/external/eigen_archive -iquote external/gif -iquote bazel-out/aarch64-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/aarch64-opt/bin/external/libjpeg_turbo -iquote external/com_google_protobuf -iquote bazel-out/aarch64-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/aarch64-opt/bin/external/zlib -iquote external/com_googlesource_code_re2 -iquote bazel-out/aarch64-opt/bin/external/com_googlesource_code_re2 -iquote external/farmhash_archive -iquote bazel-out/aarch64-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/aarch64-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/aarch64-opt/bin/external/highwayhash -iquote external/double_conversion -iquote bazel-out/aarch64-opt/bin/external/double_conversion -iquote external/snappy -iquote bazel-out/aarch64-opt/bin/external/snappy -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinAttributeInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinAttributesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinDialectIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinLocationAttributesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinOpsIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinTypeInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinTypesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/CallOpInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/CastOpInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/FunctionInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/InferTypeOpInterfaceIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/OpAsmInterfaceIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/RegionKindInterfaceIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/SideEffectInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/SubElementInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/SymbolInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/TensorEncodingIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/ParserTokenKinds -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/ArithmeticBaseIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/ArithmeticCanonicalizationIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/ArithmeticOpsIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/VectorInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/ControlFlowInterfacesIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/ControlFlowOpsIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/FuncIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/DerivedAttributeOpInterfaceIncGen -Ibazel-out/aarch64-opt/bin/external/llvm-project/mlir/_virtual_includes/LoopLikeInterfaceIncGen -isystem external/llvm-project/mlir/include -isystem bazel-out/aarch64-opt/bin/external/llvm-project/mlir/include -isystem external/llvm-project/llvm/include -isystem bazel-out/aarch64-opt/bin/external/llvm-project/llvm/include -isystem external/nsync/public -isystem bazel-out/aarch64-opt/bin/external/nsync/public -isystem external/eigen_archive -isystem bazel-out/aarch64-opt/bin/external/eigen_archive -isystem external/gif -isystem bazel-out/aarch64-opt/bin/external/gif -isystem external/com_google_protobuf/src -isystem bazel-out/aarch64-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/aarch64-opt/bin/external/zlib -isystem external/farmhash_archive/src -isystem bazel-out/aarch64-opt/bin/external/farmhash_archive/src '-fvisibility=hidden' -Wno-sign-compare -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' '-std=c++14' -fno-canonical-system-headers -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -c external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc -o bazel-out/aarch64-opt/bin/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/_objs/tensorflow_ops_n_z/tf_ops_n_z.pic.o)
#5 920.3 # Configuration: 067882a40c13978b4ad738e4cc220caed1e342024c529713672fd04174120478
#5 920.3 # Execution platform: @local_execution_config_platform//:platform
#5 920.3 gcc: fatal error: Killed signal terminated program cc1plus
#5 920.3 compilation terminated.
#5 922.7 Target //build:build_wheel failed to build
#5 922.9 INFO: Elapsed time: 830.976s, Critical Path: 153.84s
#5 922.9 INFO: 3552 processes: 347 internal, 3205 local.
#5 922.9 FAILED: Build did NOT complete successfully
#5 922.9 ERROR: Build failed. Not running target
#5 922.9 FAILED: Build did NOT complete successfully
#5 923.0 
#5 923.0      _   _  __  __
#5 923.0     | | / \ \ \/ /
#5 923.0  _  | |/ _ \ \  /
#5 923.0 | |_| / ___ \/  \
#5 923.0  \___/_/   \/_/\_\
#5 923.0 
#5 923.0 
#5 923.0 Downloading bazel from: https://github.com/bazelbuild/bazel/releases/download/5.0.0/bazel-5.0.0-linux-arm64
#5 923.0 
#5 923.0 Bazel binary path: ./bazel-5.0.0-linux-arm64
#5 923.0 Bazel version: 5.0.0
#5 923.0 Python binary path: /jax-env/bin/python3.9
#5 923.0 Python version: 3.9
#5 923.0 NumPy version: 1.22.3
#5 923.0 MKL-DNN enabled: no
#5 923.0 Target CPU: aarch64
#5 923.0 Target CPU features: default
#5 923.0 CUDA enabled: no
#5 923.0 TPU enabled: no
#5 923.0 ROCm enabled: no
#5 923.0 
#5 923.0 Building XLA and installing it in the jaxlib source tree...
#5 923.0 ./bazel-5.0.0-linux-arm64 run --verbose_failures=true :build_wheel -- --output_path=/jax/dist --cpu=aarch64
#5 923.0 b''
#5 923.0 Traceback (most recent call last):
#5 923.0   File "/jax/build/build.py", line 527, in <module>
#5 923.0     main()
#5 923.0   File "/jax/build/build.py", line 522, in main
#5 923.0     shell(command)
#5 923.0   File "/jax/build/build.py", line 53, in shell
#5 923.0     output = subprocess.check_output(cmd)
#5 923.0   File "/usr/local/lib/python3.9/subprocess.py", line 424, in check_output
#5 923.0     return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
#5 923.0   File "/usr/local/lib/python3.9/subprocess.py", line 528, in run
#5 923.0     raise CalledProcessError(retcode, process.args,
#5 923.0 subprocess.CalledProcessError: Command '['./bazel-5.0.0-linux-arm64', 'run', '--verbose_failures=true', ':build_wheel', '--', '--output_path=/jax/dist', '--cpu=aarch64']' returned non-zero exit status 1.
#5 923.5 WARNING: Requirement 'dist/*.whl' looks like a filename, but the file does not exist
#5 923.5 ERROR: *.whl is not a valid wheel filename.
------
executor failed running [/bin/sh -c apt-get update     && apt-get install -y --no-install-recommends g++ git python3.9-dev     && python3.9 -m venv jax-env     && /bin/bash -c "source ./jax-env/bin/activate; python3.9 -m pip install numpy six wheel; git clone https://github.com/google/jax; cd jax; python3.9 build/build.py --enable_mkl_dnn False --target_cpu_features default; python3.9 -m pip install dist/*.whl"]: exit code: 1

In #5501 they say how to build jaxlib on the M1 chip , but not through Docker. Would it be possible to provide instructions for how this could be modified to work on Docker?

@yoziru
Copy link

yoziru commented Apr 27, 2022

☝️ this indicates you're running out of memory, especially since the 0.3.x release, you need pretty much the entire 16GB available.

I managed to build manylinux aarch64 wheels on M1, you can get them here: https://github.com/yoziru/jax/releases/tag/jaxlib-v0.3.5

Here is the Dockerfile for reproducing these wheels (run from the root of this repo after cloning):

FROM quay.io/pypa/manylinux2014_aarch64

COPY .bazelversion .
RUN BAZEL_VERSION=$(cat .bazelversion) \
    && curl -SL https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-linux-arm64 \
    -o /usr/local/bin/bazel \
    && chmod +x /usr/local/bin/bazel

ARG PYTHON_VERSION_MAIN=38 \
    JAXLIB_VERSION=0.3.5

RUN ln -s /opt/python/cp${PYTHON_VERSION_MAIN}-cp${PYTHON_VERSION_MAIN}/bin/pip /usr/local/bin/pip \
    && ln -s /opt/python/cp${PYTHON_VERSION_MAIN}-cp${PYTHON_VERSION_MAIN}/bin/python /usr/local/bin/python \
    && ln -s /opt/python/cp${PYTHON_VERSION_MAIN}-cp${PYTHON_VERSION_MAIN}/bin/python /usr/local/bin/python3
RUN pip install numpy six
WORKDIR /builder
COPY . .
RUN python build/build.py
RUN auditwheel repair dist/jaxlib-${JAXLIB_VERSION}-cp${PYTHON_VERSION_MAIN}-none-manylinux2014_aarch64.whl

@martin-g
Copy link

+1 for official support for Linux ARM64 from me too!

@mjsML
Copy link
Collaborator

mjsML commented Aug 15, 2022

@sudhakarsingh27 for viz

@hawkinsp
Copy link
Member

I'm curious whether using conda and conda-forge would work here.

We don't have infrastructure to build and test ARM wheels at the moment, but conda-forge does, as I understand it, and you can get JAX using conda install jax from the conda-forge channel. Would you be happy using conda instead?

@martin-g
Copy link

Hi @hawkinsp !
Yes, using conda-forge is an option!
As you already noticed I've created conda-forge/jaxlib-feedstock#125.

hawkinsp added a commit to hawkinsp/jax that referenced this issue Aug 16, 2022
* Add aarch64 as a known target_cpu value.
* Only pass --bazel_options to build actions since they can make "bazel
  shutdown" fail.
* Pass the bazel startup options to "bazel shutdown".

Issue google#7097
Fixes google#7639
@hawkinsp
Copy link
Member

I verified we can successfully build a Linux aarch64 jaxlib on a Google Cloud t2a virtual machine.

I was also able to cross-compile a Linux aarch64 jaxlib from a Debian Linux x86-64 machine by doing the following:

load("@local_config_cc//:cc_toolchain_config.bzl", "cc_toolchain_config")

package(default_visibility = ["//visibility:public"])

cc_toolchain_suite(
    name = "toolchain",
    toolchains = {
        "k8|compiler": "@local_config_cc//:cc-compiler-k8",
        "k8": "@local_config_cc//:cc-compiler-k8",
        "aarch64": ":cc-compiler-aarch64",
    },
)

cc_toolchain(
    name = "cc-compiler-aarch64",
    all_files = "@local_config_cc//:compiler_deps",
    ar_files = "@local_config_cc//:compiler_deps",
    as_files = "@local_config_cc//:compiler_deps",
    compiler_files = "@local_config_cc//:compiler_deps",
    dwp_files = ":empty",
    linker_files = "@local_config_cc//:compiler_deps",
    module_map = None,
    objcopy_files = ":empty",
    strip_files = ":empty",
    supports_param_files = 1,
    toolchain_config = ":cross_aarch64",
    toolchain_identifier = "cross_aarch64",
)

cc_toolchain_config(
    name = "cross_aarch64",
    abi_libc_version = "local",
    abi_version = "local",
    compile_flags = [
        "-U_FORTIFY_SOURCE",
        "-fstack-protector",
        "-Wall",
        "-Wunused-but-set-parameter",
        "-Wno-free-nonheap-object",
        "-fno-omit-frame-pointer",
    ],
    compiler = "compiler",
    coverage_compile_flags = ["--coverage"],
    coverage_link_flags = ["--coverage"],
    cpu = "aarch64",
    cxx_builtin_include_directories = [
        "/usr/aarch64-linux-gnu/include",
        "/usr/lib/gcc-cross/aarch64-linux-gnu/11/include",
        "/usr/local/include",
        "/usr/include",
        "/usr/include/c++/11",
        "/usr/include/c++/11/backward",
    ],
    cxx_flags = ["-std=c++0x"],
    dbg_compile_flags = ["-g"],
    host_system_name = "local",
    link_flags = [
        "-fuse-ld=gold",
        "-Wl,-no-as-needed",
        "-Wl,-z,relro,-z,now",
        "-B/usr/bin/aarch64-linux-gnu-",
        "-pass-exit-codes",
    ],
    link_libs = [
        "-lstdc++",
        "-lm",
    ],
    opt_compile_flags = [
        "-g0",
        "-O2",
        "-D_FORTIFY_SOURCE=1",
        "-DNDEBUG",
        "-ffunction-sections",
        "-fdata-sections",
    ],
    opt_link_flags = ["-Wl,--gc-sections"],
    supports_start_end_lib = True,
    target_libc = "local",
    target_system_name = "local",
    tool_paths = {
        "ar": "/usr/bin/ar",
        "ld": "/usr/bin/aarch64-linux-gnu-ld",
        "llvm-cov": "/usr/bin/llvm-cov",
        "cpp": "/usr/bin/aarch64-linux-gnu-cpp",
        "gcc": "/usr/bin/aarch64-linux-gnu-gcc",
        "dwp": "/usr/bin/aarch64-linux-gnu-dwp",
        "gcov": "/usr/bin/aarch64-linux-gnu-gcov",
        "nm": "/usr/bin/aarch64-linux-gnu-nm",
        "objcopy": "/usr/bin/aarch64-linux-gnu-objcopy",
        "objdump": "/usr/bin/aarch64-linux-gnu-objdump",
        "strip": "/usr/bin/aarch64-linux-gnu-strip",
    },
    toolchain_identifier = "cross_aarch64",
    unfiltered_compile_flags = [
        "-fno-canonical-system-headers",
        "-Wno-builtin-macro-redefined",
        "-D__DATE__=\"redacted\"",
        "-D__TIMESTAMP__=\"redacted\"",
        "-D__TIME__=\"redacted\"",
    ],
)
  • and build with:
python build/build.py  --bazel_option=--crosstool_top=//toolchain:toolchain --target_cpu=aarch64 --bazel_options=--override_repository=org_tensorflow=/path/to/the/tensorflow/checkout

and I verified this wheel works at least for some simple tests on a t2a VM.

@emiliofernandes
Copy link

Hi! What is stopping the progress here ?
I also need jaxlib Linux aarch64 wheels!

treyra added a commit to treyra/jax that referenced this issue Oct 27, 2022
Currently non x86_64 linux architectures are not supported, see google#7097 for request to change this. This can lead to installation confusion, as jax will install, but jaxlib will not. For example see google#12307. 

If this can be more clearly phrased or explained, let me know.
treyra added a commit to treyra/jax that referenced this issue Oct 27, 2022
Currently non x86_64 linux architectures are not supported, see google#7097 for request to change this. This can lead to installation confusion, as jax will install, but jaxlib will not. For example see google#12307. 

If this can be more clearly phrased or explained, let me know.
treyra added a commit to treyra/jax that referenced this issue Nov 4, 2022
Currently non x86_64 linux architectures are not supported, see google#7097 for request to change this. This can lead to installation confusion, as jax will install, but jaxlib will not. For example see google#12307. This adds a note to the install sections for the relevant pip wheels.
@markjens
Copy link

I am glad to see there is some progress already on this task!
I hope we get aarch64 wheels soon at https://storage.googleapis.com/jax-releases/jax_cuda_releases.html !

nikitosl pushed a commit to nikitosl/jokes-generator-web that referenced this issue Jan 8, 2023
# Current blocker is that jax not supported linux library for ARM64
# google/jax#7097
@hawkinsp hawkinsp removed the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Mar 30, 2023
@ridwan-salau
Copy link

If you experience this while building a docker image like I did, then the answer here solves the problem easily for me and many others (as seen in the number of upvotes).

@martin-g
Copy link

martin-g commented May 4, 2023

@ridwan-salau The solution you linked to is a workaround to use the AMD64 image on Mac ARM64 via emulation (Rosetta).
This issue is about providing a proper ARM64 image so there is no need of emulation (Rosetta or QEMU).

@thomasaarholt
Copy link

I've come back to this issue several times in the last 6 months. I'd love to use jax in our M1 Mac docker images without resorting to Rosetta.

Is there anything I can do to help push this forward? I can have a go at modifying the current build jobs to support aarch64, but only if such a PR is welcome.

@thomasaarholt
Copy link

thomasaarholt commented Jun 1, 2023

I had a go at building from source, and it was remarkably easy - following the "build from source" instructions using python build/build.py worked without a hitch. Unfortunately, the process took 38 minutes, and that is without emulation, in a aarch64 debian image on my M1 Mac. I assume this is the real issue - if the build process was to run on e.g. a github action emulated on QEMU, then it would probably take several times longer 🤔

joverlee521 added a commit to nextstrain/docker-base that referenced this issue Jul 6, 2023
The latest release of evofr included updates to its dependencies that
updated the minimum version for jaxlib.¹ I'm using the latest release
of jaxlib available in the user forked repo since there's still no
official pre-built binaries for linux/arm64.²

I did not realize this change was needed until I saw the
"Validate Platform" job fail in the CI run³ that I triggered to update
evofr to v0.1.20.

¹ blab/evofr@46f744b
² google/jax#7097
³ https://github.com/nextstrain/docker-base/actions/runs/5480222698
joverlee521 added a commit to nextstrain/docker-base that referenced this issue Jul 7, 2023
The latest release of evofr included updates to its dependencies that
updated the minimum version for jaxlib.¹ I'm using the latest release
of jaxlib available in the user forked repo since there's still no
official pre-built binaries for linux/arm64.²

I did not realize this change was needed until I saw the
"Validate Platform" job fail in the CI run³ that I triggered to update
evofr to v0.1.20. Added a note for our future selves to check the
jaxlib installation if we run into different versions of evofr on the
different platforms.

¹ blab/evofr@46f744b
² google/jax#7097
³ https://github.com/nextstrain/docker-base/actions/runs/5480222698
@victorlin
Copy link

@yoziru the v0.4.13 manylinux aarch64 wheels aren't working properly. Can you enable issues on your fork so I can send a detailed report?

victorlin added a commit to victorlin/build-jax that referenced this issue Jul 18, 2023
from github . com/google/issues/7097#issuecomment-1110730040
@hawkinsp
Copy link
Member

hawkinsp commented Oct 5, 2023

We've started publishing aarch64 Linux nightly wheels (https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html) and we will include aarch64 Linux in the next jaxlib release.

@hawkinsp
Copy link
Member

hawkinsp commented Oct 9, 2023

Jaxlib 0.4.18 includes Linux aarch64 wheels (thanks @yashk2810 !)

@hawkinsp hawkinsp closed this as completed Oct 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests