In [1]:
import os
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from tensorflow.python.layers import core as layers_core

tf.reset_default_graph()

# Model parameters

Set `use_toy_data` to `True` for toy experiments. This will train the network on two unique examples.

The real dataset is morphological reinflection task: Hungarian nouns in the instrumental case.
Hungarian features both vowel harmony and assimilation.
A few examples are listed here (capitalization is added for emphasis):

| input | output | meaning | what happens |
| :-----: | :-----: | :-----: | :-----: |
| autó | autóval | with car | |
| Peti | Petiv**E**l | with Pete | vowel harmony |
| fej | fej**J**el | with head | assimilation |
| pálca | pálc**Á**val | with stick | low vowel lengthening |
| kulcs | kul**CCS**al | with key | digraph + assimilation |

This turns out to be a very easy task for a fairly small seq2seq model.

In [2]:
use_toy_data = False
LOG_DIR = 'logs'  # Tensorboard log directory

if use_toy_data:
    batch_size = 8
    embedding_dim = 5
    cell_size = 32 
    max_len = 6
else:
    batch_size = 64
    embedding_dim = 20
    cell_size = 128
    max_len = 33
    
use_attention = True
use_bidirectional_encoder = True
is_time_major = True

# Download data if necessary

The input data is expected in the following format:

~~~
i n p u t 1 TAB o u t p u t 1
i n p u t 2 TAB o u t p u t 2
~~~

Each line contains a single input-output pair separated by a TAB.
Tokens are space-separated.

In [3]:
if use_toy_data:
    input_fn = 'toy_input.txt'
    with open(input_fn, 'w') as f:
        f.write('a b c\td e f d e f\n')
        f.write('d e f\ta b c a b c\n')
else:
    DATA_DIR = '../../data/'
    input_fn = 'instrumental.full.train'
    input_fn = os.path.join(DATA_DIR, input_fn)
    if not os.path.exists(input_fn):
        import urllib
        u = urllib.request.URLopener()
        u.retrieve(
            "http://sandbox.mokk.bme.hu/~judit/resources/instrumental.full.train", input_fn)

# Load and preprocess data

In [4]:
if use_toy_data:
    vocab = ['PAD', 'UNK', 'EOS', 'SOS'] + list("abcdef")
else:
    vocab = ['PAD', 'UNK', 'EOS', 'SOS'] + list("aábcdeéfghiíjklmnoóöőpqrstuúüűvwxyz-+._")
EOS = 2  # end of sentence
SOS = 3  # start of sentence (GO symbol)
table = lookup_ops.index_table_from_tensor(tf.constant(vocab), default_value=1)
vocab = {k: i for i, k in enumerate(vocab)}
vocab_size = len(vocab)

table_initializer = tf.tables_initializer()

dataset = tf.contrib.data.TextLineDataset(input_fn)
dataset = dataset.repeat()
dataset = dataset.map(lambda string: tf.string_split([string], delimiter='\t').values)
source = dataset.map(lambda string: string[0])
target = dataset.map(lambda string: string[1])

source = source.map(lambda string: tf.string_split([string], delimiter=' ').values)
source = source.map(lambda words: table.lookup(words))
target = target.map(lambda string: tf.string_split([string], delimiter=' ').values)
target = target.map(lambda words: table.lookup(words))

src_tgt_dataset = tf.contrib.data.Dataset.zip((source, target))
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (src,
                      tf.concat(([SOS], tgt), 0),
                      tf.concat((tgt, [EOS]), 0),)
)
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in))
)

batched = src_tgt_dataset.padded_batch(batch_size, padded_shapes=(
    tf.TensorShape([max_len+2]), tf.TensorShape([max_len+2]), tf.TensorShape([None]),
         tf.TensorShape([]), tf.TensorShape([])))
batched_iter = batched.make_initializable_iterator()
src_ids, tgt_in_ids, tgt_out_ids, src_size, tgt_size = batched_iter.get_next()

# Create model

## Embedding

The input and output embeddings are the same.

In [5]:
with tf.variable_scope("embedding"):
    embedding = tf.get_variable("embedding", [vocab_size, embedding_dim], dtype=tf.float32)
    embedding_input = tf.nn.embedding_lookup(embedding, src_ids)
    decoder_emb_inp = tf.nn.embedding_lookup(embedding, tgt_in_ids)
    if is_time_major:
        embedding_input = tf.transpose(embedding_input, [1, 0, 2])
        decoder_emb_inp = tf.transpose(decoder_emb_inp, [1, 0, 2])

