Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/mnist_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy.random as npr

import jax.numpy as np
from jax.config import config
from jax import jit, grad
from jax.experimental import minmax
from jax.experimental import stax
Expand Down Expand Up @@ -80,6 +81,7 @@ def update(i, opt_state, batch):
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
Expand All @@ -92,4 +94,3 @@ def update(i, opt_state, batch):
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))

75 changes: 75 additions & 0 deletions jax/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@


class Config(object):
def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False

def update(self, name, val):
self.check_exists(name)
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))
self.values[name] = val

def read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
self.check_exists(name)
return self.values[name]

def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
if name in self.values:
raise Exception("Config option {} already defined".format(name))
self.values[name] = default
self.meta[name] = (opt_type, meta_args, meta_kwargs)

def check_exists(self, name):
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))

def DEFINE_bool(self, name, default, *args, **kwargs):
self.add_option(name, default, bool, args, kwargs)

def DEFINE_integer(self, name, default, *args, **kwargs):
self.add_option(name, default, int, args, kwargs)

def DEFINE_string(self, name, default, *args, **kwargs):
self.add_option(name, default, str, args, kwargs)

def DEFINE_enum(self, name, default, *args, **kwargs):
self.add_option(name, default, 'enum', args, kwargs)

def config_with_absl(self):
# Run this before calling `app.run(main)` etc
import absl.flags as absl_FLAGS
from absl import app, flags as absl_flags

self.use_absl = True
self.absl_flags = absl_flags
absl_defs = { bool: absl_flags.DEFINE_bool,
int: absl_flags.DEFINE_integer,
str: absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum }

for name, val in self.values.items():
flag_type, meta_args, meta_kwargs = self.meta[name]
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)

def complete_absl_config(self, absl_flags):
for name, _ in self.values.items():
self.update(name, getattr(absl_flags.FLAGS, name))


class NameSpace(object):
def __init__(self, getter):
self._getter = getter

def __getattr__(self, name):
return self._getter(name)


config = Config()
flags = config
2 changes: 1 addition & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import six
from six.moves import xrange

from absl import flags
from .. config import flags
from .. import core
from .. import ad_util
from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types
Expand Down
17 changes: 2 additions & 15 deletions jax/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import os
import warnings

from absl import flags
from ..config import flags
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy

from . import xla_data_pb2
Expand Down Expand Up @@ -206,21 +206,8 @@ def dtype_to_etype(dtype):
}


def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
# This function is a thin wrapper around the memoized _canonicalize_dtype to
# handle the case where FLAGS haven't been parsed yet, for example because
# this function is called at module loading time. This situation can't obtain
# during tracing and instead can arise when there are module-level constants
# computed using lax or lax_numpy.
if FLAGS.is_parsed():
return _canonicalize_dtype(dtype)
else:
return dtype


@memoize
def _canonicalize_dtype(dtype):
def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
dtype = onp.dtype(dtype)
if FLAGS.jax_enable_x64:
Expand Down
2 changes: 1 addition & 1 deletion jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import functools
import re

from absl import flags
from absl.testing import absltest
from absl.testing import parameterized

import numpy as onp
import numpy.random as npr

from . import api
from .config import flags
from .util import partial
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce

Expand Down
2 changes: 2 additions & 0 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from absl.testing import absltest
from absl.testing import parameterized

from jax.config import flags
from jax import api
from jax import core
from jax import numpy as np
Expand Down Expand Up @@ -331,4 +332,5 @@ def test_jvp_2(self):


if __name__ == '__main__':
flags.config_with_absl()
absltest.main()