In [1]:
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from tensorflow.python.layers import core as layers_core

tf.reset_default_graph()

# Global parameters

In [2]:
batch_size = 4
beam_width = 9 

# Create training data

In [3]:
with open('/tmp/toy_data.txt', 'w') as data_file:
    for _ in range(10000):
        data_file.write("a b c\td e f d e f\n")
        data_file.write("d e f\ta b c a b c\n")

# Vocabulary as a lookup table

In [4]:
vocab = ['PAD', 'UNK', 'EOS', 'SOS'] + list("aábcdeéfghijklmnoóöőpqrstuúüűvwxyz-+.")
EOS = 2  # end of sentence
SOS = 3  # start of sentence (GO symbol)
table = lookup_ops.index_table_from_tensor(tf.constant(vocab), default_value=1)
vocab = {k: i for i, k in enumerate(vocab)}
vocab_size = len(vocab)

table_initializer = tf.tables_initializer()

# Reading dataset

Format:

~~~
input TAB output
input TAB output
~~~

In [5]:
dataset = tf.contrib.data.TextLineDataset('/tmp/toy_data.txt')
dataset = dataset.map(lambda string: tf.string_split([string], delimiter='\t').values)
source = dataset.map(lambda string: string[0])
target = dataset.map(lambda string: string[1])

source = source.map(lambda string: tf.string_split([string], delimiter=' ').values)
source = source.map(lambda words: table.lookup(words))
target = target.map(lambda string: tf.string_split([string], delimiter=' ').values)
target = target.map(lambda words: table.lookup(words))

src_tgt_dataset = tf.contrib.data.Dataset.zip((source, target))
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (src,
                      tf.concat(([SOS], tgt), 0),
                      tf.concat((tgt, [EOS]), 0),)
)
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in))
)

# Padded batch

In [6]:
# if I set the third padding shape to the same as the second one,
# it fails if there is no maxlen character long sample in the batch
# WHY???
batched = src_tgt_dataset.padded_batch(batch_size, padded_shapes=(
    tf.TensorShape([32]), tf.TensorShape([32]), tf.TensorShape([None]),
         tf.TensorShape([]), tf.TensorShape([])))
batched_iter = batched.make_initializable_iterator()
src_ids, tgt_in_ids, tgt_out_ids, src_size, tgt_size = batched_iter.get_next()

# Encoder

In [7]:
embedding = tf.get_variable("embedding", [vocab_size, 3], dtype=tf.float32)

encoder_emb_inp = tf.nn.embedding_lookup(embedding, src_ids)
   

In [8]:
fw_cell = tf.contrib.rnn.BasicLSTMCell(8)
bw_cell = tf.contrib.rnn.BasicLSTMCell(8)

encoder_outputs, encoder_state = tf.nn.bidirectional_dynamic_rnn(
    fw_cell, bw_cell, encoder_emb_inp, dtype=tf.float32, sequence_length=src_size
)

encoder_outputs = tf.concat(encoder_outputs, -1)
#encoder_state = tf.concat(encoder_state, -1)

other = tf.contrib.rnn.BasicLSTMCell(16)
other_outputs, other_state = tf.nn.dynamic_rnn(other, encoder_outputs, dtype=tf.float32,
                                               sequence_length=src_size)

#encoder_outputs = other_outputs
#encoder_state = other_state
#encoder_cell = tf.contrib.rnn.BasicLSTMCell(16)
#
#encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, dtype=tf.float32,
#                                                   sequence_length=src_size)

# Decoder

In [9]:
attention = tf.contrib.seq2seq.LuongAttention(16, encoder_outputs,
                                              memory_sequence_length=src_size)
decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    decoder_cell,
    attention,
    attention_layer_size=16,
    alignment_history=False
)
decoder_initial_state = decoder_cell.zero_state(batch_size, tf.float32)

decoder_emb_inp = tf.nn.embedding_lookup(embedding, tgt_in_ids)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, tgt_size)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
sample_id = outputs.sample_id
output_proj = layers_core.Dense(vocab_size, name="output_projection")
logits = output_proj(outputs.rnn_output)

In [10]:
decoder_initial_state

AttentionWrapperState(cell_state=LSTMStateTuple(c=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state:0' shape=(4, 16) dtype=float32>, h=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state_1:0' shape=(4, 16) dtype=float32>), attention=<tf.Tensor 'AttentionWrapperZeroState/zeros_1:0' shape=(4, 16) dtype=float32>, time=<tf.Tensor 'AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>, alignments=<tf.Tensor 'AttentionWrapperZeroState/zeros_2:0' shape=(4, 32) dtype=float32>, alignment_history=())

# Loss

In [11]:
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tgt_out_ids, logits=logits)
target_weights = tf.sequence_mask(tgt_size, tf.shape(tgt_out_ids)[1], tf.float32)
loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(5)

# Optimizer and gradient update

In [12]:
#optimizer =tf.train.MomentumOptimizer(0.1, 0.9)
optimizer =tf.train.AdamOptimizer(0.1)
#optimizer =tf.train.RMSPropOptimizer(0.1)
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
update = optimizer.apply_gradients(zip(gradients, params))

# Greedy decoding with `GreedyEmbeddingHelper`

The encoder stays the same but we need to redefine the decoder.

In [13]:
greedy_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch_size], SOS),
                                                         EOS)
greedy_attention = tf.contrib.seq2seq.LuongAttention(16, encoder_outputs,
                                              memory_sequence_length=src_size)

