Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

{tools}[foss/2023a] jax v0.4.24 w/ CUDA 12.1.1 #19841

Closed
124 changes: 124 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.24'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasHoffmann77 There's a jax 0.4.25 release now, can we try updating to that and see if we're still seeing failing tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@boegel I'll give 0.4.25 a try today

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need this, so happy to review once you have updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need this, so happy to review once you have updated.

@verdurin see #20119. (My local test runs did not finsh yet)

versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://pypi.python.org/pypi/jax'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'foss', 'version': '2023a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
('Bazel', '6.3.1'),
('pytest-xdist', '3.3.1'),
# git 2.x required to fetch repository 'io_bazel_rules_docker'
('git', '2.41.0', '-nodocs'),
('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py
('poetry', '1.5.1'),
]

dependencies = [
('CUDA', '12.1.1', '', SYSTEM),
('cuDNN', '8.9.2.26', versionsuffix, SYSTEM),
('NCCL', '2.18.3', versionsuffix),
('Python', '3.11.3'),
('SciPy-bundle', '2023.07'),
('flatbuffers-python', '23.5.26'),
('zlib', '1.2.13'),
('ml_dtypes', '0.3.2'),
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl
local_xla_commit = '12eee889e1f2ad41e27d7b0e970cb92d282d3ec5'
local_tfrt_commit = '4665f7483063a16b6113a05eb45f98103cc1d611'
local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit
local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit

# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid lossing numerical precision by disabling TF32 Tensor Cores;
local_test = "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 "
local_test += "XLA_PYTHON_CLIENT_ALLOCATOR=platform "
local_test += "JAX_ENABLE_X64=true pytest -vv tests "
ThomasHoffmann77 marked this conversation as resolved.
Show resolved Hide resolved

use_pip = True

default_easyblock = 'PythonPackage'
default_component_specs = {
'sources': [SOURCE_TAR_GZ],
'source_urls': [PYPI_SOURCE],
'start_dir': '%(name)s-%(version)s',
'use_pip': True,
'sanity_pip_check': True,
'download_dep_fail': True,
}

components = [
('absl-py', '2.1.0', {
'options': {'modulename': 'absl'},
'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'],
}),
('jaxlib', version, {
'sources': [
'%(name)s-v%(version)s.tar.gz',
{
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit,
},
{
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit,
},
],
'source_urls': [
'https://github.com/google/jax/archive/',
'https://github.com/tensorflow/runtime/archive',
'https://github.com/openxla/xla/archive'
],
'patches': [
('jax-0.4.24_xla-%s_indexing_analysis_small_vector.patch' % local_xla_commit[:7],
'../xla-%s' % local_xla_commit),
# cuda-noncanonical-include-paths still required?:
# ('jax-0.4.24_xla-%s_cuda-noncanonical-include-paths.patch' % local_xla_commit[:7],
# '../xla-%s' % local_xla_commit),
],
'checksums': [
{'jaxlib-v0.4.24.tar.gz':
'c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28'},
{'xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5.tar.gz':
'db007b6628cfe108c63f45d611c6de910abe3ee827e55f08314ce143c4887d66'},
{'tf_runtime-4665f7483063a16b6113a05eb45f98103cc1d611.tar.gz':
'3aa0ab30fe94dab33f20824b9c2d8e7c3b6017106c833b12070f71d2e0f1d6d6'},
{'jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch':
'7187cdd08cce12d0af889494317cb8c32865487d1d6d9254064cb62fd3453b6d'},
],
'start_dir': 'jax-jaxlib-v%(version)s',
'buildopts': local_repo_opt
}),
]

exts_list = [
(name, version, {
'runtest': local_test,
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
'checksums': [
{'jax-v0.4.24.tar.gz': '6e52d8b547624bd70d423e6bf85f4fcd47336b529f1a4f6a94fac3096017a694'},
],
}),
]

sanity_pip_check = True

moduleclass = 'tools'
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
diff -ru xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5_old/xla/service/gpu/model/indexing_analysis.cc xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5/xla/service/gpu/model/indexing_analysis.cc
--- xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5_old/xla/service/gpu/model/indexing_analysis.cc 2024-02-05 19:41:29.000000000 +0100
+++ xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5/xla/service/gpu/model/indexing_analysis.cc 2024-02-12 12:09:35.301680070 +0100
@@ -687,7 +687,7 @@
llvm::SmallVector<AffineExpr, 4> DelinearizeInBoundsIndex(
mlir::AffineExpr linear, absl::Span<const int64_t> sizes,
absl::Span<const int64_t> strides) {
- llvm::SmallVector<AffineExpr> result;
+ llvm::SmallVector<AffineExpr, 4> result; // THEMBL; see commit c10075688d773c43c22e658c814a94ade3cbb372
result.reserve(sizes.size());
for (auto [size, stride] : llvm::zip(sizes, strides)) {
result.push_back(linear.floorDiv(stride) % size);
50 changes: 50 additions & 0 deletions easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/02
easyblock = 'PythonBundle'

name = 'ml_dtypes'
version = '0.3.2'

homepage = 'https://github.com/jax-ml/ml_dtypes'
description = """
ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used
in machine learning libraries, including:

bfloat16: an alternative to the standard float16 format
float8_*: several experimental 8-bit floating point representations including:
float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
"""

toolchain = {'name': 'foss', 'version': '2023a'}

dependencies = [
('Python', '3.11.3'),
('SciPy-bundle', '2023.07'),
]


use_pip = True

exts_list = [
('opt_einsum', '3.3.0', {
'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'],
}),
('etils', '1.6.0', {
'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'],
}),
(name, version, {
'patches': [('ml_dtypes-0.3.2_EigenAvx512.patch', 1)],
'checksums': [
{'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'},
{'ml_dtypes-0.3.2_EigenAvx512.patch':
'fef229a24515b9c03be0d2e932c499965212e3a03ae3ede5d037874f88f93c46'},
],
})
]

sanity_pip_check = True

moduleclass = 'tools'