![title](img/self_attn.png)

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
params = {
    'batch_size': 64,
    'text_iter_step': 25,
    'seq_len': 200,
    'hidden_dim': 128,
    'num_head': 8,
    'n_hidden_layer': 2,
    'display_step': 10,
    'generate_step': 100,
}

In [3]:
def parse_text(file_path):
    with open(file_path) as f:
        text = f.read()
    
    char2idx = {c: i+3 for i, c in enumerate(set(text))}
    char2idx['<pad>'] = 0
    char2idx['<start>'] = 1
    char2idx['<end>'] = 2
    
    ints = np.array([char2idx[char] for char in list(text)])
    return ints, char2idx

def next_batch(ints):
    len_win = params['seq_len'] * params['batch_size']
    for i in range(0, len(ints)-len_win, params['text_iter_step']):
        clip = ints[i: i+len_win]
        yield clip.reshape([params['batch_size'], params['seq_len']])
        
def input_fn(ints):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_batch(ints), tf.int32, tf.TensorShape([None, params['seq_len']]))
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

In [4]:
def start_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<start>'])
    return tf.concat([_x, x], 1)

def end_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<end>'])
    return tf.concat([x, _x], 1)

def embed_seq(x, vocab_sz, embed_dim, name, zero_pad=False, scale=False):
    embedding = tf.get_variable(name, [vocab_sz, embed_dim])
    if zero_pad:
        embedding = tf.concat([tf.zeros([1, embed_dim]), embedding[1:, :]], 0)
    x = tf.nn.embedding_lookup(embedding, x)
    if scale:
        x = x * np.sqrt(embed_dim)
    return x

def position_embedding(inputs):
    T = inputs.get_shape().as_list()[1]
    x = tf.range(T)                            # (T)
    x = tf.expand_dims(x, 0)                   # (1, T)
    x = tf.tile(x, [tf.shape(inputs)[0], 1])   # (N, T)
    return embed_seq(x, T, params['hidden_dim'], 'position_embedding')

def layer_norm(inputs, epsilon=1e-8):
    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))
    
    params_shape = inputs.get_shape()[-1:]
    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())
    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())
    
    return gamma * normalized + beta

def self_attention(inputs, is_training, activation=None):
    num_units = params['hidden_dim']
    num_heads = params['num_head']
    T_q = T_k = inputs.get_shape().as_list()[1]

    Q_K_V = tf.layers.dense(inputs, 3*num_units, activation)
    Q, K, V = tf.split(Q_K_V, 3, -1)
    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), 0)                         
    K_ = tf.concat(tf.split(K, num_heads, axis=2), 0)                        
    V_ = tf.concat(tf.split(V, num_heads, axis=2), 0)                         

    align = tf.matmul(Q_, tf.transpose(K_, [0,2,1]))                               
    align = align / np.sqrt(K_.get_shape().as_list()[-1])

    paddings = tf.fill(tf.shape(align), float('-inf'))         
    lower_tri = tf.ones([T_q, T_k])                                                
    lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense()      
    masks = tf.tile(tf.expand_dims(lower_tri,0), [tf.shape(align)[0],1,1])       
    align = tf.where(tf.equal(masks, 0), paddings, align)               

    align = tf.nn.softmax(align)                                                  
    align = tf.layers.dropout(align, 0.1, training=is_training)           
    x = tf.matmul(align, V_)                                                 
    x = tf.concat(tf.split(x, num_heads, axis=0), 2)              
    x += inputs                                                                
    x = layer_norm(x)                                                 
    return x

def ffn(inputs, activation=tf.nn.relu):
    x = tf.layers.conv1d(inputs, 4*params['hidden_dim'], 1, activation=activation)
    x = tf.layers.conv1d(x, params['hidden_dim'], 1, activation=None)
    x += inputs
    x = layer_norm(x)
    return x

In [5]:
def forward(inputs, reuse, is_training):
    inputs = start_sent(inputs)
    with tf.variable_scope('model', reuse=reuse):
        x = embed_seq(inputs, params['vocab_size'], params['hidden_dim'], 'word_embedding',
                      zero_pad=True, scale=True)
        x += position_embedding(x)
        x = tf.layers.dropout(x, 0.1, training=is_training)
        
        for i in range(params['n_hidden_layer']):
            with tf.variable_scope('attn_%d'%i, reuse=reuse):
                x = self_attention(x, is_training)
            with tf.variable_scope('ffn_%d'%i, reuse=reuse):
                x = ffn(x)
        
        logits = tf.layers.dense(x, params['vocab_size'])
    return logits

In [6]:
def autoregressive():
    def cond(i, x, temp):
        return i < params['seq_len']

    def body(i, x, temp):
        logits = forward(x, reuse=True, is_training=False)
        ids = tf.argmax(logits, -1, output_type=tf.int32)[:, i]
        ids = tf.expand_dims(ids, -1)

        temp = tf.concat([temp[:, 1:], ids], -1)

        x = tf.concat([temp[:, -(i+1):], temp[:, :-(i+1)]], -1)
        x = tf.reshape(x, [1, params['seq_len']])
        i += 1
        return i, x, temp

    x = tf.zeros([1, params['seq_len']], tf.int32)
    _, res, _ = tf.while_loop(cond, body, [tf.constant(0), x, x])
    
    return res[0]

