Skip to content

Commit

Permalink
Internal refactors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 338388042
  • Loading branch information
znado authored and Copybara-Service committed Oct 22, 2020
1 parent 6e96526 commit 04d4241
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
10 changes: 4 additions & 6 deletions baselines/cifar/deterministic.py
Expand Up @@ -96,7 +96,7 @@ def main(argv):
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

train_input_fn = utils.load_input_fn(
split=tfds.Split.TRAIN,
Expand All @@ -110,11 +110,9 @@ def main(argv):
batch_size=FLAGS.per_core_batch_size,
use_bfloat16=FLAGS.use_bfloat16,
data_dir=FLAGS.data_dir)
train_dataset = strategy.experimental_distribute_datasets_from_function(
train_input_fn)
train_dataset = strategy.distribute_datasets_from_function(train_input_fn)
test_datasets = {
'clean': strategy.experimental_distribute_datasets_from_function(
clean_test_input_fn),
'clean': strategy.distribute_datasets_from_function(clean_test_input_fn),
}
if FLAGS.corruptions_interval > 0:
if FLAGS.dataset == 'cifar10':
Expand All @@ -132,7 +130,7 @@ def main(argv):
batch_size=FLAGS.per_core_batch_size,
use_bfloat16=FLAGS.use_bfloat16)
test_datasets['{0}_{1}'.format(corruption, intensity)] = (
strategy.experimental_distribute_datasets_from_function(input_fn))
strategy.distribute_datasets_from_function(input_fn))

ds_info = tfds.builder(FLAGS.dataset).info
batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
Expand Down
12 changes: 7 additions & 5 deletions uncertainty_baselines/models/__init__.py
Expand Up @@ -28,7 +28,6 @@
from uncertainty_baselines.models.resnet50_batchensemble import resnet_batchensemble
from uncertainty_baselines.models.resnet50_deterministic import resnet50_deterministic
from uncertainty_baselines.models.resnet50_dropout import resnet50_dropout
from uncertainty_baselines.models.resnet50_mimo import resnet50_mimo
from uncertainty_baselines.models.resnet50_rank1 import resnet50_rank1
from uncertainty_baselines.models.resnet50_sngp import resnet50_sngp
from uncertainty_baselines.models.resnet50_sngp_be import resnet50_sngp_be
Expand All @@ -40,17 +39,20 @@
from uncertainty_baselines.models.wide_resnet_hyperbatchensemble import e_factory as hyperbatchensemble_e_factory
from uncertainty_baselines.models.wide_resnet_hyperbatchensemble import LambdaConfig as HyperBatchEnsembleLambdaConfig
from uncertainty_baselines.models.wide_resnet_hyperbatchensemble import wide_resnet_hyperbatchensemble
from uncertainty_baselines.models.wide_resnet_mimo import wide_resnet_mimo
from uncertainty_baselines.models.wide_resnet_rank1 import wide_resnet_rank1
from uncertainty_baselines.models.wide_resnet_sngp import wide_resnet_sngp
from uncertainty_baselines.models.wide_resnet_sngp_be import wide_resnet_sngp_be
from uncertainty_baselines.models.wide_resnet_variational import wide_resnet_variational

# When adding a new model, also add to models.py for easier user access.

# pylint: disable=g-import-not-at-top
try:
from uncertainty_baselines.models.bert import create_model as BertBuilder # pylint: disable=g-import-not-at-top
from uncertainty_baselines.models.bert_dropout import create_model as DropoutBertBuilder # pylint: disable=g-import-not-at-top
from uncertainty_baselines.models.bert_sngp import create_model as SngpBertBuilder # pylint: disable=g-import-not-at-top
from uncertainty_baselines.models.bert import create_model as BertBuilder
from uncertainty_baselines.models.bert_dropout import create_model as DropoutBertBuilder
from uncertainty_baselines.models.bert_sngp import create_model as SngpBertBuilder
from uncertainty_baselines.models.resnet50_mimo import resnet50_mimo
from uncertainty_baselines.models.wide_resnet_mimo import wide_resnet_mimo
except ImportError as e:
warnings.warn(f'Skipped due to ImportError: {e}')
# pylint: enable=g-import-not-at-top

0 comments on commit 04d4241

Please sign in to comment.