Skip to content

Commit

Permalink
Fix pip install jax[tpu]
Browse files Browse the repository at this point in the history
* Updates jax_releases.html index to include libtpu wheels
* Change [tpu] extras to specify `libtpu-nightly` instead of wheel URL

The full install command will now be:
`pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html`
(similar to the cuda install commands)

I've already pushed an updated jax_releases.html to the jax-releases GCS bucket.
  • Loading branch information
skye committed Jun 23, 2021
1 parent 2460f91 commit 55276d1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
37 changes: 26 additions & 11 deletions build/generate_release_index.py
Expand Up @@ -20,6 +20,7 @@
gsutil cp jax_releases.html gs://jax-releases/
"""

from itertools import chain
import subprocess

FILENAME = "jax_releases.html"
Expand All @@ -33,20 +34,34 @@

FOOTER = "</body>\n</html>\n"

print("Running command: gsutil ls gs://jax-releases/cuda*")
ls_output = subprocess.check_output(["gsutil", "ls", "gs://jax-releases/cuda*"])
def get_entries(gcs_uri, whl_filter=".whl"):
entries = []
print(f"Running command: gsutil ls {gcs_uri}")
ls_output = subprocess.check_output(["gsutil", "ls", gcs_uri])
for line in ls_output.decode("utf-8").split("\n"):
# Skip incorrectly formatted wheel filenames and other gsutil output
if not whl_filter in line: continue
# Example lines:
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
# gs://cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl

# Link title should be the innermost directory + wheel filename
# Example link titles:
# cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
# libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl
link_title_idx = line.rfind('/', 0, line.rfind('/')) + 1
link_title = line[link_title_idx:]
link_href = line.replace("gs://", "https://storage.googleapis.com/")
entries.append(f'<a href="{link_href}">{link_title}</a><br>\n')
return entries

jaxlib_cuda_entries = get_entries("gs://jax-releases/cuda*", whl_filter="+cuda")
libtpu_entries = get_entries("gs://cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/")

print(f"Writing index to {FILENAME}")
with open(FILENAME, "w") as f:
f.write(HEADER)
for line in ls_output.decode("utf-8").split("\n"):
# Skip incorrectly formatted wheel filenames and other gsutil output
if not "+cuda" in line: continue
# Example line:
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
assert line.startswith("gs://jax-releases/cuda")
link_title = line[len("gs://jax-releases/"):]
link_href = line.replace("gs://", "https://storage.googleapis.com/")
f.write(f'<a href="{link_href}">{link_title}</a><br>\n')
for entry in chain(jaxlib_cuda_entries, libtpu_entries):
f.write(entry)
f.write(FOOTER)
print("Done.")
9 changes: 3 additions & 6 deletions setup.py
Expand Up @@ -24,10 +24,7 @@
__version__ = _dct['__version__']
_minimum_jaxlib_version = _dct['_minimum_jaxlib_version']

_libtpu_version = '20210615'
_libtpu_url = (
f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/'
f'libtpu-nightly/libtpu_nightly-0.1.dev{_libtpu_version}-py3-none-any.whl')
_libtpu_version = '0.1.dev20210615'

setup(
name='jax',
Expand All @@ -52,9 +49,9 @@
'cpu': [f'jaxlib>={_minimum_jaxlib_version}'],

# Cloud TPU VM jaxlib can be installed via:
# $ pip install jax[tpu]
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html
'tpu': [f'jaxlib=={_current_jaxlib_version}',
f'libtpu-nightly @ {_libtpu_url}'],
f'libtpu-nightly=={_libtpu_version}'],

# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
Expand Down

0 comments on commit 55276d1

Please sign in to comment.