In [413]:
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from tensorflow.python.layers import core as layers_core

tf.reset_default_graph()

# Create training data

In [414]:
with open('/tmp/toy_data.txt', 'w') as data_file:
    for _ in range(1000):
        data_file.write("a b c\td e f d e f\n")
        data_file.write("d e f\ta b c a b c\n")

# Vocabulary as a lookup table

In [415]:
vocab = ['PAD', 'EOS', 'SOS'] + list("aábcdef")
vocab = ['UNK', 'PAD', 'EOS', 'SOS'] + list("aábcdeéfghijklmnoóöőpqrstuúüűvwxyz-+.")

EOS = 1  # end of sentence
SOS = 2  # start of sentence (GO symbol)
table = lookup_ops.index_table_from_tensor(tf.constant(vocab), default_value=0)
vocab = {k: i for i, k in enumerate(vocab)}
vocab_size = len(vocab)

table_initializer = tf.tables_initializer()

In [416]:
batch_size = 100
lstm_size = 256

# Reading dataset

Format:

~~~
input TAB output
input TAB output
~~~

In [417]:
#dataset = tf.contrib.data.TextLineDataset('/tmp/toy_data.txt')
dataset = tf.contrib.data.TextLineDataset('../../data/webcorp/webcorp.exploded.100k')
dataset = dataset.map(lambda string: tf.string_split([string], delimiter='\t').values)
source = dataset.map(lambda string: string[0])
target = dataset.map(lambda string: string[1])

source = source.map(lambda string: tf.string_split([string], delimiter=' ').values)
source = source.map(lambda words: table.lookup(words))
target = target.map(lambda string: tf.string_split([string], delimiter=' ').values)
target = target.map(lambda words: table.lookup(words))

src_tgt_dataset = tf.contrib.data.Dataset.zip((source, target))
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (src,
                      tf.concat(([SOS], tgt), 0),
                      tf.concat((tgt, [EOS]), 0),)
)
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in))
)

# Padded batch

In [418]:
# if I set the third padding shape to tf.TensorShape([5]),
# it fails if there is no 4 character long sample in the batch
# WHY???
batched = src_tgt_dataset.padded_batch(batch_size, padded_shapes=(
    tf.TensorShape([32]), tf.TensorShape([32]), tf.TensorShape([None]),
         tf.TensorShape([]), tf.TensorShape([])))
batched_iter = batched.make_initializable_iterator()
src_ids, tgt_in_ids, tgt_out_ids, src_size, tgt_size = batched_iter.get_next()

# Encoder

In [419]:
embedding = tf.get_variable("embedding", [vocab_size, 20], dtype=tf.float32)

encoder_emb_inp = tf.nn.embedding_lookup(embedding, src_ids)
    
encoder_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, dtype=tf.float32,
                                                   sequence_length=src_size)

# Decoder

In [420]:
decoder_initial_state = encoder_state
decoder_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
decoder_emb_inp = tf.nn.embedding_lookup(embedding, tgt_in_ids)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, tgt_size)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
sample_id = outputs.sample_id
output_proj = layers_core.Dense(vocab_size, name="output_projection")
logits = output_proj(outputs.rnn_output)

# Loss

In [421]:
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tgt_out_ids, logits=logits)
target_weights = tf.sequence_mask(tgt_size, tf.shape(tgt_out_ids)[1], tf.float32)
loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(batch_size)

# Optimizer and gradient update

In [422]:
optimizer =tf.train.AdamOptimizer(0.1)
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
update = optimizer.apply_gradients(zip(gradients, params))

# Starting session

In [423]:
sess = tf.InteractiveSession()
sess.run(table_initializer)
sess.run(batched_iter.initializer)
sess.run(tf.global_variables_initializer())

# Training

In [424]:
for i in range(500):
    sess.run(update)
    l = sess.run(loss)
    if i % 100 == 99:
        print("Iteration: {}, training loss: {}".format(i+1, l))

Iteration: 100, training loss: 37.884029388427734
Iteration: 200, training loss: 34.82209396362305
Iteration: 300, training loss: 35.557891845703125
Iteration: 400, training loss: 35.57179260253906
Iteration: 500, training loss: 35.73652267456055


In [425]:
#s, t1, t2 = sess.run([src_ids, tgt_in_ids, tgt_out_ids])

# Manual greedy decoding

NOTE: running logits iterates over the next batch in the dataset, so running this cell multiple times decodes a different batch in the dataset.