In [None]:
ints, params['char2idx'] = parse_text('../temp/anna.txt')
params['vocab_size'] = len(params['char2idx'])
params['idx2char'] = {i: c for c, i in params['char2idx'].items()}
print('Vocabulary size:', params['vocab_size'])

X = input_fn(ints)
logits = forward(X, reuse=False, is_training=True)

ops = {}
ops['global_step'] = tf.Variable(0, trainable=False)

targets = end_sent(X)
ops['loss'] = tf.reduce_mean(tf.contrib.seq2seq.sequence_loss(
    logits = logits,
    targets = targets,
    weights = tf.to_float(tf.ones_like(targets))))

ops['train'] = tf.train.AdamOptimizer().minimize(ops['loss'], global_step=ops['global_step'])

ops['generate'] = autoregressive()

Vocabulary size: 86
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead


In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
while True:
    try:
        sess.run(ops['train'])
    except tf.errors.OutOfRangeError:
        break
    else:
        step = sess.run(ops['global_step'])
        if step % params['display_step'] == 0 or step == 1:
            loss = sess.run(ops['loss'])
            print("Step %d | Loss %.3f" % (step, loss))
        if step % params['generate_step'] == 0 and step > 1:
            ints = sess.run(ops['generate'])
            print('\n'+''.join([params['idx2char'][i] for i in ints])+'\n')

Step 1 | Loss 3.928
Step 10 | Loss 2.773
Step 20 | Loss 2.598
Step 30 | Loss 2.524
Step 40 | Loss 2.486
Step 50 | Loss 2.462
Step 60 | Loss 2.435
Step 70 | Loss 2.416
Step 80 | Loss 2.395
Step 90 | Loss 2.384
Step 100 | Loss 2.366

and the he he he he he he he he he he he he he he he he he he han he he he he he he he he han he he he he he he he he he he han he he he he he he he he he he he he he he he he he he he he he he han he

Step 110 | Loss 2.348
Step 120 | Loss 2.334
Step 130 | Loss 2.311
Step 140 | Loss 2.293
Step 150 | Loss 2.271
Step 160 | Loss 2.254
Step 170 | Loss 2.242
Step 180 | Loss 2.213
Step 190 | Loss 2.191
Step 200 | Loss 2.174

 and the the the the is he he he he se se se he se sean the the se the cofon he han he he he han hathe he he he he hand he se he he he se se se hand se se the se he suse se thatheathe se hathe suse th

Step 210 | Loss 2.161
Step 220 | Loss 2.151
Step 230 | Loss 2.119
Step 240 | Loss 2.103
Step 250 | Loss 2.096
Step 260 | Loss 2.070
Step 270 | 

Step 1940 | Loss 0.645
Step 1950 | Loss 0.656
Step 1960 | Loss 0.660
Step 1970 | Loss 0.642
Step 1980 | Loss 0.649
Step 1990 | Loss 0.657
Step 2000 | Loss 0.647

 as as hart at you would dordis, wifith anded
adis that he und noto Oblonsky, whe mar ther
way proferour hat ad atter the tatired ap turon a phile int his my to
whim. Buter the connection wed he that 

Step 2010 | Loss 0.654
Step 2020 | Loss 0.632
Step 2030 | Loss 0.629
Step 2040 | Loss 0.656
Step 2050 | Loss 0.666
Step 2060 | Loss 0.653
Step 2070 | Loss 0.657
Step 2080 | Loss 0.638
Step 2090 | Loss 0.632
Step 2100 | Loss 0.647

 that he his
yould mer un they questin dis hner ot begice him.

"I cant sanot to the de nof asent o meent Levin went he to profent she prort, and she went
back to his argument her, and the tin own shi

Step 2110 | Loss 0.651
Step 2120 | Loss 0.643
Step 2130 | Loss 0.651
Step 2140 | Loss 0.650
Step 2150 | Loss 0.642
Step 2160 | Loss 0.649
Step 2170 | Loss 0.653
Step 2180 | Loss 0.627
Step 2190 | Loss 0.

Step 3830 | Loss 0.479
Step 3840 | Loss 0.483
Step 3850 | Loss 0.480
Step 3860 | Loss 0.479
Step 3870 | Loss 0.488
Step 3880 | Loss 0.451
Step 3890 | Loss 0.454
Step 3900 | Loss 0.454

 the at of her other make alone, and their disturbed faces. Levin
bowed to her, and said nothing. Kitty did not speak nor lift her eyes.
"Thank God, she has refused him," thought the mother, and her f

Step 3910 | Loss 0.438
Step 3920 | Loss 0.441
Step 3930 | Loss 0.452
Step 3940 | Loss 0.447
Step 3950 | Loss 0.474
Step 3960 | Loss 0.463
Step 3970 | Loss 0.467
Step 3980 | Loss 0.466
Step 3990 | Loss 0.455
Step 4000 | Loss 0.443

 as surely in and cannot even be offended by each
other.

The Countess Nordston pounced upon Levin at once.

"Ah, Konstantin Dmitrievitch! So you've come back to our corrupt
Babylon," she said, giving

Step 4010 | Loss 0.437
Step 4020 | Loss 0.448
Step 4030 | Loss 0.459
Step 4040 | Loss 0.439
Step 4050 | Loss 0.454
Step 4060 | Loss 0.449
Step 4070 | Loss 0.462
Step 4080 | Loss 0.