Skip to content

Commit

Permalink
fix: comply with function signature for tfa.seq2seq.dynamic_decode
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Feb 17, 2020
1 parent d077649 commit 7e4261c
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions ludwig/models/modules/recurrent_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def recurrent_decoder(encoder_outputs, targets, max_sequence_length, vocab_size,

# ================ Decoding ================
def decode(initial_state, cell, sampler, beam_width=1,
projection_layer=None):
projection_layer=None, inputs=None):
# The decoder itself
if beam_width > 1:
# Tile inputs for beam search decoder
Expand All @@ -414,22 +414,20 @@ def decode(initial_state, cell, sampler, beam_width=1,
else:
decoder = BasicDecoder(
cell=cell, sampler=sampler,
# todo tf2: remove obsolete code #initial_state=initial_state,
output_layer=projection_layer)

initial_state = cell.get_initial_state(
batch_size=batch_size, dtype=tf.float32
)
# todo tf2: remove obsolete code #initial_state=initial_state,

# todo tf2: need to figure out 'inputs' to next function
decoder.initialize(inputs=None, initial_state=initial_state)
decoder.initialize(inputs, initial_state=initial_state)

# The decoding operation
outputs = tfa.seq2seq.dynamic_decode(
decoder=decoder,
output_time_major=False,
impute_finished=False if beam_width > 1 else True,
maximum_iterations=max_sequence_length
maximum_iterations=max_sequence_length,
decoder_init_input=inputs,
decoder_init_kwargs={'initial_state': initial_state}
)

return outputs
Expand All @@ -451,15 +449,25 @@ def decode(initial_state, cell, sampler, beam_width=1,

else:
train_sampler = tfa.seq2seq.sampler.TrainingSampler()
train_sampler.initialize(targets_embedded,
sequence_length=targets_sequence_length_with_eos)
# todo tf2: cleanout obsolete code
# train_helper = tfa.seq2seq.sampler.TrainingSampler(
# inputs=targets_embedded,
# sequence_length=targets_sequence_length_with_eos)

# # todo tf2: test code
# initial_state = cell.get_initial_state(
# batch_size=batch_size, dtype=tf.float32
# )

final_outputs_train, final_state_train, final_sequence_lengths_train = decode(
initial_state,
cell,
train_sampler, # todo: tf2 to be removed #train_helper,
projection_layer=projection_layer)
projection_layer=projection_layer,
inputs=encoder_outputs
)
eval_logits = final_outputs_train.rnn_output
train_logits = final_outputs_train.projection_input
# train_predictions = final_outputs_train.sample_id
Expand Down

0 comments on commit 7e4261c

Please sign in to comment.