In [1]:
import tensorflow as tf
import numpy as np

class SelfAttenModel(object):
    
    def __init__(self,
                 batch_size=40, 
                 vocab_size=200,
                 hidden_size=2000,
                 label_num=4,
                 layer_num=2, 
                 embedding_size=100, 
                 keep_prob=0.8, 
                 max_sequence_length=10,
                 num_units=128,
                 d_a=350,
                 r=30,learning_rate=0.01):
        
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.label_num = label_num
        self.layer_num = layer_num
        self.embedding_size = embedding_size
        self.keep_prob = keep_prob
        self.n = self.max_sequence_length = max_sequence_length
        self.u = self.num_units = num_units
        self.d_a = d_a
        self.r = r
        self.learning_rate = learning_rate
        
        self._build_placeholder()
        self._build_model()
        self._build_optimizer()
            
    def _build_placeholder(self):
        self.sources = tf.placeholder(name='sources', shape=[self.batch_size, self.max_sequence_length], dtype=tf.int64)
        self.labels = tf.placeholder(name='labels', shape=[self.batch_size], dtype=tf.int64)
        

    def _build_single_cell(self):
        cell = tf.contrib.rnn.BasicLSTMCell(self.num_units)
        cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self.keep_prob)
        return cell
    
    def _build_model(self):
        # Word embedding #
        with tf.variable_scope("embedding"):
            initializer = tf.contrib.layers.xavier_initializer()
            embeddings = tf.get_variable(name="embedding_encoder",
                                                shape=[self.vocab_size, self.embedding_size], 
                                                dtype=tf.float32,
                                                initializer=initializer,
                                                trainable=True)

            input_embeddings = tf.nn.embedding_lookup(params=embeddings,
                                                      ids=self.sources)

        # Bidirectional rnn #
        with tf.variable_scope("bidirectional_rnn"):
            cell_forward = self._build_single_cell()
            cell_backward = self._build_single_cell()
            
            # outputs is state 'H'
            outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_forward, 
                                                              cell_bw=cell_backward, 
                                                              inputs=input_embeddings,
                                                              dtype=tf.float32)
            
            H = tf.concat(outputs, -1)
            
        # Self Attention #
        with tf.variable_scope("self_attention"):
            initializer = tf.contrib.layers.xavier_initializer()
            W_s1 = tf.get_variable(name="W_s1", shape=[self.d_a, 2*self.u], initializer=initializer)
            W_s2 = tf.get_variable(name='W_s2', shape=[self.r, self.d_a],initializer=initializer)
            
            a_prev = tf.map_fn(lambda x: tf.matmul(W_s1, tf.transpose(x)), H)
            a_prev = tf.tanh(a_prev)
            a_prev = tf.map_fn(lambda x: tf.matmul(W_s2, x), a_prev)
            
            self.A = tf.nn.softmax(a_prev)
            self.M = tf.matmul(self.A, H)
        
        # Fully connected layer #
        with tf.variable_scope("fully_connected_layer"):
            input_fc = tf.layers.flatten(self.M)
            layer_fc = tf.contrib.layers.fully_connected(inputs=input_fc, 
                                                         num_outputs=self.hidden_size,
                                                         activation_fn=tf.nn.relu)
            
            self.logits = tf.contrib.layers.fully_connected(inputs=layer_fc, 
                                                            num_outputs=self.label_num,
                                                            activation_fn=None)
            
            
            
    def _build_optimizer(self):
        with tf.variable_scope("optimizer"):
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits,
                                                                           labels=self.labels)
            self.loss = tf.reduce_mean(cross_entropy)
            self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
            self.optimizer = self.optimizer.minimize(self.loss)

            correct_pred = tf.equal(tf.argmax(self.logits, -1), self.labels)
            self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
            
            
            
def main():
    model = SelfAttenModel()
    
    with tf.Session() as sess:
        
        # Run the initializer
        sess.run(tf.global_variables_initializer())
        
        training_steps = 40000
        display_step = 200
    
        for step in range(len(training_steps)):
            # Run optimization op (backprop)
            sess.run(model.optimizer, feed_dict={model.sources: batch_sources, 
                                                 model.labels: batch_labels})

            if (step % display_step) == 0 or step == 1:
                # Calculate batch accuracy & loss
                loss, accuracy = sess.run([model.loss, model.accuracy], feed_dict={model.sources: batch_sources, 
                                                                                   model.labels: batch_labels})

                print("Step " + str(step * batch_size) + ", Minibatch Loss= " + \
                      "{:.6f}".format(loss) + ", Training Accuracy= " + \
                      "{:.5f}".format(accuracy))

                break;

        print("Optimization Finished!")
        print("Testing Accuracy:", sess.run(model.accuracy, feed_dict={model.sources: batch_sources, 
                                                                       model.labels: batch_labels}))
    
if __name__ == '__main__':
    main()
    