Skip to content

Commit

Permalink
fix: support port to tf2
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Feb 16, 2020
1 parent b2673eb commit 3994e2c
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions ludwig/models/modules/recurrent_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def recurrent_decoder(encoder_outputs, targets, max_sequence_length, vocab_size,
logger.debug(' {}: {}'.format(v.name, v))

# ================ Decoding ================
def decode(initial_state, cell, helper, beam_width=1,
def decode(initial_state, cell, sampler, beam_width=1,
projection_layer=None):
# The decoder itself
if beam_width > 1:
Expand All @@ -412,8 +412,8 @@ def decode(initial_state, cell, helper, beam_width=1,
output_layer=projection_layer)
else:
decoder = BasicDecoder(
cell=cell, helper=helper,
initial_state=initial_state,
cell=cell, sampler=sampler,
# todo tf2: remove obsolete code #initial_state=initial_state,
output_layer=projection_layer)

# The decoding operation
Expand Down Expand Up @@ -442,13 +442,15 @@ def decode(initial_state, cell, helper, beam_width=1,
predictions_sequence_length_with_eos = final_sequence_lengths_pred

else:
train_helper = tfa.seq2seq.sampler.TrainingSampler(
inputs=targets_embedded,
sequence_length=targets_sequence_length_with_eos)
train_sampler = tfa.seq2seq.sampler.TrainingSampler()
# todo tf2: cleanout obsolete code
# train_helper = tfa.seq2seq.sampler.TrainingSampler(
# inputs=targets_embedded,
# sequence_length=targets_sequence_length_with_eos)
final_outputs_train, final_state_train, final_sequence_lengths_train = decode(
initial_state,
cell,
train_helper,
train_sampler, # todo: tf2 to be removed #train_helper,
projection_layer=projection_layer)
eval_logits = final_outputs_train.rnn_output
train_logits = final_outputs_train.projection_input
Expand Down

0 comments on commit 3994e2c

Please sign in to comment.