In [1]:
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from tensorflow.python.layers import core as layers_core

tf.reset_default_graph()

# Create training data

In [2]:
with open('/tmp/toy_data.txt', 'w') as data_file:
    for _ in range(10000):
        data_file.write("a b c\td e f d e f\n")
        data_file.write("d e f\ta b c a b c\n")

# Vocabulary as a lookup table

In [3]:
vocab = ['PAD', 'EOS', 'SOS'] + list("aábcdef")
vocab = ['UNK', 'PAD', 'EOS', 'SOS'] + list("aábcdeéfghijklmnoóöőpqrstuúüűvwxyz-+.")

EOS = 1  # end of sentence
SOS = 2  # start of sentence (GO symbol)
table = lookup_ops.index_table_from_tensor(tf.constant(vocab), default_value=0)
vocab = {k: i for i, k in enumerate(vocab)}
vocab_size = len(vocab)

table_initializer = tf.tables_initializer()

In [4]:
batch_size = 32
lstm_size = 32

# Reading dataset

Format:

~~~
i n p u t TAB o u t p u t
i n p u t TAB o u t p u t
~~~

In [5]:
dataset = tf.contrib.data.TextLineDataset('/tmp/toy_data.txt')
dataset = dataset.map(lambda string: tf.string_split([string], delimiter='\t').values)
source = dataset.map(lambda string: string[0])
target = dataset.map(lambda string: string[1])

source = source.map(lambda string: tf.string_split([string], delimiter=' ').values)
source = source.map(lambda words: table.lookup(words))
target = target.map(lambda string: tf.string_split([string], delimiter=' ').values)
target = target.map(lambda words: table.lookup(words))

src_tgt_dataset = tf.contrib.data.Dataset.zip((source, target))
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (src,
                      tf.concat(([SOS], tgt), 0),
                      tf.concat((tgt, [EOS]), 0),)
)
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in))
)

# Padded batch

In [6]:
# if I set the third padding shape to tf.TensorShape([5]),
# it fails if there is no 4 character long sample in the batch
# WHY???
batched = src_tgt_dataset.padded_batch(batch_size, padded_shapes=(
    tf.TensorShape([32]), tf.TensorShape([32]), tf.TensorShape([None]),
         tf.TensorShape([]), tf.TensorShape([])))
batched_iter = batched.make_initializable_iterator()
src_ids, tgt_in_ids, tgt_out_ids, src_size, tgt_size = batched_iter.get_next()

# Encoder

In [7]:
embedding = tf.get_variable("embedding", [vocab_size, 20], dtype=tf.float32)

encoder_emb_inp = tf.nn.embedding_lookup(embedding, src_ids)
    
encoder_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, dtype=tf.float32,
                                                   sequence_length=src_size)

# Decoder

In [8]:
decoder_initial_state = encoder_state
decoder_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
decoder_emb_inp = tf.nn.embedding_lookup(embedding, tgt_in_ids)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, tgt_size)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
sample_id = outputs.sample_id
output_proj = layers_core.Dense(vocab_size, name="output_projection")
logits = output_proj(outputs.rnn_output)

# Loss

In [9]:
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tgt_out_ids, logits=logits)
target_weights = tf.sequence_mask(tgt_size, tf.shape(tgt_out_ids)[1], tf.float32)
loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(batch_size)

# Optimizer and gradient update

In [10]:
optimizer =tf.train.AdamOptimizer(0.1)
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
update = optimizer.apply_gradients(zip(gradients, params))

# Starting session

In [11]:
sess = tf.InteractiveSession()
sess.run(table_initializer)
sess.run(batched_iter.initializer)
sess.run(tf.global_variables_initializer())

# Training

In [12]:
for i in range(100):
    sess.run(update)
    l = sess.run(loss)
    if i % 10 == 9:
        print("Iteration: {}, training loss: {}".format(i+1, l))