In [426]:
inv_vocab = {v: k for k, v in vocab.items()}
inv_vocab[-1] = 'UNK'
skip_symbols = ('PAD', 'SOS', 'EOS', 'UNK')

## Input and output labels

Greed: just take the highest probabilty along the last axis.

In [427]:
input_ids, out_probs = sess.run([src_ids, logits])
output_ids = out_probs.argmax(axis=-1)

output_ids.shape

(100, 27)

## Convert labels to characters

In [428]:
def decode_ids(input_ids, output_ids):
    decoded = []
    for sample_i in range(output_ids.shape[0]):
        input_sample = input_ids[sample_i]
        output_sample = output_ids[sample_i]
        input_decoded = [inv_vocab[s] for s in input_sample]
        input_decoded = ''.join(c for c in input_decoded if c not in skip_symbols)
        output_decoded = [inv_vocab[s] for s in output_sample]
        output_decoded = ''.join(c for c in output_decoded if c not in skip_symbols)
        decoded.append((input_decoded, output_decoded))
    return decoded
 
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

fiúpóló ---> fek+sa++
tiblábolnak ---> kekee+ar+ak
vzhiányai ---> faj+áas+y+ok
szerszámot. ---> kze++féer+rk+.
hévégén ---> fas+s+k
fantastic-álás ---> fely+t+aks+s++
cenzúrájuk ---> fs+++l++kák
normann-szaracén ---> kakt+l++á+káak++sk
felemelhetetlen ---> fel+fk+l+éel+t+et
sósavmentes ---> kz+s+l+áely+l
hadifoglyaikat ---> fat+kor+aákk++ok
lógatjátok ---> ke++++áá++k
legvitathatóbb ---> ker+fes+++áat++ka
hangtomptást ---> fat++kar+or+ás+t
eurócentre ---> klk++ks+++óa
alapgesztusának ---> kl+t+áyl+++k+s+rak
orgonaiskolát ---> klt+ryts+++r++k
vegyületettel ---> fal++l+t+t++t
reprodukálhassák ---> ker+ek+s+++áat+oá+
tudásszigetként ---> kek+ás+sáákyl+áak+
főrendszerekre ---> fe+se+++ézer+fk+re
végesbe ---> fag+sk+eő
rótehetségek ---> kt+á+salan++g+ok
anarchistáknál. ---> klyt+sak+a+k+ra++.
szolgáltatásformát ---> kzer++++áak+ás+sor+++k
ismeret.tény ---> kns+l+t+hery
nyencfalattal ---> kt++l+sfel+t+áak
megbánássá ---> feg+fe+yss+oa
biciklizőket ---> fekskeeks+é+s+rk
tejbolt ---> ke+tfer+


In [433]:
a, b, c = sess.run([src_ids, tgt_in_ids, tgt_out_ids])

In [435]:
a[:3]

