In [1]:
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from tensorflow.python.layers import core as layers_core

tf.reset_default_graph()

# Global parameters

In [2]:
batch_size = 4
beam_width = 9 

# Create training data

In [3]:
with open('/tmp/toy_data.txt', 'w') as data_file:
    for _ in range(10000):
        #data_file.write("a b c\td e f á d e f\n")
        #data_file.write("d e á f\ta b c c c a b c\n")
        data_file.write("abc\tdefdef\n")
        data_file.write("def\tabc\n")

# Vocabulary as a lookup table

In [4]:
vocab = ['PAD', 'EOS', 'SOS'] + list("áabcdef")
EOS = 1  # end of sentence
SOS = 2  # start of sentence (GO symbol)
table = lookup_ops.index_table_from_tensor(tf.constant(vocab))
vocab = {k: i for i, k in enumerate(vocab)}
vocab_size = len(vocab)

table_initializer = tf.tables_initializer()

# Reading dataset

Format:

~~~
input TAB output
input TAB output
~~~

In [5]:
dataset = tf.contrib.data.TextLineDataset('/tmp/toy_data.txt')
#raw_data = tf.placeholder(shape=[None], dtype=tf.string)
#raw_data_size = tf.placeholder(shape=[], dtype=tf.int64)

#dataset = tf.contrib.data.Dataset.from_tensor_slices(raw_data)
dataset = dataset.map(lambda string: tf.string_split([string], delimiter='\t').values)
source = dataset.map(lambda string: string[0])
target = dataset.map(lambda string: string[1])

source = source.map(lambda string: tf.string_split([string], delimiter='').values)
source = source.map(lambda words: table.lookup(words))
target = target.map(lambda string: tf.string_split([string], delimiter='').values)
target = target.map(lambda words: table.lookup(words))

src_tgt_dataset = tf.contrib.data.Dataset.zip((source, target))
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (src,
                      tf.concat(([SOS], tgt), 0),
                      tf.concat((tgt, [EOS]), 0),)
)
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in))
)

# Padded batch

In [6]:
# if I set the third padding shape to the same as the second one,
# it fails if there is no maxlen character long sample in the batch
# WHY???
batched = src_tgt_dataset.padded_batch(batch_size, padded_shapes=(
    tf.TensorShape([6]), tf.TensorShape([10]), tf.TensorShape([None]),
         tf.TensorShape([]), tf.TensorShape([])))
batched_iter = batched.make_initializable_iterator()
src_ids, tgt_in_ids, tgt_out_ids, src_size, tgt_size = batched_iter.get_next()

# Encoder

In [7]:
embedding = tf.get_variable("embedding", [vocab_size, 3], dtype=tf.float32)

encoder_emb_inp = tf.nn.embedding_lookup(embedding, src_ids)
   

In [8]:
fw_cell = [tf.contrib.rnn.BasicLSTMCell(8),
           tf.contrib.rnn.ResidualWrapper(tf.contrib.rnn.BasicLSTMCell(8))]
bw_cell = [tf.contrib.rnn.BasicLSTMCell(8), tf.contrib.rnn.BasicLSTMCell(8)]
#bw_cell = tf.contrib.rnn.BasicLSTMCell(8)
fw_cell = tf.contrib.rnn.MultiRNNCell(fw_cell)
bw_cell = tf.contrib.rnn.MultiRNNCell(bw_cell)

encoder_outputs, encoder_state = tf.nn.bidirectional_dynamic_rnn(
    fw_cell, bw_cell, encoder_emb_inp, dtype=tf.float32, sequence_length=src_size
)

encoder_outputs = tf.concat(encoder_outputs, -1)
#encoder_state = tf.concat(encoder_state, -1)

#other = tf.contrib.rnn.BasicLSTMCell(16)
#other_outputs, other_state = tf.nn.dynamic_rnn(other, encoder_outputs, dtype=tf.float32,
                                               #sequence_length=src_size)

#encoder_outputs = other_outputs
#encoder_state = other_state
#encoder_cell = tf.contrib.rnn.BasicLSTMCell(16)
#
#encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, dtype=tf.float32,
#                                                   sequence_length=src_size)

In [9]:
batch_size = tf.size(src_size)
encoder_outputs
encoder_state

((LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_2:0' shape=(?, 8) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?, 8) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(?, 8) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_5:0' shape=(?, 8) dtype=float32>)),
 (LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_2:0' shape=(?, 8) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(?, 8) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_4:0' shape=(?, 8) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_5:0' shape=(?, 8) dtype=float32>)))

# Decoder

In [10]:
attention = tf.contrib.seq2seq.LuongAttention(16, encoder_outputs,
                                              memory_sequence_length=src_size)
decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    decoder_cell,
    attention,
    attention_layer_size=17,
    alignment_history=False
)
decoder_initial_state = decoder_cell.zero_state(batch_size, tf.float32)