## Encoder

In [6]:
with tf.variable_scope("encoder"):
    
    if use_bidirectional_encoder:
        fw_cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size)
        fw_cell = tf.contrib.rnn.DropoutWrapper(fw_cell, input_keep_prob=0.8)
        bw_cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size)
        bw_cell = tf.contrib.rnn.DropoutWrapper(bw_cell, input_keep_prob=0.8)

        o, e = tf.nn.bidirectional_dynamic_rnn(
            fw_cell, bw_cell, embedding_input, dtype='float32', sequence_length=src_size,
            time_major=is_time_major)
        encoder_outputs = tf.concat(o, -1)
        encoder_state = e
    
    else:
        fw_cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size)
        fw_cell = tf.contrib.rnn.DropoutWrapper(fw_cell, input_keep_prob=0.8)
        o, e = tf.nn.dynamic_rnn(fw_cell, embedding_input, dtype='float32',
                                 sequence_length=src_size, time_major=is_time_major)
        encoder_outputs = o
        encoder_state = e
    

## Decoder

In [7]:
with tf.variable_scope("decoder", dtype="float32") as scope:
    if use_bidirectional_encoder:
        decoder_cells = []
        for i in range(2):
            decoder_cell = tf.contrib.rnn.BasicLSTMCell(cell_size)
            decoder_cell = tf.contrib.rnn.DropoutWrapper(decoder_cell, input_keep_prob=0.8)
            decoder_cells.append(decoder_cell)
        decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells)

        if use_attention:
            if is_time_major:
                attention_states = tf.transpose(encoder_outputs, [1, 0, 2])
            else:
                attention_states = encoder_outputs
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                cell_size, attention_states, memory_sequence_length=src_size,
                scale=True
            )
            decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                decoder_cell, attention_mechanism, attention_layer_size=cell_size,
                name="attention"
            )
            if is_time_major:
                decoder_initial_state = decoder_cell.zero_state(
                    tf.shape(decoder_emb_inp)[1], tf.float32).clone(cell_state=encoder_state)
            else:
                decoder_initial_state = decoder_cell.zero_state(
                    tf.shape(decoder_emb_inp)[0], tf.float32).clone(cell_state=encoder_state)
        else:
            decoder_initial_state = encoder_state
            
    else:
        decoder_cell = tf.contrib.rnn.BasicLSTMCell(cell_size)
        decoder_initial_state = encoder_state
        
    helper = tf.contrib.seq2seq.TrainingHelper(
        decoder_emb_inp, tgt_size, time_major=is_time_major)
    decoder = tf.contrib.seq2seq.BasicDecoder(
        decoder_cell, helper, decoder_initial_state)
    
    outputs, final, _ = tf.contrib.seq2seq.dynamic_decode(
        decoder, output_time_major=is_time_major, swap_memory=True, scope=scope)
    
    output_proj = layers_core.Dense(vocab_size, name="output_proj")
    logits = output_proj(outputs.rnn_output)
    
    

## Loss and training operations

In [8]:
with tf.variable_scope("train"):
    if is_time_major:
        logits = tf.transpose(logits, [1, 0, 2])
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tgt_out_ids, logits=logits)
        target_weights = tf.sequence_mask(tgt_size, tf.shape(logits)[1], tf.float32)
    else:
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tgt_out_ids, logits=logits)
        target_weights = tf.sequence_mask(tgt_size, tf.shape(logits)[1], tf.float32)
    loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(batch_size)
    tf.summary.scalar("loss", loss)

    learning_rate = tf.placeholder(dtype=tf.float32, name="learning_rate")
    max_global_norm = tf.placeholder(dtype=tf.float32, name="max_global_norm")
    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.5)
    params = tf.trainable_variables()
    gradients = tf.gradients(loss, params)
    for grad, var in zip(gradients, params):
        tf.summary.histogram(var.op.name+'/gradient', grad)
    gradients, _ = tf.clip_by_global_norm(gradients, max_global_norm)
    for grad, var in zip(gradients, params):
        tf.summary.histogram(var.op.name+'/clipped_gradient', grad)
    update = optimizer.apply_gradients(zip(gradients, params))

## Greedy decoder for inference

In [9]:
g_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch_size], SOS), EOS)
g_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, g_helper, decoder_initial_state,
                                         output_layer=output_proj)