greedy_decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)
greedy_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    greedy_decoder_cell,
    greedy_attention,
    attention_layer_size=16,
    alignment_history=True
)
greedy_decoder_initial_state = greedy_decoder_cell.zero_state(
    batch_size, tf.float32).clone(cell_state=encoder_state )

greedy_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state,
                                         output_layer=output_proj)

greedy_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(greedy_decoder, maximum_iterations=10)

# Beam search decoding

In [14]:
"""
beam_memory = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
beam_src_size = tf.contrib.seq2seq.tile_batch(src_size, multiplier=beam_width)

beam_attention = tf.contrib.seq2seq.LuongAttention(16, beam_memory,
                                              memory_sequence_length=beam_src_size)
beam_encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)

beam_decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)
beam_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    beam_decoder_cell,
    beam_attention,
    attention_layer_size=16,
    alignment_history=False
)
#beam_decoder_cell = tf.contrib.rnn.DeviceWrapper(beam_decoder_cell, device="/cpu:0")
start_tokens = tf.fill([batch_size], SOS)


beam_decoder_initial_state = beam_decoder_cell.zero_state(
    batch_size * beam_width, tf.float32).clone(
    cell_state=beam_encoder_state
)

beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    cell=beam_decoder_cell,
    embedding=embedding,
    start_tokens=start_tokens,
    initial_state=beam_decoder_initial_state,
    beam_width=beam_width,
    output_layer=output_proj,
    end_token=EOS,
)
beam_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(beam_decoder, maximum_iterations=10)
"""

'\nbeam_memory = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)\nbeam_src_size = tf.contrib.seq2seq.tile_batch(src_size, multiplier=beam_width)\n\nbeam_attention = tf.contrib.seq2seq.LuongAttention(16, beam_memory,\n                                              memory_sequence_length=beam_src_size)\nbeam_encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)\n\nbeam_decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)\nbeam_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(\n    beam_decoder_cell,\n    beam_attention,\n    attention_layer_size=16,\n    alignment_history=False\n)\n#beam_decoder_cell = tf.contrib.rnn.DeviceWrapper(beam_decoder_cell, device="/cpu:0")\nstart_tokens = tf.fill([batch_size], SOS)\n\n\nbeam_decoder_initial_state = beam_decoder_cell.zero_state(\n    batch_size * beam_width, tf.float32).clone(\n    cell_state=beam_encoder_state\n)\n\nbeam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n    cell=beam_decoder_cell,\

# Starting session

In [15]:
sess = tf.InteractiveSession()
sess.run(table_initializer)
sess.run(batched_iter.initializer)
sess.run(tf.global_variables_initializer())

# Training

In [16]:
for i in range(100):
    sess.run(update)
    l = sess.run(loss)
    if i % 10 == 9:
        print("Iteration: {}, training loss: {}".format(i+1, l))

Iteration: 10, training loss: 8.901555061340332
Iteration: 20, training loss: 5.586587905883789
Iteration: 30, training loss: 3.339617967605591
Iteration: 40, training loss: 1.880027174949646
Iteration: 50, training loss: 0.7130416035652161
Iteration: 60, training loss: 1.0728628635406494
Iteration: 70, training loss: 6.814558506011963
Iteration: 80, training loss: 1.8322842121124268
Iteration: 90, training loss: 1.224880337715149
Iteration: 100, training loss: 1.2833908796310425


# Manual greedy decoding

NOTE: running logits iterates over the next batch in the dataset, so running this cell multiple times decodes a different batch in the dataset.

In [17]:
inv_vocab = {v: k for k, v in vocab.items()}
inv_vocab[-1] = 'UNK'
skip_symbols = ('PAD', 'SOS', 'EOS', 'UNK')

## Input and output labels

Greed: just take the highest probabilty along the last axis.

In [18]:
input_ids, out_probs = sess.run([src_ids, logits])
output_ids = out_probs.argmax(axis=-1)

output_ids.shape

(4, 7)

## Convert labels to characters

In [19]:
def decode_ids(input_ids, output_ids):
    decoded = []
    for sample_i in range(output_ids.shape[0]):
        input_sample = input_ids[sample_i]
        output_sample = output_ids[sample_i]
        input_decoded = [inv_vocab[s] for s in input_sample]
        input_decoded = ''.join(c for c in input_decoded if c not in skip_symbols)
        output_decoded = [inv_vocab[s] for s in output_sample]
        output_decoded = ''.join(c for c in output_decoded if c not in skip_symbols)
        decoded.append((input_decoded, output_decoded))
    return decoded
 
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defdef
def ---> abcabca
abc ---> defdef
def ---> abcabca


# Run greedy inference

In [20]:
input_ids, output_ids, logits = sess.run([src_ids, greedy_outputs.sample_id,
                                          greedy_outputs.rnn_output])
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defdef
def ---> abcabca
abc ---> defdef
def ---> abcabca


In [21]:
logits.shape

(4, 7, 41)

# Run beam search inference

In [22]:
""""""
input_ids, output_ids = sess.run([src_ids, beam_outputs.predicted_ids])
output_ids

all_decoded = []
for beam_i in range(beam_width):
    inputs = []
    all_decoded.append([])
    decoded = decode_ids(input_ids, output_ids[:,:,beam_i])
    for dec in decoded:
        all_decoded[-1].append(dec[1])
        inputs.append(dec[0])

print('\n'.join(
    '{} ---> {}'.format(inputs[i], ' / '.join(d[i] for d in all_decoded))
                        for i in range(len(inputs))
))

NameError: name 'beam_outputs' is not defined