decoder_emb_inp = tf.nn.embedding_lookup(embedding, tgt_in_ids)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, tgt_size)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
sample_id = outputs.sample_id
output_proj = layers_core.Dense(vocab_size, name="output_projection")
logits = output_proj(outputs.rnn_output)

In [11]:
decoder

<tensorflow.contrib.seq2seq.python.ops.basic_decoder.BasicDecoder at 0x7fcafe573f60>

In [12]:
decoder_initial_state

AttentionWrapperState(cell_state=LSTMStateTuple(c=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state:0' shape=(?, 16) dtype=float32>, h=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state_1:0' shape=(?, 16) dtype=float32>), attention=<tf.Tensor 'AttentionWrapperZeroState/zeros_1:0' shape=(?, 17) dtype=float32>, time=<tf.Tensor 'AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>, alignments=<tf.Tensor 'AttentionWrapperZeroState/zeros_2:0' shape=(?, 6) dtype=float32>, alignment_history=())

In [13]:
decoder_cell

<tensorflow.contrib.seq2seq.python.ops.attention_wrapper.AttentionWrapper at 0x7fcafe58d828>

In [14]:
decoder_initial_state

AttentionWrapperState(cell_state=LSTMStateTuple(c=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state:0' shape=(?, 16) dtype=float32>, h=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state_1:0' shape=(?, 16) dtype=float32>), attention=<tf.Tensor 'AttentionWrapperZeroState/zeros_1:0' shape=(?, 17) dtype=float32>, time=<tf.Tensor 'AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>, alignments=<tf.Tensor 'AttentionWrapperZeroState/zeros_2:0' shape=(?, 6) dtype=float32>, alignment_history=())

# Loss

In [15]:
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tgt_out_ids, logits=logits)
target_weights = tf.sequence_mask(tgt_size, tf.shape(tgt_out_ids)[1], tf.float32)
loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(5)

# Optimizer and gradient update

In [16]:
optimizer =tf.train.AdamOptimizer(0.1)
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
update = optimizer.apply_gradients(zip(gradients, params))

# Greedy decoding with `GreedyEmbeddingHelper`

The encoder stays the same but we need to redefine the decoder.

In [17]:
greedy_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch_size], SOS),
                                                         EOS)
greedy_attention = tf.contrib.seq2seq.LuongAttention(16, encoder_outputs,
                                              memory_sequence_length=src_size)

greedy_decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)
greedy_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    greedy_decoder_cell,
    greedy_attention,
    attention_layer_size=16,
    alignment_history=True
)
greedy_decoder_initial_state = greedy_decoder_cell.zero_state(
    batch_size, tf.float32).clone(cell_state=encoder_state )

greedy_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state,
                                         output_layer=output_proj)

greedy_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(greedy_decoder, maximum_iterations=10)

# Beam search decoding

In [18]:
"""
beam_memory = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
beam_src_size = tf.contrib.seq2seq.tile_batch(src_size, multiplier=beam_width)

beam_attention = tf.contrib.seq2seq.LuongAttention(16, beam_memory,
                                              memory_sequence_length=beam_src_size)
beam_encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)

beam_decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)
beam_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    beam_decoder_cell,
    beam_attention,
    attention_layer_size=16,
    alignment_history=False
)
#beam_decoder_cell = tf.contrib.rnn.DeviceWrapper(beam_decoder_cell, device="/cpu:0")
start_tokens = tf.fill([batch_size], SOS)


beam_decoder_initial_state = beam_decoder_cell.zero_state(
    batch_size * beam_width, tf.float32).clone(
    cell_state=beam_encoder_state
)

beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    cell=beam_decoder_cell,
    embedding=embedding,
    start_tokens=start_tokens,
    initial_state=beam_decoder_initial_state,
    beam_width=beam_width,
    output_layer=output_proj,
    end_token=EOS,
)
beam_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(beam_decoder, maximum_iterations=10)
"""

'\nbeam_memory = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)\nbeam_src_size = tf.contrib.seq2seq.tile_batch(src_size, multiplier=beam_width)\n\nbeam_attention = tf.contrib.seq2seq.LuongAttention(16, beam_memory,\n                                              memory_sequence_length=beam_src_size)\nbeam_encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)\n\nbeam_decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)\nbeam_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(\n    beam_decoder_cell,\n    beam_attention,\n    attention_layer_size=16,\n    alignment_history=False\n)\n#beam_decoder_cell = tf.contrib.rnn.DeviceWrapper(beam_decoder_cell, device="/cpu:0")\nstart_tokens = tf.fill([batch_size], SOS)\n\n\nbeam_decoder_initial_state = beam_decoder_cell.zero_state(\n    batch_size * beam_width, tf.float32).clone(\n    cell_state=beam_encoder_state\n)\n\nbeam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n    cell=beam_decoder_cell,\