array([[11, 23, 27, 37,  9, 26, 16,  9, 27, 37, 28,  9, 28, 28,  9, 18,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [16, 10, 27, 37, 31, 17, 28, 22, 18,  6,  9, 19,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [26,  4, 11, 11, 14, 19,  5, 28, 29, 18,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [437]:
b[:3]

array([[ 2, 11, 23, 27, 37,  9, 26, 16,  9, 27, 37, 28,  9, 28, 28,  9, 18,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 2, 16, 10, 27, 37, 31, 17, 28, 39, 22, 18, 39,  6,  9, 19,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 2, 26,  4, 11, 11, 14, 19,  5, 28, 29, 18,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [438]:
c[:3]

array([[11, 23, 27, 37,  9, 26, 16,  9, 27, 37, 28,  9, 28, 28,  9, 18,  1,
         0,  0,  0,  0,  0,  0,  0,  0],
       [16, 10, 27, 37, 31, 17, 28, 39, 22, 18, 39,  6,  9, 19,  1,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0],
       [26,  4, 11, 11, 14, 19,  5, 28, 29, 18,  1,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0]])

# Greedy decoding with `GreedyEmbeddingHelper`

The encoder stays the same but we need to redefine the decoder.

In [429]:
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch_size], SOS), EOS)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state,
                                         output_layer=output_proj)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=30)

In [430]:
input_ids, output_ids = sess.run([src_ids, outputs.sample_id])
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

hittankönyvét ---> fel+fel+és+szer+i
tépázásig ---> kiker+fel+.
hangrögztésünk ---> fel+fel+és+szer+kék
szgorú ---> kiker+ke
mesealakjai ---> fel+fel+és+ség
kamaraegyüttesnek ---> kiker+fel+és+szer+kék
kódexek- ---> kiker+kék
nádalkotmányt ---> kiker+fel+és+ség
szem-hasonlóság ---> kiker+fel+és+szer+i
részszakaszok ---> kiker+fel+és+ség
konszenzusnyelv ---> kiker+fel+és+szer+i
süvegje ---> kiker+ke
vakegér ---> fel+és+ek
tündérkedik ---> kiker+fel+és+t
távolságival ---> kiker+fel+és+ek
itáliától ---> kiker+fel+.
elhzok ---> kiker+ke
élethelyzetüknél ---> kiker+fel+és+szer+ke
öslakosokat ---> kiker+fel+és+t
moralizmus-pragmatizmus ---> fel+fel+és+szer+fel+és+szer+ke
omnidirekcionális ---> fel+fel+és+szer+fel+ő+k
kilátókocsikból ---> kiker+fel+és+szer+i
cisrkedarabokat ---> fel+fel+és+szer+kék
bizottságba ---> fel+fel+és+ség
időtorzulás ---> kiker+fel+és+t
nemzeti-polgári ---> kiker+fel+és+szer+i
bundestag-választásokat ---> fel+fel+és+szer+fel+és+szer+ke
rózsa-rozetta ---> kiker+fel+és+

# Beam search decoding

In [431]:
beam_width = 2
start_tokens = tf.fill([4], SOS)
bm_decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)
bm_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    cell=decoder_cell,
    embedding=embedding,
    start_tokens=start_tokens,
    initial_state=bm_decoder_initial_state,
    beam_width=beam_width,
    output_layer=output_proj,
    end_token=EOS,
)
bm_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(bm_decoder, maximum_iterations=100)

In [432]:
input_ids, output_ids = sess.run([src_ids, bm_outputs.predicted_ids])
output_ids

all_decoded = []
for beam_i in range(beam_width):
    inputs = []
    all_decoded.append([])
    decoded = decode_ids(input_ids, output_ids[:,:,beam_i])
    for dec in decoded:
        all_decoded[-1].append(dec[1])
        inputs.append(dec[0])

print('\n'.join(
    '{} ---> {}'.format(inputs[i], ' / '.join(d[i] for d in all_decoded))
                        for i in range(len(inputs))
))

InvalidArgumentError: Input to reshape is a tensor with 51200 values, but the requested shape has 2048
	 [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/gpu:0"](tile_batch/Reshape, concat_1)]]
	 [[Node: decoder_2/while/BeamSearchDecoderStep/truediv_1/_573 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_709_decoder_2/while/BeamSearchDecoderStep/truediv_1", tensor_type=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/cpu:0"](^_cloopdecoder_2/while/BeamSearchDecoderStep/Reshape_13/shape/_489)]]

Caused by op 'Reshape', defined at:
  File "/usr/lib/python3.4/runpy.py", line 170, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.4/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/ipykernel/kernelapp.py", line 474, in start
    ioloop.IOLoop.instance().start()
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tornado/ioloop.py", line 887, in start
    handler_func(fd_obj, events)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
    handler(stream, idents, msg)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/ipykernel/kernelbase.py", line 390, in execute_request
    user_expressions, allow_stdin)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/ipykernel/zmqshell.py", line 501, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/IPython/core/interactiveshell.py", line 2821, in run_ast_nodes
    if self.run_code(code, result):
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-431-96cf957ab113>", line 11, in <module>
    end_token=EOS,
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py", line 193, in __init__
    initial_state, self._cell.state_size)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/python/util/nest.py", line 325, in map_structure
    structure[0], [func(*x) for x in entries])
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/python/util/nest.py", line 325, in <listcomp>
    structure[0], [func(*x) for x in entries])
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py", line 374, in _maybe_split_batch_beams
    return self._split_batch_beams(t, s)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py", line 339, in _split_batch_beams
    ([self._batch_size, self._beam_width], t_shape[1:]), 0))
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/python/ops/gen_array_ops.py", line 2451, in reshape
    name=name)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/mnt/store/judit/.virtualenvs/deep/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 51200 values, but the requested shape has 2048
	 [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/gpu:0"](tile_batch/Reshape, concat_1)]]
	 [[Node: decoder_2/while/BeamSearchDecoderStep/truediv_1/_573 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_709_decoder_2/while/BeamSearchDecoderStep/truediv_1", tensor_type=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/cpu:0"](^_cloopdecoder_2/while/BeamSearchDecoderStep/Reshape_13/shape/_489)]]
