Skip to content

Commit

Permalink
internal merge of PR tensorflow#1336
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 227913649
  • Loading branch information
ywkim authored and kpe committed Mar 2, 2019
1 parent 018e44b commit c1562d2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 18 deletions.
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/problem.py
Expand Up @@ -368,6 +368,7 @@ def eval_metrics(self):
]

def eval_metric_fns(self, model_hparams):
del model_hparams
metric_names = self.eval_metrics()
if not all([m in metrics.METRICS_FNS for m in metric_names]):
error_str = ("Unrecognized metric. Problem %s specified metrics "
Expand Down
16 changes: 0 additions & 16 deletions tensor2tensor/models/transformer.py
Expand Up @@ -1121,22 +1121,6 @@ def body(self, features):

return encoder_output

@registry.register_model
class TransformerRegressor(TransformerEncoder):
"""Transformer inheriting from Encoder, for the regression problem.
Final res is a tensor that has a shape of (?, 1, 1, 1)
"""

def top(self, body_output, features):
"""Computes single scalar value from body_output
"""
with tf.variable_scope("reg_top_ffn"):
# scalar = common_layers.dense(body_output,hparams)
x = body_output
x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
res = tf.layers.dense(x, 1, name="model_top")
return res


def features_to_nonpadding(features, inputs_or_targets="inputs"):
key = inputs_or_targets + "_segmentation"
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/utils/metrics.py
Expand Up @@ -638,7 +638,7 @@ def create_eager_metrics_for_problem(problem, model_hparams):
metric_fns = problem.eval_metric_fns(model_hparams)
tm = problem.get_hparams(model_hparams).modality["targets"]
return create_eager_metrics_internal(
metric_fns, weights_fn=tm.targets_weights_fn)
metric_fns, weights_fn=tm.targets_weights_fn)


def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all):
Expand All @@ -664,7 +664,7 @@ def create_eager_metrics_internal(metric_fns,
"""Create metrics accumulators and averager for Eager mode.
Args:
metric_names: dict<metric name, metric function>
metric_fns: dict<metric name, metric function>
weights_fn: function that takes labels and returns a weights mask. Defaults
to weights of all 1, i.e. common_layers.weights_all. Use
common_layers.weights_nonzero if labels have 0-padding.
Expand Down

0 comments on commit c1562d2

Please sign in to comment.