g_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(g_decoder, maximum_iterations=30)

# Starting session

In [10]:
sess = tf.Session()
sess.run(table_initializer)
sess.run(batched_iter.initializer)
sess.run(tf.global_variables_initializer())

merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 's2s_sandbox', 'tmp'))
writer.add_graph(sess.graph)

# Training

In [11]:
%%time

def train(epochs, logstep, lr):
    print("Running {} epochs with learning rate {}".format(epochs, lr))
    for i in range(epochs):
        _, s = sess.run([update, merged_summary], feed_dict={learning_rate: lr, max_global_norm: 5.0})
        l = sess.run(loss)
        writer.add_summary(s, i)
        if i % logstep == logstep - 1:
            print("Iter {}, learning rate {}, loss {}".format(i+1, lr, l))
            
print("Start training...")
if use_toy_data:
    train(100, 10, .5)
else:
    train(500, 50, 1)
    train(1000, 100, 0.1)

Start training...
Running 500 epochs with learning rate 1
Iter 50, learning rate 1, loss 49.17654800415039
Iter 100, learning rate 1, loss 42.914772033691406
Iter 150, learning rate 1, loss 35.850791931152344
Iter 200, learning rate 1, loss 23.160964965820312
Iter 250, learning rate 1, loss 10.841547966003418
Iter 300, learning rate 1, loss 7.734435081481934
Iter 350, learning rate 1, loss 4.6222639083862305
Iter 400, learning rate 1, loss 5.883413314819336
Iter 450, learning rate 1, loss 5.117784023284912
Iter 500, learning rate 1, loss 7.258327960968018
Running 1000 epochs with learning rate 0.1
Iter 100, learning rate 0.1, loss 1.5203056335449219
Iter 200, learning rate 0.1, loss 1.5339192152023315
Iter 300, learning rate 0.1, loss 1.2612941265106201
Iter 400, learning rate 0.1, loss 1.3094332218170166
Iter 500, learning rate 0.1, loss 1.751520037651062
Iter 600, learning rate 0.1, loss 0.835232675075531
Iter 700, learning rate 0.1, loss 0.7679431438446045
Iter 800, learning rate 0.

# Inference

In [12]:
inv_vocab = {v: k for k, v in vocab.items()}
skip_symbols = ('PAD',)

def decode_ids(input_ids, output_ids):
    decoded = []
    for sample_i in range(output_ids.shape[0]):
        input_sample = input_ids[sample_i]
        output_sample = output_ids[sample_i]
        input_decoded = [inv_vocab[s] for s in input_sample]
        input_decoded = ''.join(c for c in input_decoded if c not in skip_symbols)
        output_decoded = [inv_vocab[s] for s in output_sample]
        try:
            eos_idx = output_decoded.index('EOS')
        except ValueError:  # EOS not in list
            eos_idx = len(output_decoded)
        output_decoded = output_decoded[:eos_idx]
        output_decoded = ''.join(c for c in output_decoded if c not in skip_symbols)
        decoded.append((input_decoded, output_decoded))
    return decoded

input_ids, output_ids = sess.run([src_ids, g_outputs.sample_id])
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

fémgyapot ---> fémgyapottal
bútorvasalat ---> bútorvasalattal
vargha-gyógymód ---> vargha-gyógymóddal
energiatartalom ---> energiatartalommal
gravisz ---> gavisszel
projektdíj ---> projektdíjjal
fogvatartás ---> fogvatartással
gépésztervező ---> gépésztervezővel
béla-alkotás ---> béla-alkotással
UNKdörzspapír ---> drrzspapírral
priusz ---> priusszal
terméktámogatás ---> terméktámogatással
nirvána ---> nirvánával
zuhanórepülés ---> zuhanórepüléssel
szerep-teljesítés ---> szerep-teljesítéssel
részterhelés ---> részterheléssel
kávéháztörténet ---> kávézténettel
töretés ---> töretéssel
burgonyafánk ---> burgonyafánkkal
tengőd ---> tengővel
alapkőzet ---> alapkőzettel
palánta ---> palántával
internet-cím ---> internet-címmel
jet-ski ---> jet-skivel
zalahús ---> zalahússal
porkoncentráció ---> porkoncentrációval
aufschlág ---> aufschlággal
púderecset ---> púderecsettel
igazgató ---> igazgatóval
magángép ---> magángéppel
nyelvtehetség ---> nyelvtehetséggel
biomanipuláció ---> biomanipulációva