Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chentingpc committed May 22, 2023
1 parent 2fc637b commit 383d414
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 23 deletions.
7 changes: 4 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import objective as obj_lib

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow.compat.v2 as tf2

FLAGS = flags.FLAGS
Expand All @@ -35,7 +36,7 @@ def build_model_fn(model, num_classes, num_train_examples):
"""Build model function."""
def model_fn(features, labels, mode, params=None):
"""Build model and optimizer."""
is_training = mode == tf.estimator.ModeKeys.TRAIN
is_training = mode == tf_estimator.ModeKeys.TRAIN

# Check training mode.
if FLAGS.train_mode == 'pretrain':
Expand Down Expand Up @@ -183,7 +184,7 @@ def scaffold_fn():
else:
scaffold_fn = None

return tf.estimator.tpu.TPUEstimatorSpec(
return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
else:

Expand Down Expand Up @@ -215,7 +216,7 @@ def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
tf.losses.get_regularization_loss()),
}

return tf.estimator.tpu.TPUEstimatorSpec(
return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(metric_fn, metrics),
Expand Down
4 changes: 2 additions & 2 deletions resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def _cross_replica_average(self, t):
num_shards = tpu_function.get_tpu_context().number_of_shards
return tf.tpu.cross_replica_sum(t) / tf.cast(num_shards, t.dtype)

def _moments(self, inputs, reduction_axes, keep_dims):
def _moments(self, inputs, reduction_axes, keep_dims, mask=None):
"""Compute the mean and variance: it overrides the original _moments."""
shard_mean, shard_variance = super(BatchNormalization, self)._moments(
inputs, reduction_axes, keep_dims=keep_dims)
inputs, reduction_axes, keep_dims=keep_dims, mask=mask)

num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards and num_shards > 1:
Expand Down
11 changes: 6 additions & 5 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import model_util as model_util

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Expand Down Expand Up @@ -397,10 +398,10 @@ def main(argv):
tf.config.experimental_connect_to_cluster(cluster)
tf.tpu.experimental.initialize_tpu_system(cluster)

default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
run_config = tf.estimator.tpu.RunConfig(
tpu_config=tf.estimator.tpu.TPUConfig(
default_eval_mode = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V1
sliced_eval_mode = tf_estimator.tpu.InputPipelineConfig.SLICED
run_config = tf_estimator.tpu.RunConfig(
tpu_config=tf_estimator.tpu.TPUConfig(
iterations_per_loop=checkpoint_steps,
eval_training_input_configuration=sliced_eval_mode
if FLAGS.use_tpu else default_eval_mode),
Expand All @@ -410,7 +411,7 @@ def main(argv):
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
master=FLAGS.master,
cluster=cluster)
estimator = tf.estimator.tpu.TPUEstimator(
estimator = tf_estimator.tpu.TPUEstimator(
model_lib.build_model_fn(model, num_classes, num_train_examples),
config=run_config,
train_batch_size=FLAGS.train_batch_size,
Expand Down
3 changes: 2 additions & 1 deletion tf2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,11 @@ def get_preprocess_fn(is_training, is_pretrain):
test_crop = False
else:
test_crop = True
color_jitter_strength = FLAGS.color_jitter_strength if is_pretrain else 0.
return functools.partial(
data_util.preprocess_image,
height=FLAGS.image_size,
width=FLAGS.image_size,
is_training=is_training,
color_distort=is_pretrain,
color_jitter_strength=color_jitter_strength,
test_crop=test_crop)
21 changes: 10 additions & 11 deletions tf2/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@
"""Data preprocessing and augmentation."""

import functools
from absl import flags

import tensorflow.compat.v2 as tf

FLAGS = flags.FLAGS

CROP_PROPORTION = 0.875 # Standard for ImageNet.


Expand Down Expand Up @@ -446,7 +443,7 @@ def generate_selector(p, bsz):
def preprocess_for_train(image,
height,
width,
color_distort=True,
color_jitter_strength=0.,
crop=True,
flip=True,
impl='simclrv2'):
Expand All @@ -456,11 +453,12 @@ def preprocess_for_train(image,
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
color_distort: Whether to apply the color distortion.
color_jitter_strength: `float` between 0 and 1 indicating the color
distortion strength, disable color distortion if not bigger than 0.
crop: Whether to crop the image.
flip: Whether or not to flip left and right of an image.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
version of random brightness.
Returns:
A preprocessed image `Tensor`.
Expand All @@ -469,8 +467,8 @@ def preprocess_for_train(image,
image = random_crop_with_resize(image, height, width)
if flip:
image = tf.image.random_flip_left_right(image)
if color_distort:
image = random_color_jitter(image, strength=FLAGS.color_jitter_strength,
if color_jitter_strength > 0:
image = random_color_jitter(image, strength=color_jitter_strength,
impl=impl)
image = tf.reshape(image, [height, width, 3])
image = tf.clip_by_value(image, 0., 1.)
Expand All @@ -497,15 +495,16 @@ def preprocess_for_eval(image, height, width, crop=True):


def preprocess_image(image, height, width, is_training=False,
color_distort=True, test_crop=True):
color_jitter_strength=0., test_crop=True):
"""Preprocesses the given image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
is_training: `bool` for whether the preprocessing is for training.
color_distort: whether to apply the color distortion.
color_jitter_strength: `float` between 0 and 1 indicating the color
distortion strength, disable color distortion if not bigger than 0.
test_crop: whether or not to extract a central crop of the images
(as for standard ImageNet evaluation) during the evaluation.
Expand All @@ -514,6 +513,6 @@ def preprocess_image(image, height, width, is_training=False,
"""
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if is_training:
return preprocess_for_train(image, height, width, color_distort)
return preprocess_for_train(image, height, width, color_jitter_strength)
else:
return preprocess_for_eval(image, height, width, test_crop)
2 changes: 1 addition & 1 deletion tf2/lars_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
EETA_DEFAULT = 0.001


class LARSOptimizer(tf.keras.optimizers.Optimizer):
class LARSOptimizer(tf.keras.optimizers.legacy.Optimizer):
"""Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
Expand Down

0 comments on commit 383d414

Please sign in to comment.