Skip to content

Commit

Permalink
refactor: cast sequence metric/loss functions to handle tf.int64 values
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Apr 19, 2020
1 parent de7b133 commit def27bd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 13 deletions.
11 changes: 2 additions & 9 deletions ludwig/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,6 @@ def __init__(self, feature):

self.decoder_obj = self.initialize_decoder(feature)

# determine required tf dtype to represent sequence encoded values
max_base2_exponent = np.int(np.ceil(np.log2(self.num_classes - 1)))
if max_base2_exponent <= 32:
self._prediction_dtype = tf.int32
else:
self._prediction_dtype = tf.int64

self._setup_loss()
self._setup_metrics()

Expand Down Expand Up @@ -280,15 +273,15 @@ def predictions(

logits = inputs[LOGITS]

probabilities= tf.nn.softmax(
probabilities = tf.nn.softmax(
logits,
name='probabilities_{}'.format(self.name)
)
predictions = tf.argmax(
logits,
-1,
name='predictions_{}'.format(self.name),
output_type=self._prediction_dtype
output_type=tf.int64
)

if self.decoder == 'generator':
Expand Down
3 changes: 2 additions & 1 deletion ludwig/models/modules/loss_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(self, name=None, **kwargs):
def call(self, y_true, y_pred):
# y_true: shape [batch_size, sequence_size]
# y_pred: shape [batch_size, sequence_size, num_classes]

# get sequence lengths from targets
targets_sequence_length = sequence_length_2D(
tf.convert_to_tensor(y_true, dtype=tf.int32)
Expand All @@ -150,7 +151,7 @@ def call(self, y_true, y_pred):

# compute loss based on valid time steps
loss = self.loss_function(
tf.convert_to_tensor(y_true, dtype=tf.int32),
tf.convert_to_tensor(y_true, dtype=tf.int64),
y_pred[LOGITS],
sample_weight=sample_mask
)
Expand Down
7 changes: 4 additions & 3 deletions ludwig/models/modules/metric_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ def update_state(self, y, y_hat):
super().update_state(loss)




class SequenceLastAccuracyMetric(tf.keras.metrics.Accuracy):
"""
Sequence accuracy based on last token in the sequence
Expand All @@ -181,7 +179,7 @@ def __init__(self, name=None):
def update_state(self, y_true, y_pred, sample_weight=None):
# TODO TF2 account for weights
targets_sequence_length = sequence_length_2D(
tf.convert_to_tensor(y_true, dtype=tf.int32)
tf.convert_to_tensor(y_true, dtype=tf.int64)
)
last_targets = tf.gather_nd(
y_true,
Expand All @@ -195,6 +193,8 @@ def update_state(self, y_true, y_pred, sample_weight=None):
)
)

last_targets = tf.cast(last_targets, dtype=tf.int64)

super().update_state(last_targets, y_pred)


Expand All @@ -211,6 +211,7 @@ def result(self):
mean = super().result()
return np.exp(mean)


class EditDistanceMetric(tf.keras.metrics.Mean):
def __init__(self, name=None):
super(EditDistanceMetric, self).__init__(name=name)
Expand Down

0 comments on commit def27bd

Please sign in to comment.