Iteration: 10, training loss: 8.13936996459961
Iteration: 20, training loss: 2.7662346363067627
Iteration: 30, training loss: 0.8806182146072388
Iteration: 40, training loss: 0.2316165566444397
Iteration: 50, training loss: 0.10096192359924316
Iteration: 60, training loss: 0.0590992197394371
Iteration: 70, training loss: 0.042138297110795975
Iteration: 80, training loss: 0.03173879534006119
Iteration: 90, training loss: 0.024419207125902176
Iteration: 100, training loss: 0.01931704953312874


In [13]:
#s, t1, t2 = sess.run([src_ids, tgt_in_ids, tgt_out_ids])

# Manual greedy decoding

NOTE: running logits iterates over the next batch in the dataset, so running this cell multiple times decodes a different batch in the dataset.

In [14]:
inv_vocab = {v: k for k, v in vocab.items()}
inv_vocab[-1] = 'UNK'
skip_symbols = ('PAD', 'SOS', 'EOS', 'UNK')

## Input and output labels

Greed: just take the highest probabilty along the last axis.

In [15]:
input_ids, out_probs = sess.run([src_ids, logits])
output_ids = out_probs.argmax(axis=-1)

output_ids.shape

(32, 7)

## Convert labels to characters

In [16]:
def decode_ids(input_ids, output_ids):
    decoded = []
    for sample_i in range(output_ids.shape[0]):
        input_sample = input_ids[sample_i]
        output_sample = output_ids[sample_i]
        input_decoded = [inv_vocab[s] for s in input_sample]
        input_decoded = ''.join(c for c in input_decoded if c not in skip_symbols)
        output_decoded = [inv_vocab[s] for s in output_sample]
        output_decoded = ''.join(c for c in output_decoded if c not in skip_symbols)
        decoded.append((input_decoded, output_decoded))
    return decoded
 
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc


# Greedy decoding with `GreedyEmbeddingHelper`

The encoder stays the same but we need to redefine the decoder.

In [17]:
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch_size], SOS), EOS)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state,
                                         output_layer=output_proj)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=30)

In [18]:
input_ids, output_ids = sess.run([src_ids, outputs.sample_id])
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc


# Beam search decoding

In [19]:
beam_width = 2
start_tokens = tf.fill([4], SOS)
bm_decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)
bm_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    cell=decoder_cell,
    embedding=embedding,
    start_tokens=start_tokens,
    initial_state=bm_decoder_initial_state,
    beam_width=beam_width,
    output_layer=output_proj,
    end_token=EOS,
)
bm_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(bm_decoder, maximum_iterations=100)

In [20]:
input_ids, output_ids = sess.run([src_ids, bm_outputs.predicted_ids])
output_ids

all_decoded = []
for beam_i in range(beam_width):
    inputs = []
    all_decoded.append([])
    decoded = decode_ids(input_ids, output_ids[:,:,beam_i])
    for dec in decoded:
        all_decoded[-1].append(dec[1])
        inputs.append(dec[0])

print('\n'.join(
    '{} ---> {}'.format(inputs[i], ' / '.join(d[i] for d in all_decoded))
                        for i in range(len(inputs))
))

InvalidArgumentError: Input to reshape is a tensor with 2048 values, but the requested shape has 256
	 [[Node: Reshape_1 = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](tile_batch/Reshape_1, concat_2)]]

Caused by op 'Reshape_1', defined at:
  File "/usr/lib64/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib64/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2802, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-19-96cf957ab113>", line 11, in <module>
    end_token=EOS,
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py", line 193, in __init__
    initial_state, self._cell.state_size)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 325, in map_structure
    structure[0], [func(*x) for x in entries])
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 325, in <listcomp>
    structure[0], [func(*x) for x in entries])
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py", line 374, in _maybe_split_batch_beams
    return self._split_batch_beams(t, s)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py", line 339, in _split_batch_beams
    ([self._batch_size, self._beam_width], t_shape[1:]), 0))
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 2451, in reshape
    name=name)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/judit/.virtualenvs/deep/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 2048 values, but the requested shape has 256
	 [[Node: Reshape_1 = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](tile_batch/Reshape_1, concat_2)]]
