Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request]: Add support for Linux ARM64 #125

Open
1 task done
martin-g opened this issue Aug 15, 2022 · 26 comments
Open
1 task done

[Feature request]: Add support for Linux ARM64 #125

martin-g opened this issue Aug 15, 2022 · 26 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@martin-g
Copy link

Solution to issue cannot be found in the documentation.

  • I checked the documentation.

Issue

According to https://anaconda.org/conda-forge/jaxlib the current supported OS+CPU architectures are:

  • Mac ARM64
  • Mac AMD64
  • Linux AMD64

I'd like to request adding Linux ARM64 to this list.

At the moment AlphaFold project cannot be used on Linux ARM64 due to a missing jaxlib+cuda Python wheel - google-deepmind/alphafold#528
Currently AlphaFold uses Pip3 to install jaxlib from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
It would be nice if it could use conda-forge instead!

Installed packages

# packages in environment at /home/mgrigorov/devel/conda:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                      51_gnu  
brotlipy                  0.7.0           py39hfd63f10_1002  
ca-certificates           2022.3.29            hd43f75c_1  
certifi                   2021.10.8        py39hd43f75c_2  
cffi                      1.15.0           py39h9a3cfec_1  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
colorama                  0.4.4              pyhd3eb1b0_0  
conda                     4.12.0           py39hd43f75c_0  
conda-content-trust       0.1.1              pyhd3eb1b0_0  
conda-package-handling    1.8.1            py39h2f4d8fa_0  
cryptography              36.0.0           py39h3d58568_0  
idna                      3.3                pyhd3eb1b0_0  
ld_impl_linux-aarch64     2.36.1               h0ab8de2_3  
libffi                    3.3                  h7c1a80f_2  
libgcc-ng                 10.2.0              h1234567_51  
libgomp                   10.2.0              h1234567_51  
libstdcxx-ng              10.2.0              h1234567_51  
ncurses                   6.3                  h2f4d8fa_2  
openssl                   1.1.1n               h2f4d8fa_0  
pip                       21.2.4           py39hd43f75c_0  
pycosat                   0.6.3            py39hfd63f10_2  
pycparser                 2.21               pyhd3eb1b0_0  
pyopenssl                 22.0.0             pyhd3eb1b0_0  
pysocks                   1.7.1            py39hd43f75c_0  
python                    3.9.12               hc137634_0  
readline                  8.1.2                h2f4d8fa_1  
requests                  2.27.1             pyhd3eb1b0_0  
ruamel_yaml               0.15.100         py39h2f4d8fa_0  
setuptools                61.2.0           py39hd43f75c_0  
six                       1.16.0             pyhd3eb1b0_1  
sqlite                    3.38.2               h6632b73_0  
tk                        8.6.11               h241ca14_0  
tqdm                      4.63.0             pyhd3eb1b0_0  
tzdata                    2022a                hda174b7_0  
urllib3                   1.26.8             pyhd3eb1b0_0  
wheel                     0.37.1             pyhd3eb1b0_0  
xz                        5.2.5                hfd63f10_1  
yaml                      0.2.5                hfd63f10_0  
zlib                      1.2.12               h2f4d8fa_1

Environment info

active environment : None
       user config file : /home/mgrigorov/.condarc
 populated config files : 
          conda version : 4.12.0
    conda-build version : not installed
         python version : 3.9.12.final.0
       virtual packages : __linux=4.19.90=0
                          __glibc=2.28=0
                          __unix=0=0
                          __archspec=1=aarch64
       base environment : /home/mgrigorov/devel/conda  (writable)
      conda av data dir : /home/mgrigorov/devel/conda/etc/conda
  conda av metadata url : None
           channel URLs : https://repo.anaconda.com/pkgs/main/linux-aarch64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-aarch64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /home/mgrigorov/devel/conda/pkgs
                          /home/mgrigorov/.conda/pkgs
       envs directories : /home/mgrigorov/devel/conda/envs
                          /home/mgrigorov/.conda/envs
               platform : linux-aarch64
             user-agent : conda/4.12.0 requests/2.27.1 CPython/3.9.12 Linux/4.19.90-2207.4.0.0160.oe1.aarch64 openeuler/20.03 glibc/2.28
                UID:GID : 1000:1000
             netrc file : None
           offline mode : False
@ngam
Copy link
Contributor

ngam commented Aug 15, 2022

Hi, contributions are welcome. See this initial PR for an attempt: #105. I can help/guide you if you would like to try to add it.

@hawkinsp
Copy link
Contributor

hawkinsp commented Aug 15, 2022

BTW, I verified that jaxlib builds fine at head on an GCP t2a VM just now with no special treatment, so I'm pretty confident things should work if we simply change the conda-forge build configuration to also build for aarch64 on Linux. I'd be happy to help here but @ngam would need to tell me what to do to make that happen...

