In [1]:
import os
import yaml
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from tensorflow.python.layers import core as layers_core

tf.reset_default_graph()

# Model parameters

Set `use_toy_data` to `True` for toy experiments. This will train the network on two unique examples.

The real dataset is morphological reinflection task: Hungarian nouns in the instrumental case.
Hungarian features both vowel harmony and assimilation.
A few examples are listed here (capitalization is added for emphasis):

| input | output | meaning | what happens |
| :-----: | :-----: | :-----: | :-----: |
| autó | autóval | with car | |
| Peti | Petiv**E**l | with Pete | vowel harmony |
| fej | fej**J**el | with head | assimilation |
| pálca | pálc**Á**val | with stick | low vowel lengthening |
| kulcs | kul**CCS**al | with key | digraph + assimilation |

This turns out to be a very easy task for a fairly small seq2seq model.

In [2]:
PROJECT_DIR = "../../"
use_toy_data = False
LOG_DIR = 'logs'  # Tensorboard log directory

if use_toy_data:
    batch_size = 8
    embedding_dim = 5
    cell_size = 32
    max_len = 6
else:
    batch_size = 64
    embedding_dim = 20
    cell_size = 128
    max_len = 33
    
use_attention = False
use_bidirectional_encoder = True
is_time_major = True

# Download data if necessary

The input data is expected in the following format:

~~~
i n p u t 1 TAB o u t p u t 1
i n p u t 2 TAB o u t p u t 2
~~~

Each line contains a single input-output pair separated by a TAB.
Tokens are space-separated.

In [3]:
if use_toy_data:
    input_fn = 'toy_input.txt'
    with open(input_fn, 'w') as f:
        f.write('a b c\td e f d e f\n')
        f.write('d e f\ta b c a b c\n')
else:
    DATA_DIR = '../../data/'
    input_fn = 'instrumental.full.train'
    input_fn = os.path.join(DATA_DIR, input_fn)
    if not os.path.exists(input_fn):
        import urllib
        u = urllib.request.URLopener()
        u.retrieve(
            "http://sandbox.mokk.bme.hu/~judit/resources/instrumental.full.train", input_fn)

# Load and preprocess data

In [4]:
class Dataset(object):
    PAD = 0
    SOS = 1
    EOS = 2
    UNK = 3
    #src_vocab = ['PAD', 'UNK']
    constants = ['PAD', 'SOS', 'EOS', 'UNK']
    hu_alphabet = list("aábcdeéfghiíjklmnoóöőpqrstuúüűvwxyz-+._")
    
    def __init__(self, fn, config, src_alphabet=None, tgt_alphabet=None):
        self.config = config
        self.create_tables(src_alphabet, tgt_alphabet)
        self.load_and_preproc_dataset(fn)
        
    def create_tables(self, src_alphabet, tgt_alphabet):
        if src_alphabet is None:
            self.src_vocab = Dataset.constants + Dataset.hu_alphabet
        else:
            self.src_vocab = Dataset.constants + alphabet
        self.src_table = lookup_ops.index_table_from_tensor(
            tf.constant(self.src_vocab), default_value=Dataset.UNK
        )
        if self.config.share_vocab:
            self.tgt_vocab = self.src_vocab
            self.tgt_table = self.src_table
        else:
            if tgt_alphabet is None:
                self.tgt_vocab = Dataset.constants + Dataset.hu_alphabet
            else:
                self.tgt_vocab = Dataset.constants + alphabet
            self.tgt_table = lookup_ops.index_table_from_tensor(
                tf.constant(self.tgt_vocab), default_value=Dataset.UNK
            )
        self.src_vocab_size = len(self.src_vocab)
        self.tgt_vocab_size = len(self.tgt_vocab)
    
    def load_and_preproc_dataset(self, fn):
        dataset = tf.contrib.data.TextLineDataset(fn)
        dataset = dataset.repeat()
        dataset = dataset.map(lambda s: tf.string_split([s], delimiter='\t').values)
        
        src = dataset.map(lambda s: s[0])
        tgt = dataset.map(lambda s: s[1])
        
        src = src.map(lambda s: tf.string_split([s], delimiter=' ').values)
        src = src.map(lambda s: s[:self.config.src_maxlen])
        tgt = tgt.map(lambda s: tf.string_split([s], delimiter=' ').values)
        tgt = tgt.map(lambda s: s[:self.config.tgt_maxlen])
        
        src = src.map(lambda words: self.src_table.lookup(words))
        tgt = tgt.map(lambda words: self.tgt_table.lookup(words))
        
        dataset = tf.contrib.data.Dataset.zip((src, tgt))
        dataset = dataset.map(
            lambda src, tgt: (
                src,
                tf.concat(([Dataset.SOS], tgt), 0),
                tf.concat((tgt, [Dataset.EOS]), 0),
            )
        )
        dataset = dataset.map(
            lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in))
        )
        batched = dataset.padded_batch(
            self.config.batch_size,
            padded_shapes=(
                tf.TensorShape([self.config.src_maxlen]),
                tf.TensorShape([self.config.tgt_maxlen+2]),
                tf.TensorShape([None]),
                tf.TensorShape([]),
                tf.TensorShape([]),
            )
        )
        self.batched_iter = batched.make_initializable_iterator()
        s = self.batched_iter.get_next()
        self.src_ids = s[0]
        self.tgt_in_ids = s[1]
        self.tgt_out_ids = s[2]
        self.src_size = s[3]
        self.tgt_size = s[4]
        
    def run_initializers(self, session):
        session.run(tf.tables_initializer())
        session.run(self.batched_iter.initializer)

# Create model

## Embedding

The input and output embeddings are the same.

In [5]:
class Config(object):
    default_fn = os.path.join(
        PROJECT_DIR, "config", "seq2seq", "default.yaml"
    )
    
    @staticmethod
    def load_defaults(fn=default_fn):
        with open(fn) as f:
            return yaml.load(f)
    
    @classmethod
    def from_yaml(cls, fn):
        params = yaml.load(fn)
        return cls(**params)
    
    def __init__(self, **kwargs):
        defaults = Config.load_defaults()
        for param, val in defaults.items():
            setattr(self, param, val)
        for param, val in kwargs.items():
            setattr(self, param, val)
        
config = Config(src_maxlen=30, tgt_maxlen=33)
dataset = Dataset(input_fn, config)

In [6]:
with tf.variable_scope("embedding"):
    embedding = tf.get_variable("embedding", [dataset.src_vocab_size, embedding_dim], dtype=tf.float32)
    embedding_input = tf.nn.embedding_lookup(embedding, dataset.src_ids)
    decoder_emb_inp = tf.nn.embedding_lookup(embedding, dataset.tgt_in_ids)
    if is_time_major:
        embedding_input = tf.transpose(embedding_input, [1, 0, 2])
        decoder_emb_inp = tf.transpose(decoder_emb_inp, [1, 0, 2])

## Encoder

In [7]:
with tf.variable_scope("encoder"):
    
    if use_bidirectional_encoder:
        fw_cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size)
        fw_cell = tf.contrib.rnn.DropoutWrapper(fw_cell, input_keep_prob=0.8)
        bw_cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size)
        bw_cell = tf.contrib.rnn.DropoutWrapper(bw_cell, input_keep_prob=0.8)

        o, e = tf.nn.bidirectional_dynamic_rnn(
            fw_cell, bw_cell, embedding_input, dtype='float32', sequence_length=dataset.src_size,
            time_major=is_time_major)
        encoder_outputs = tf.concat(o, -1)
        encoder_state = e
    
    else:
        fw_cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size)
        fw_cell = tf.contrib.rnn.DropoutWrapper(fw_cell, input_keep_prob=0.8)
        o, e = tf.nn.dynamic_rnn(fw_cell, embedding_input, dtype='float32',
                                 sequence_length=dataset.src_size, time_major=is_time_major)
        encoder_outputs = o
        encoder_state = e
    

## Decoder

In [8]:
with tf.variable_scope("decoder", dtype="float32") as scope:
    if use_bidirectional_encoder:
        decoder_cells = []
        for i in range(2):
            decoder_cell = tf.contrib.rnn.BasicLSTMCell(cell_size)
            decoder_cell = tf.contrib.rnn.DropoutWrapper(decoder_cell, input_keep_prob=0.8)
            decoder_cells.append(decoder_cell)
        decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells)

        if use_attention:
            if is_time_major:
                attention_states = tf.transpose(encoder_outputs, [1, 0, 2])
            else:
                attention_states = encoder_outputs
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                cell_size, attention_states, memory_sequence_length=dataset.src_size,
                scale=True
            )
            decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                decoder_cell, attention_mechanism, attention_layer_size=cell_size,
                name="attention"
            )
            if is_time_major:
                decoder_initial_state = decoder_cell.zero_state(
                    tf.shape(decoder_emb_inp)[1], tf.float32).clone(cell_state=encoder_state)
            else:
                decoder_initial_state = decoder_cell.zero_state(
                    tf.shape(decoder_emb_inp)[0], tf.float32).clone(cell_state=encoder_state)
        else:
            decoder_initial_state = encoder_state
            
    else:
        decoder_cell = tf.contrib.rnn.BasicLSTMCell(cell_size)
        decoder_initial_state = encoder_state
        
    helper = tf.contrib.seq2seq.TrainingHelper(
        decoder_emb_inp, dataset.tgt_size, time_major=is_time_major)
    decoder = tf.contrib.seq2seq.BasicDecoder(
        decoder_cell, helper, decoder_initial_state)
    
    outputs, final, _ = tf.contrib.seq2seq.dynamic_decode(
        decoder, output_time_major=is_time_major, swap_memory=True, scope=scope)
    
    output_proj = layers_core.Dense(dataset.tgt_vocab_size, name="output_proj")
    logits = output_proj(outputs.rnn_output)
    
    

## Loss and training operations

In [9]:
with tf.variable_scope("train"):
    if is_time_major:
        logits = tf.transpose(logits, [1, 0, 2])
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=dataset.tgt_out_ids, logits=logits)
        target_weights = tf.sequence_mask(dataset.tgt_size, tf.shape(logits)[1], tf.float32)
    else:
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=dataset.tgt_out_ids, logits=logits)
        target_weights = tf.sequence_mask(dataset.tgt_size, tf.shape(logits)[1], tf.float32)
    loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(batch_size)
    tf.summary.scalar("loss", loss)

    learning_rate = tf.placeholder(dtype=tf.float32, name="learning_rate")
    max_global_norm = tf.placeholder(dtype=tf.float32, name="max_global_norm")
    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.5)
    params = tf.trainable_variables()
    gradients = tf.gradients(loss, params)
    for grad, var in zip(gradients, params):
        tf.summary.histogram(var.op.name+'/gradient', grad)
    gradients, _ = tf.clip_by_global_norm(gradients, max_global_norm)
    for grad, var in zip(gradients, params):
        tf.summary.histogram(var.op.name+'/clipped_gradient', grad)
    update = optimizer.apply_gradients(zip(gradients, params))

## Greedy decoder for inference

In [10]:
with tf.variable_scope("greedy_decoder"):
    g_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
        embedding, tf.fill([dataset.config.batch_size], dataset.SOS), dataset.EOS)
    g_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, g_helper, decoder_initial_state,
                                             output_layer=output_proj)

    g_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(g_decoder, maximum_iterations=30)

## Beam search decoder

In [11]:
with tf.variable_scope("beam_search"):
    beam_width = 4
    start_tokens = tf.fill([config.batch_size], dataset.SOS)
    bm_dec_initial_state = tf.contrib.seq2seq.tile_batch(
        encoder_state, multiplier=beam_width)
    bm_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        cell=decoder_cell,
        embedding=embedding,
        start_tokens=start_tokens,
        initial_state=bm_dec_initial_state,
        beam_width=beam_width,
        output_layer=output_proj,
        end_token=dataset.EOS
    )
    bm_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        bm_decoder, maximum_iterations=config.tgt_maxlen)

# Starting session

In [None]:
#sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
sess = tf.Session()
dataset.run_initializers(sess)
sess.run(tf.global_variables_initializer())

merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 's2s_sandbox', 'tmp'))
writer.add_graph(sess.graph)

# Training

In [None]:
%%time

def train(epochs, logstep, lr):
    print("Running {} epochs with learning rate {}".format(epochs, lr))
    for i in range(epochs):
        _, s = sess.run([update, merged_summary], feed_dict={learning_rate: lr, max_global_norm: 5.0})
        l = sess.run(loss)
        writer.add_summary(s, i)
        if i % logstep == logstep - 1:
            print("Iter {}, learning rate {}, loss {}".format(i+1, lr, l))
            
print("Start training...")
if use_toy_data:
    train(100, 10, .5)
else:
    train(350, 50, 1)
    train(1000, 100, 0.1)
    train(1000, 100, 0.01)

Start training...
Running 350 epochs with learning rate 1
Iter 50, learning rate 1, loss 45.26033020019531
Iter 100, learning rate 1, loss 34.81024169921875
Iter 150, learning rate 1, loss 31.522985458374023
Iter 200, learning rate 1, loss 23.517223358154297
Iter 250, learning rate 1, loss 18.17040252685547
Iter 300, learning rate 1, loss 16.15471839904785
Iter 350, learning rate 1, loss 13.37446403503418
Running 1000 epochs with learning rate 0.1
Iter 100, learning rate 0.1, loss 9.242836952209473
Iter 200, learning rate 0.1, loss 7.943110466003418
Iter 300, learning rate 0.1, loss 8.02791690826416
Iter 400, learning rate 0.1, loss 7.113648414611816
Iter 500, learning rate 0.1, loss 7.066655158996582
Iter 600, learning rate 0.1, loss 6.1783928871154785


# Inference

In [None]:
inv_vocab = {i: v for i, v in enumerate(dataset.tgt_vocab)}
inv_vocab[-1] = 'UNK'
skip_symbols = ('PAD',)

def decode_ids(input_ids, output_ids):
    decoded = []
    for sample_i in range(output_ids.shape[0]):
        input_sample = input_ids[sample_i]
        output_sample = output_ids[sample_i]
        input_decoded = [inv_vocab[s] for s in input_sample]
        input_decoded = ''.join(c for c in input_decoded if c not in skip_symbols)
        output_decoded = [inv_vocab[s] for s in output_sample]
        try:
            eos_idx = output_decoded.index('EOS')
        except ValueError:  # EOS not in list
            eos_idx = len(output_decoded)
        output_decoded = output_decoded[:eos_idx]
        output_decoded = ''.join(c for c in output_decoded if c not in skip_symbols)
        decoded.append((input_decoded, output_decoded))
    return decoded

input_ids, output_ids, bm_output_ids = sess.run([dataset.src_ids, g_outputs.sample_id,
                                                bm_outputs.predicted_ids])
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

## Beam search decoding

In [None]:
all_decoded = []
for beam_i in range(beam_width):
    inputs = []
    all_decoded.append([])
    decoded = decode_ids(input_ids, bm_output_ids[:,:,beam_i])
    for dec in decoded:
        all_decoded[-1].append(dec[1])
        inputs.append(dec[0])

print('\n'.join(
    '{} ---> {}'.format(inputs[i], ' / '.join(d[i] for d in all_decoded))
                        for i in range(len(inputs))
))