Skip to content

Commit 5e92fac

Browse files
committed
tweak xla_bridge.py flags
* add environment variables for jax_disable_most_optimizations and jax_cpu_backend_variant * comment on the default values in help strings
1 parent c97d63d commit 5e92fac

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

jax/lib/xla_bridge.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
logging._warn_preinit_stderr = 0
3131

3232
import jax.lib
33-
from .._src.config import flags
33+
from .._src.config import flags, bool_env
3434
from . import tpu_driver_client
3535
from . import xla_client
3636
from jax._src import util, traceback_util
@@ -52,21 +52,26 @@
5252
'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.')
5353
flags.DEFINE_string(
5454
'jax_backend_target', 'local',
55-
'Either "local" or "rpc:address" to connect to a remote service target.')
55+
'Either "local" or "rpc:address" to connect to a remote service target. '
56+
'The default is "local".')
5657
flags.DEFINE_string(
5758
'jax_platform_name',
58-
os.getenv('JAX_PLATFORM_NAME', ''),
59+
os.getenv('JAX_PLATFORM_NAME', '').lower(),
5960
'Platform name for XLA. The default is to attempt to use a GPU or TPU if '
6061
'available, but fall back to CPU otherwise. To set the platform manually, '
61-
'pass "cpu" for CPU, "gpu" for GPU, etc.')
62+
'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, '
63+
'setting the platform name to "cpu" can silence warnings that appear with '
64+
'the default setting.')
6265
flags.DEFINE_bool(
63-
'jax_disable_most_optimizations', False,
66+
'jax_disable_most_optimizations',
67+
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
6468
'Try not to do much optimization work. This can be useful if the cost of '
6569
'optimization is greater than that of running a less-optimized program.')
6670
flags.DEFINE_string(
67-
'jax_cpu_backend_variant', 'tfrt',
68-
'jax_cpu_backend_variant selects cpu backend variant: stream_executor or '
69-
'tfrt')
71+
'jax_cpu_backend_variant',
72+
os.getenv('JAX_CPU_BACKEND_VARIANT', 'tfrt'),
73+
'Selects CPU backend runtime variant: "stream_executor" or "tfrt". The '
74+
'default is "tfrt".')
7075

7176
def get_compile_options(
7277
num_replicas: int,

0 commit comments

Comments
 (0)