@ngam
Copy link
Contributor

ngam commented Aug 15, 2022

@hawkinsp you're everywhere all of the sudden (in a good way!) and I am even getting a 1-question survey to fill about you in my edu inbox :P (relatedly, #105 was in reaction to the the same people who invited you to give that talk wanting a ppc build for jax...)

We have some options --- I will explain the process more in detail, but briefly:

  • I submitted add jaxlib to arch mig conda-forge-pinning-feedstock#3253 which will instruct the bots to submit a PR to this feedstock (once all deps have been addressed); we have a long-running "arch" migration to rebuild packages for ppc and aarch. You can actually see jaxlib already in queue https://conda-forge.org/status/#long_migrations and it seems it doesn't need any other deps
  • Once the PR arrives, we will investigate what type of build makes sense, we can do cross-compile on x86_64 or emulation (the latter tends to take longer, so it may not be doable on our public CIs)
  • We will likely need to deal with a few issues --- I expect the PPC build to be the more challenging (for PPC, we do actually have limited TravisCI credits that we could potentially use, which would build this natively and relatively quickly, but using Travis is quite annoying in these parts of open source)

@ngam ngam added enhancement New feature or request help wanted Extra attention is needed and removed bug Something isn't working labels Aug 15, 2022
@ngam
Copy link
Contributor

ngam commented Aug 15, 2022

We also have CUDA builds on aarch and ppc now, so we could go all out and add those too... but probably we should take care of the cpu ones first 😅

@hawkinsp
Copy link
Contributor

Yes, I think it's a great idea to have Linux CUDA aarch64 builds at least because of the upcoming https://www.nvidia.com/en-us/data-center/grace-cpu/ which I'm sure someone will want to use with JAX...

@ngam ngam mentioned this issue Aug 16, 2022
5 tasks
@ngam
Copy link
Contributor

ngam commented Aug 16, 2022

Well... the bot failed, so starting manually #127

@hawkinsp
Copy link
Contributor

Just a heads up: I was able to cross compile jaxlib for AArch64 easily enough, but the JIT compiler target detection isn't correct without making some upstream TensorFlow changes (tensorflow/tensorflow#57182).

So we will not be able to get a working cross-compiled Aarch64 build under 0.3.15 as is and it will need a new jaxlib release or some patching.

@ngam
Copy link
Contributor

ngam commented Aug 16, 2022

Thank you for the heads. Did you use our tooling (conda-forge) or something else for this? How about when you built native aarch64 version? Did you use our setup here?

@ngam
Copy link
Contributor

ngam commented Aug 16, 2022

I will apply your method here later in the week to see if we can get this sorted.

@ngam
Copy link
Contributor

ngam commented Aug 16, 2022

And you're correct, we can test if the build time is reasonable.

@hawkinsp
Copy link
Contributor

Yup, that's what I did. I suspect the bazel_toolchain package may do the right thing in the conda build already for cross-compilation, I'd certainly try that first.

@ngam
Copy link
Contributor

ngam commented Aug 16, 2022

Yup, that's what I did. I suspect the bazel_toolchain package may do the right thing in the conda build already for cross-compilation, I'd certainly try that first.

BTW, there is a way to publish pypi wheels here too. I am not sure if the core team is okay with that, but if you want we can make the pypi wheels here too. An example is numba publishing some of their wheels on anaconda.org: https://anaconda.org/numba/numba/files?type=pypi

@julien-faye
Copy link

I also need Jaxlib for Linux ARM64!
Is there any progress on this issue ?
Thank you!

@hawkinsp
Copy link
Contributor

I don't speak for the conda-forge jaxlib package maintainers, but jaxlib should work fine on ARM64 if you build it from source. So hopefully that can unblock you in the meantime!

@ngam
Copy link
Contributor

ngam commented Nov 15, 2022

Let's see how #147 pans out (contributions welcome!)

@ngam
Copy link
Contributor

ngam commented Nov 15, 2022

@hawkinsp here's what it stops in #155:

WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/tensorflow/runtime/archive/c27b720c93f76662ab6d0e0e507d1fc66ab22119.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1073WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/001d18664f8bcf63af64f10688809f7681dfbf0b.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1074Loading: 
1075Loading: 1 packages loaded
1076Analyzing: target //build:build_wheel (2 packages loaded, 0 targets configured)
1077INFO: ToolchainResolution: Target platform @local_config_platform//:host: Selected execution platform @local_execution_config_platform//:platform, 
1078INFO: ToolchainResolution:     Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: Rejected toolchain @local_execution_config_python//:py_runtime_pair; mismatching values: platform_constraint
1079INFO: ToolchainResolution:   Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: execution @local_execution_config_platform//:platform: Selected toolchain @local_config_python//:py_runtime_pair
1080INFO: ToolchainResolution:   Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: execution @local_config_platform//:host: Selected toolchain @local_config_python//:py_runtime_pair
1081INFO: ToolchainResolution:   Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: execution @local_execution_config_platform//:platform: Selected toolchain @local_config_cc//:cc-compiler-aarch64
1082INFO: ToolchainResolution:   Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: execution @local_config_platform//:host: Selected toolchain @local_config_cc//:cc-compiler-aarch64
1083INFO: ToolchainResolution:     Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: Rejected toolchain @local_config_cc//:cc-compiler-armeabi-v7a; mismatching values: arm, android
1084INFO: ToolchainResolution: Target platform @local_config_platform//:host: Selected execution platform @local_execution_config_platform//:platform, type @bazel_tools//tools/cpp:toolchain_type -> toolchain @local_config_cc//:cc-compiler-aarch64, type @bazel_tools//tools/python:toolchain_type -> toolchain @local_config_python//:py_runtime_pair
1085ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:11:1: indentation error
1086ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:14:2: Trailing comma is allowed only in parenthesized tuples.
1087ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:14:3: syntax error at 'outdent': expected expression
1088WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1089ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/work/build/BUILD.bazel:38:10: every rule of type py_binary implicitly depends upon the target '@bazel_tools//tools/zip:zipper', but this target could not be found because of: no such target '@bazel_tools//tools/zip:zipper': target 'zipper' not declared in package 'tools/zip' defined by /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD
1090ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis failed
1091INFO: Elapsed time: 58.229s
1092INFO: 0 processes.
1093FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1094ERROR: Build failed. Not running target
1095

@ngam
Copy link
Contributor

ngam commented Nov 15, 2022

Link to PPC/arm64 builds: https://app.travis-ci.com/github/conda-forge/jaxlib-feedstock/builds/257816827

Note we have our own toolchain that may need thorough updating ... I can work on that

https://github.com/conda-forge/bazel-toolchain-feedstock

@ngam
Copy link
Contributor

ngam commented Nov 16, 2022

It's now slightly clearer how we have your customizations (see collapsed code below) from google/jax#7097 (comment) in our tooling in #157, but I am get an error:

ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:11:1: indentation error
1079ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:14:2: Trailing comma is allowed only in parenthesized tuples.
1080ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:14:3: syntax error at 'outdent': expected expression
1081WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1082ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/work/build/BUILD.bazel:38:10: every rule of type py_binary implicitly depends upon the target '@bazel_tools//tools/zip:zipper', but this target could not be found because of: no such target '@bazel_tools//tools/zip:zipper': target 'zipper' not declared in package 'tools/zip' defined by /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD
1083ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis failed
1084INFO: Elapsed time: 51.199s
1085INFO: 0 processes.
1086FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1087ERROR: Build failed. Not running target
1088FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1089Traceback (most recent call last):
1090  File "/home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/work/build/build.py", line 572, in <module>
1091b''
1092
load("@local_config_cc//:cc_toolchain_config.bzl", "cc_toolchain_config")

package(default_visibility = ["//visibility:public"])

cc_toolchain_suite(
    name = "toolchain",
    toolchains = {
        "k8|compiler": "@local_config_cc//:cc-compiler-k8",
        "k8": "@local_config_cc//:cc-compiler-k8",
        "aarch64": ":cc-compiler-aarch64",
    },
)

cc_toolchain(
    name = "cc-compiler-aarch64",
    all_files = "@local_config_cc//:compiler_deps",
    ar_files = "@local_config_cc//:compiler_deps",
    as_files = "@local_config_cc//:compiler_deps",
    compiler_files = "@local_config_cc//:compiler_deps",
    dwp_files = ":empty",
    linker_files = "@local_config_cc//:compiler_deps",
    module_map = None,
    objcopy_files = ":empty",
    strip_files = ":empty",
    supports_param_files = 1,
    toolchain_config = ":cross_aarch64",
    toolchain_identifier = "cross_aarch64",
)

