|
30 | 30 | logging._warn_preinit_stderr = 0 |
31 | 31 |
|
32 | 32 | import jax.lib |
33 | | -from .._src.config import flags |
| 33 | +from .._src.config import flags, bool_env |
34 | 34 | from . import tpu_driver_client |
35 | 35 | from . import xla_client |
36 | 36 | from jax._src import util, traceback_util |
|
52 | 52 | 'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.') |
53 | 53 | flags.DEFINE_string( |
54 | 54 | 'jax_backend_target', 'local', |
55 | | - 'Either "local" or "rpc:address" to connect to a remote service target.') |
| 55 | + 'Either "local" or "rpc:address" to connect to a remote service target. ' |
| 56 | + 'The default is "local".') |
56 | 57 | flags.DEFINE_string( |
57 | 58 | 'jax_platform_name', |
58 | | - os.getenv('JAX_PLATFORM_NAME', ''), |
| 59 | + os.getenv('JAX_PLATFORM_NAME', '').lower(), |
59 | 60 | 'Platform name for XLA. The default is to attempt to use a GPU or TPU if ' |
60 | 61 | 'available, but fall back to CPU otherwise. To set the platform manually, ' |
61 | | - 'pass "cpu" for CPU, "gpu" for GPU, etc.') |
| 62 | + 'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, ' |
| 63 | + 'setting the platform name to "cpu" can silence warnings that appear with ' |
| 64 | + 'the default setting.') |
62 | 65 | flags.DEFINE_bool( |
63 | | - 'jax_disable_most_optimizations', False, |
| 66 | + 'jax_disable_most_optimizations', |
| 67 | + bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), |
64 | 68 | 'Try not to do much optimization work. This can be useful if the cost of ' |
65 | 69 | 'optimization is greater than that of running a less-optimized program.') |
66 | 70 | flags.DEFINE_string( |
67 | | - 'jax_cpu_backend_variant', 'tfrt', |
68 | | - 'jax_cpu_backend_variant selects cpu backend variant: stream_executor or ' |
69 | | - 'tfrt') |
| 71 | + 'jax_cpu_backend_variant', |
| 72 | + os.getenv('JAX_CPU_BACKEND_VARIANT', 'tfrt'), |
| 73 | + 'Selects CPU backend runtime variant: "stream_executor" or "tfrt". The ' |
| 74 | + 'default is "tfrt".') |
70 | 75 |
|
71 | 76 | def get_compile_options( |
72 | 77 | num_replicas: int, |
|
0 commit comments