Skip to content

Commit

Permalink
Fix jax 0.3.11 GPU breakge when used with jaxlib 0.3.10.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed May 16, 2022
1 parent 1381afc commit 337ec47
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/lib/__init__.py
Expand Up @@ -183,7 +183,7 @@ def _parse_version(v: str) -> Tuple[int, ...]:
hip_linalg = None

try:
import jaxlib.cuda_linalg as gpu_linalg # pytype: disable=import-error
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
except ImportError:
gpu_linalg = None

Expand Down

0 comments on commit 337ec47

Please sign in to comment.