cc_toolchain_config(
    name = "cross_aarch64",
    abi_libc_version = "local",
    abi_version = "local",
    compile_flags = [
        "-U_FORTIFY_SOURCE",
        "-fstack-protector",
        "-Wall",
        "-Wunused-but-set-parameter",
        "-Wno-free-nonheap-object",
        "-fno-omit-frame-pointer",
    ],
    compiler = "compiler",
    coverage_compile_flags = ["--coverage"],
    coverage_link_flags = ["--coverage"],
    cpu = "aarch64",
    cxx_builtin_include_directories = [
        "/usr/aarch64-linux-gnu/include",
        "/usr/lib/gcc-cross/aarch64-linux-gnu/11/include",
        "/usr/local/include",
        "/usr/include",
        "/usr/include/c++/11",
        "/usr/include/c++/11/backward",
    ],
    cxx_flags = ["-std=c++0x"],
    dbg_compile_flags = ["-g"],
    host_system_name = "local",
    link_flags = [
        "-fuse-ld=gold",
        "-Wl,-no-as-needed",
        "-Wl,-z,relro,-z,now",
        "-B/usr/bin/aarch64-linux-gnu-",
        "-pass-exit-codes",
    ],
    link_libs = [
        "-lstdc++",
        "-lm",
    ],
    opt_compile_flags = [
        "-g0",
        "-O2",
        "-D_FORTIFY_SOURCE=1",
        "-DNDEBUG",
        "-ffunction-sections",
        "-fdata-sections",
    ],
    opt_link_flags = ["-Wl,--gc-sections"],
    supports_start_end_lib = True,
    target_libc = "local",
    target_system_name = "local",
    tool_paths = {
        "ar": "/usr/bin/ar",
        "ld": "/usr/bin/aarch64-linux-gnu-ld",
        "llvm-cov": "/usr/bin/llvm-cov",
        "cpp": "/usr/bin/aarch64-linux-gnu-cpp",
        "gcc": "/usr/bin/aarch64-linux-gnu-gcc",
        "dwp": "/usr/bin/aarch64-linux-gnu-dwp",
        "gcov": "/usr/bin/aarch64-linux-gnu-gcov",
        "nm": "/usr/bin/aarch64-linux-gnu-nm",
        "objcopy": "/usr/bin/aarch64-linux-gnu-objcopy",
        "objdump": "/usr/bin/aarch64-linux-gnu-objdump",
        "strip": "/usr/bin/aarch64-linux-gnu-strip",
    },
    toolchain_identifier = "cross_aarch64",
    unfiltered_compile_flags = [
        "-fno-canonical-system-headers",
        "-Wno-builtin-macro-redefined",
        "-D__DATE__=\"redacted\"",
        "-D__TIMESTAMP__=\"redacted\"",
        "-D__TIME__=\"redacted\"",
    ],
)

@martin-g
Copy link
Author

I don't speak for the conda-forge jaxlib package maintainers, but jaxlib should work fine on ARM64 if you build it from source. So hopefully that can unblock you in the meantime!

Thanks, @hawkinsp !

That's true but Alphafold uses

RUN pip3 install --upgrade pip --no-cache-dir \
    && pip3 install -r /app/alphafold/requirements.txt --no-cache-dir \
    && pip3 install --upgrade --no-cache-dir \
      jax==0.3.17 \
      jaxlib==0.3.15+cuda11.cudnn805 \
      -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

and there are only x86_64 wheels at https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Just few lines above in the Dockerfile they install some more dependencies from conda-forge but I'm afraid even if we solve this issue it won't help because it won't depend on the correct CUDA version. I hope I am wrong though!

@milot-mirdita
Copy link

Sorry for reviving this old issue.

With the introduction of linux-aarch64 support for bioconda, my package colabfold should also work on ARM, except for the missing jaxlib on linux-aarch64.

Even compilation without CUDA would be quite useful to me, as I could point users to install Colabfold through conda on e.g. a cloud ARM machine for the MSA generation part, and then then run the GPU inference separately on a different machine.

However, since I can't selectively disable conda dependencies, I would still need jaxlib to be installable on ARM.

@hawkinsp
Copy link
Contributor

I don't speak for the conda-forge maintainers, but upstream we ship a linux aarch64 pip wheel.

@milot-mirdita
Copy link

milot-mirdita commented Mar 13, 2024

I would still prefer to provide a single conda command for installation to users, since I have a few dependencies that are not pip installable.

I am very thankful for all the jaxlib pip variants though! They are super useful!

@xhochy
Copy link
Member

xhochy commented Mar 16, 2024

The main issue here is that we currently have receached the time of CI. Cross-compiled builds, e.g. for linux-aarch64 will take even longer. Once this is fixed, we can look into enabling this here.

@traversaro
Copy link
Contributor

Even compilation without CUDA would be quite useful to me, as I could point users to install Colabfold through conda on e.g. a cloud ARM machine for the MSA generation part, and then then run the GPU inference separately on a different machine.

I am missing something or indeed conda packages for jaxlib on linux-aarch64 without cuda are actually available?

@xhochy
Copy link
Member

xhochy commented May 24, 2024

They are available since a year: #183

@traversaro
Copy link
Contributor

Indeed, that was my understanding, but this was not clear from @milot-mirdita in #125 (comment) . Could it make sense to rename the issue to "[Feature request]: Add support for CUDA builds on Linux ARM64"?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants