Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

<img src="img/self_attn.png" width="200">

In [1]:
"""
pip3 install tensor2tensor
"""
import tensorflow as tf
import numpy as np
from tensor2tensor.utils import beam_search

In [2]:
params = {
    'batch_size': 64,
    'text_iter_step': 25,
    'seq_len': 200,
    'hidden_dim': 128,
    'num_head': 8,
    'n_hidden_layer': 2,
    'display_step': 10,
    'generate_step': 100,
    'beam_size': 5,
}

In [3]:
def parse_text(file_path):
    with open(file_path) as f:
        text = f.read()
    
    char2idx = {c: i+3 for i, c in enumerate(set(text))}
    char2idx['<pad>'] = 0
    char2idx['<start>'] = 1
    char2idx['<end>'] = 2
    
    ints = np.array([char2idx[char] for char in list(text)])
    return ints, char2idx

def next_batch(ints):
    len_win = params['seq_len'] * params['batch_size']
    for i in range(0, len(ints)-len_win, params['text_iter_step']):
        clip = ints[i: i+len_win]
        yield clip.reshape([params['batch_size'], params['seq_len']])
        
def input_fn(ints):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_batch(ints), tf.int32, tf.TensorShape([None, params['seq_len']]))
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

In [4]:
def start_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<start>'])
    return tf.concat([_x, x], 1)

def end_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<end>'])
    return tf.concat([x, _x], 1)

def embed_seq(x, vocab_sz, embed_dim, name, zero_pad=False, scale=False):
    embedding = tf.get_variable(name, [vocab_sz, embed_dim])
    if zero_pad:
        embedding = tf.concat([tf.zeros([1, embed_dim]), embedding[1:, :]], 0)
    x = tf.nn.embedding_lookup(embedding, x)
    if scale:
        x = x * np.sqrt(embed_dim)
    return x

def position_encoding(inputs):
    repr_dim = inputs.get_shape()[-1].value
    pos = tf.reshape(tf.range(0.0, tf.to_float(tf.shape(inputs)[1]), dtype=tf.float32), [-1, 1])
    i = np.arange(0, repr_dim, 2, np.float32)
    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])
    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)
    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1])

def layer_norm(inputs, epsilon=1e-8):
    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))
    
    params_shape = inputs.get_shape()[-1:]
    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())
    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())
    
    return gamma * normalized + beta

def self_attention(inputs, is_training, activation=None):
    num_units = params['hidden_dim']
    num_heads = params['num_head']
    T_q = T_k = tf.shape(inputs)[1]

    Q_K_V = tf.layers.dense(inputs, 3*num_units, activation)
    Q, K, V = tf.split(Q_K_V, 3, -1)
    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), 0)                         
    K_ = tf.concat(tf.split(K, num_heads, axis=2), 0)                        
    V_ = tf.concat(tf.split(V, num_heads, axis=2), 0)                         

    align = tf.matmul(Q_, tf.transpose(K_, [0,2,1]))                               
    align = align / np.sqrt(K_.get_shape().as_list()[-1])

    paddings = tf.fill(tf.shape(align), float('-inf'))         
    lower_tri = tf.ones([T_q, T_k])                                                
    lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense()      
    masks = tf.tile(tf.expand_dims(lower_tri,0), [tf.shape(align)[0],1,1])       
    align = tf.where(tf.equal(masks, 0), paddings, align)               

    align = tf.nn.softmax(align)                                                  
    align = tf.layers.dropout(align, 0.1, training=is_training)           
    x = tf.matmul(align, V_)                                                 
    x = tf.concat(tf.split(x, num_heads, axis=0), 2)              
    x += inputs                                                                
    x = layer_norm(x)                                                 
    return x

def ffn(inputs, activation=tf.nn.relu):
    x = tf.layers.conv1d(inputs, 4*params['hidden_dim'], 1, activation=activation)
    x = tf.layers.conv1d(x, params['hidden_dim'], 1, activation=None)
    x += inputs
    x = layer_norm(x)
    return x

In [5]:
def forward(inputs, reuse, is_training):
    with tf.variable_scope('model', reuse=reuse):
        x = embed_seq(inputs, params['vocab_size'], params['hidden_dim'], 'word_embedding',
                      zero_pad=True, scale=True)
        x += position_encoding(x)
        x = tf.layers.dropout(x, 0.1, training=is_training)
        
        for i in range(params['n_hidden_layer']):
            with tf.variable_scope('attn_%d'%i, reuse=reuse):
                x = self_attention(x, is_training)
            with tf.variable_scope('ffn_%d'%i, reuse=reuse):
                x = ffn(x)
        
        logits = tf.layers.dense(x, params['vocab_size'])
    return logits

In [6]:
def beam_search_decoding():
    batch_size = 1
    initial_ids = tf.constant(params['char2idx']['<start>'], tf.int32, [batch_size])
    
    def symbols_to_logits(ids):
        logits = forward(ids, reuse=True, is_training=False)
        return logits[:, tf.shape(ids)[1]-1, :]
    
    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        params['beam_size'],
        params['seq_len'],
        params['vocab_size'],
        0.0,
        eos_id = params['char2idx']['<end>'])
    
    return final_ids[0, 0, :]

In [None]:
ints, params['char2idx'] = parse_text('../temp/anna.txt')
params['vocab_size'] = len(params['char2idx'])
params['idx2char'] = {i: c for c, i in params['char2idx'].items()}
print('Vocabulary size:', params['vocab_size'])

X = input_fn(ints)
logits = forward(start_sent(X), reuse=False, is_training=True)

ops = {}
ops['global_step'] = tf.Variable(0, trainable=False)

targets = end_sent(X)
ops['loss'] = tf.reduce_mean(tf.contrib.seq2seq.sequence_loss(
    logits = logits,
    targets = targets,
    weights = tf.to_float(tf.ones_like(targets))))

ops['train'] = tf.train.AdamOptimizer().minimize(ops['loss'], global_step=ops['global_step'])

ops['generate'] = beam_search_decoding()

Vocabulary size: 86


In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
while True:
    try:
        _, step, loss = sess.run([ops['train'], ops['global_step'], ops['loss']])
    except tf.errors.OutOfRangeError:
        break
    else:
        if step % params['display_step'] == 0 or step == 1:
            print("Step %d | Loss %.3f" % (step, loss))
        if step % params['generate_step'] == 0 and step > 1:
            ints = sess.run(ops['generate'])
            print('\n'+''.join([params['idx2char'][i] for i in ints])+'\n')

Step 1 | Loss 5.045
Step 10 | Loss 2.873
Step 20 | Loss 2.677
Step 30 | Loss 2.587
Step 40 | Loss 2.536
Step 50 | Loss 2.495
Step 60 | Loss 2.463
Step 70 | Loss 2.437
Step 80 | Loss 2.416
Step 90 | Loss 2.404
Step 100 | Loss 2.384

<start>in his his his his he his his his his his his his his his his his his his his he his his his his his his his his his his his his his his his his he his he his his his his his his his his he his his h<end>

Step 110 | Loss 2.367
Step 120 | Loss 2.348
Step 130 | Loss 2.322
Step 140 | Loss 2.296
Step 150 | Loss 2.271
Step 160 | Loss 2.254
Step 170 | Loss 2.234
Step 180 | Loss 2.201
Step 190 | Loss 2.194
Step 200 | Loss 2.144

<start>and this his his his his his his, his his his his his his his his his his his his, his his his his his his his his his his his his his his his his his his his his his his his hat his his his his his h

Step 210 | Loss 2.114
Step 220 | Loss 2.089
Step 230 | Loss 2.057
Step 240 | Loss 2.043
Step 250 | Loss 2.023
Step 260 | Loss

Step 1810 | Loss 0.643
Step 1820 | Loss 0.624
Step 1830 | Loss 0.628
Step 1840 | Loss 0.618
Step 1850 | Loss 0.639
Step 1860 | Loss 0.632
Step 1870 | Loss 0.646
Step 1880 | Loss 0.642
Step 1890 | Loss 0.640
Step 1900 | Loss 0.619

<start>ing," responded Oblonsky. "We'll talk it over. But what's
brought you up to town?"

"Oh, we'll talk about that, too, later on," said Levin, reddening again
up to his ears.

"All right. I seemed <end><pad><pad><pad><pad><pad>

Step 1910 | Loss 0.621
Step 1920 | Loss 0.612
Step 1930 | Loss 0.620
Step 1940 | Loss 0.606
Step 1950 | Loss 0.600
Step 1960 | Loss 0.625
Step 1970 | Loss 0.602
Step 1980 | Loss 0.606
Step 1990 | Loss 0.616
Step 2000 | Loss 0.612

<start>in them, but under the
poetical veil that shrouded them he assumed the existence of the
loftiest sentiments and every possible perfection. Why it was the three
young ladies had one day to speak French