# Starting session

In [19]:
sess = tf.InteractiveSession()
sess.run(table_initializer)
sess.run(batched_iter.initializer)
sess.run(tf.global_variables_initializer())

# Training

In [20]:
for i in range(100):
    sess.run(update)
    l = sess.run(loss)
    if i % 10 == 9:
        print("Iteration: {}, training loss: {}".format(i+1, l))

Iteration: 10, training loss: 4.075344085693359
Iteration: 20, training loss: 0.12878328561782837
Iteration: 30, training loss: 0.005970187485218048
Iteration: 40, training loss: 0.002689577406272292
Iteration: 50, training loss: 0.0015475597465410829
Iteration: 60, training loss: 0.0008760251221247017
Iteration: 70, training loss: 0.00039950694190338254
Iteration: 80, training loss: 0.0002776928013190627
Iteration: 90, training loss: 0.00022996659390628338
Iteration: 100, training loss: 0.00019992880697827786


In [21]:
sess.run(src_ids)

array([[4, 5, 6, 0, 0, 0],
       [7, 8, 9, 0, 0, 0],
       [4, 5, 6, 0, 0, 0],
       [7, 8, 9, 0, 0, 0]])

In [22]:
sess.run(tgt_in_ids)

array([[2, 7, 8, 9, 7, 8, 9, 0, 0, 0],
       [2, 4, 5, 6, 0, 0, 0, 0, 0, 0],
       [2, 7, 8, 9, 7, 8, 9, 0, 0, 0],
       [2, 4, 5, 6, 0, 0, 0, 0, 0, 0]])

In [23]:
sess.run(tgt_out_ids)

array([[7, 8, 9, 7, 8, 9, 1],
       [4, 5, 6, 1, 0, 0, 0],
       [7, 8, 9, 7, 8, 9, 1],
       [4, 5, 6, 1, 0, 0, 0]])

# Manual greedy decoding

NOTE: running logits iterates over the next batch in the dataset, so running this cell multiple times decodes a different batch in the dataset.

In [24]:
inv_vocab = {v: k for k, v in vocab.items()}
inv_vocab[-1] = 'UNK'
skip_symbols = ('PAD', 'SOS', 'EOS', 'UNK')

## Input and output labels

Greed: just take the highest probabilty along the last axis.

In [25]:
input_ids, out_probs = sess.run([src_ids, logits])
output_ids = out_probs.argmax(axis=-1)

output_ids.shape
src_ids

<tf.Tensor 'IteratorGetNext:0' shape=(?, 6) dtype=int64>

In [26]:
src_ids

<tf.Tensor 'IteratorGetNext:0' shape=(?, 6) dtype=int64>

## Convert labels to characters

In [27]:
def decode_ids(input_ids, output_ids):
    decoded = []
    for sample_i in range(output_ids.shape[0]):
        input_sample = input_ids[sample_i]
        output_sample = output_ids[sample_i]
        input_decoded = [inv_vocab[s] for s in input_sample]
        input_decoded = ''.join(c for c in input_decoded if c not in skip_symbols)
        output_decoded = [inv_vocab[s] for s in output_sample]
        output_decoded = ''.join(c for c in output_decoded if c not in skip_symbols)
        decoded.append((input_decoded, output_decoded))
    return decoded
 
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defdef
def ---> abcac
abc ---> defdef
def ---> abcac


# Run greedy inference

In [28]:
input_ids, output_ids, logits = sess.run([src_ids, greedy_outputs.sample_id,
                                          greedy_outputs.rnn_output])
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defdef
def ---> abcac
abc ---> defdef
def ---> abcac


In [29]:
logits.shape

(4, 7, 10)

# Run beam search inference

In [30]:
"""
input_ids, output_ids = sess.run([src_ids, beam_outputs.predicted_ids])
output_ids

all_decoded = []
for beam_i in range(beam_width):
    inputs = []
    all_decoded.append([])
    decoded = decode_ids(input_ids, output_ids[:,:,beam_i])
    for dec in decoded:
        all_decoded[-1].append(dec[1])
        inputs.append(dec[0])

print('\n'.join(
    '{} ---> {}'.format(inputs[i], ' / '.join(d[i] for d in all_decoded))
                        for i in range(len(inputs))
))
"""

"\ninput_ids, output_ids = sess.run([src_ids, beam_outputs.predicted_ids])\noutput_ids\n\nall_decoded = []\nfor beam_i in range(beam_width):\n    inputs = []\n    all_decoded.append([])\n    decoded = decode_ids(input_ids, output_ids[:,:,beam_i])\n    for dec in decoded:\n        all_decoded[-1].append(dec[1])\n        inputs.append(dec[0])\n\nprint('\n'.join(\n    '{} ---> {}'.format(inputs[i], ' / '.join(d[i] for d in all_decoded))\n                        for i in range(len(inputs))\n))\n"