In [1]:
import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [2]:
class TransformerNetwork(object):
    '''
    This is the stripped down version of Transformer network.

    In MSAIC 2018 we have to select proper paragraphs with respect to the query passed. The idea is
    attending to the important elements in query and passages and see the similarity in each one of
    them and then decide which is appropriate one. Transformer network fits here perfectly as it
    attends to both the query and passage and it's self attention picks the most important words.

    The query vector obtained in multiple stages are then also fed into the passages and also
    improves the fidelity of the outputs. We need to perform label smoothening due to disproportionate
    distribution of nagative samples.

    ========

    [GOTO: https://stackoverflow.com/a/35688187]

    The idea is that since we have an external embedding matrix, we can still use the
    functionalities available in TF to use those embedding. This will require us to store the
    embedding matrix in memory and then assign it at runtime. Function assign_embeddings() does it.
    '''
    def __init__(self,
        scope,
        model_name,
        save_folder,
        pad_id,
        save_freq = 10,
        is_training = True,
        dim_model = 50,
        ff_mid = 128,
        ff_mid1 = 128,
        ff_mid2 = 128,
        num_stacks = 2,
        num_heads = 5):
        '''
        Args:
            scope: scope of graph
            model_name: name for model
            save_folder: folder for model saves
            pad_id: integer of <PAD>
            save_freq: frequency of saving
            is_training: bool if network is in training mode
            dim_model: same as embedding dimension
            ff_mid: dimension in middle layer of inner feed forward network
            ff_mid1: dimension in middle layer of outer feed forward network (L1)
            ff_mid2: dimension in middle layer of outer feed forward network (L2)
            num_stacks: number of stacks to use
            num_heads: number of heads in SDPA

        '''
        self.scope = scope
        self.model_name = model_name
        self.is_training = is_training
        self.save_folder = save_folder
        self.save_freq = save_freq
        self.pad_id = pad_id
    
        self.num_stacks = num_stacks
        self.num_heads = num_heads
        self.dim_model = dim_model
        self.ff_mid = ff_mid
        self.ff_mid1 = ff_mid1
        self.ff_mid2 = ff_mid2

        self.global_step = 0


    def build_model(self, emb, seqlen, batch_size = 32, print_stack = False):
        '''
        function to build the model end to end
        '''
        self.batch_size = batch_size
        self.seqlen = seqlen
        self.print_stack = print_stack

        with tf.variable_scope(self.scope):
            # declaring the placeholders
            self.query_input = tf.placeholder(tf.int32, [self.batch_size, self.seqlen], name = 'query_placeholder')
            self.passage_input = tf.placeholder(tf.int32, [self.batch_size, self.seqlen], name = 'passage_placeholder')
            self.target_input = tf.placeholder(tf.float32, [self.batch_size, 1], name = 'target_placeholder')
            
            # embedding matrix placeholder
            self.embedding_matrix = tf.constant(emb, name = 'embedding_matrix', dtype = tf.float32)
            
            if self.print_stack:
                print('[!] Building model...')
                print('[*] self.query_input:', self.query_input)
                print('[*] self.passage_input:', self.passage_input)
                print('[*] self.target_input:', self.target_input)
                print('[*] embedding_matrix:', self.embedding_matrix)

            # now we need to add the padding in the computation graph
            # masking
            query_mask = self.construct_padding_mask(self.query_input)   
            passage_mask = self.construct_padding_mask(self.passage_input)
            
            if self.print_stack:
                print('[*] query_mask:', query_mask)
                print('[*] passage_mask:', passage_mask)
            
            # lookup from embedding matrix
            query_emb = self.get_embedding(self.embedding_matrix, self.query_input)
            passage_emb = self.get_embedding(self.embedding_matrix, self.passage_input)
            
            if self.print_stack:
                print('[*] query_emb:', query_emb)
                print('[*] passage_emb:', passage_emb)
            
            # perform label smoothening on the labels
            # label_smooth = self.label_smoothning(self.target_input)
            label_smooth = self.target_input
            
            # model
            q_out = query_emb
            p_out = passage_emb
            for i in range(self.num_stacks):
                q_out = self.query_stack(q_in = q_out, mask = query_mask, scope = 'q_stk_{0}'.format(i))
                if self.print_stack:
                    print('[*] q_out ({0}):'.format(i), q_out)
                p_out = self.passage_stack(p_in = p_out, q_out = q_out,
                    query_mask = query_mask, passage_mask = passage_mask, scope = 'p_stk_{0}'.format(i))
                if self.print_stack:
                    print('[*] p_out ({0})'.format(i), p_out)

            # now the custom part
            ff_out = tf.layers.dense(p_out, self.ff_mid1, activation = tf.nn.relu) # (batch_size, seqlen, emb_dim)
            ff_out = tf.layers.dense(ff_out, 1, activation = tf.nn.relu) # (batch_size, seqlen, 1)
            ff_out_reshaped = tf.reshape(ff_out, [-1, seqlen]) # (batch_size, seqlen)
            self.pred = tf.layers.dense(ff_out_reshaped, 1) # (batch_size, 1)
                
            if not self.is_training:
                self.pred = tf.sigmoid(self.pred) # (batch_size, 1)
                
            if self.print_stack:
                print('[*] predictions:', self.pred)

            # loss and accuracy
            self._accuracy = tf.reduce_sum(
                tf.cast(tf.equal(self.pred, self.target_input), tf.float32)
                ) / self.batch_size

            self._loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels = label_smooth, logits = self.pred)
                )

            optim = tf.train.AdamOptimizer(beta1 = 0.9, beta2 = 0.98, epsilon = 1e-9)
            self._train = optim.minimize(self._loss)
            
            if self.print_stack:
                print('[*] accuracy:', self._accuracy)
                print('[*] loss:', self._loss)
                print('... Done!')

        with tf.variable_scope(self.model_name + "_summary"):
            tf.summary.scalar("loss", self._loss)
            tf.summary.scalar("accuracy", self._accuracy)
            self.merged_summary = tf.summary.merge_all()

    '''
    NETWORK FUNCTIONS
    =================

    Following functions were placed outside this file with an aim to increase the
    code value but is causing several issues, especially with the config file
    redundancy. So putting them here and increasing the model simplicity but 
    complicating the codebase.
    '''

    ##### OPERATIONAL LAYERS #####

    def get_embedding(self, emb, inp):
        '''
        get embeddings
        '''
        return tf.nn.embedding_lookup(emb, inp)

    ##### CORE LAYERS #####

    def sdpa(self, Q, K, V, mask = None):
        '''
        Scaled Dot Product Attention
        q_size = k_size = v_size
        Args:
            Q:    (num_heads * batch_size, q_size, d_model)
            K:    (num_heads * batch_size, k_size, d_model)
            V:    (num_heads * batch_size, v_size, d_model)
            mask: (num_heads * batch_size, q_size, d_model)
        '''

        qkt = tf.matmul(Q, tf.transpose(K, [0, 2, 1]))
        qkt /= tf.sqrt(float(self.dim_model // self.num_heads))

        # perform masking
        qkt = tf.multiply(qkt, mask) + (1.0 - mask) * (-1e10)

        soft = tf.nn.softmax(qkt) # (num_heads * batch_size, q_size, k_size)
        soft = tf.layers.dropout(soft, training = self.is_training)
        out = tf.matmul(soft, V) # (num_heads * batch_size, q_size, d_model)

        return out

    def multihead_attention(self, query, key, value, mask = None, scope = 'attn'):
        '''
        Multihead attention with masking option
        q_size = k_size = v_size = d_model/num_heads
        Args:
            query: (batch_size, q_size, d_model)
            key:   (batch_size, k_size, d_model)
            value: (batch_size, v_size, d_model)
            mask:  (batch_size, q_size, d_model)
        '''
        with tf.variable_scope(scope):
            # linear projection blocks
            # print(query)
            Q = tf.layers.dense(query, self.dim_model, activation = tf.nn.relu)
            K = tf.layers.dense(key, self.dim_model, activation = tf.nn.relu)
            V = tf.layers.dense(value, self.dim_model, activation = tf.nn.relu)

            # split the matrix into multiple heads and then concatenate them to get
            # a larger batch size: (num_heads, q_size, d_model/nume_heads)
            Q_reshaped = tf.concat(tf.split(Q, self.num_heads, axis = 2), axis = 0)
            K_reshaped = tf.concat(tf.split(K, self.num_heads, axis = 2), axis = 0)
            V_reshaped = tf.concat(tf.split(V, self.num_heads, axis = 2), axis = 0)
            mask = tf.tile(mask, [self.num_heads, 1, 1])

            # scaled dot product attention
            sdpa_out = self.sdpa(Q_reshaped, K_reshaped, V_reshaped, mask)
            out = tf.concat(tf.split(sdpa_out, self.num_heads, axis = 0), axis = 2)

            # final linear layer
            out_linear = tf.layers.dense(out, self.dim_model)
            out_linear = tf.layers.dropout(out_linear, training = self.is_training)

        return out_linear

    def feed_forward(self, x, scope = 'ff'):
        '''
        Position-wise feed forward network, applied to each position seperately
        and identically. Can be implemented as follows
        '''
        with tf.variable_scope(scope):
            out = tf.layers.conv1d(x, filters = self.ff_mid, kernel_size = 1,
                activation = tf.nn.relu)
            out = tf.layers.conv1d(out, filters = self.dim_model, kernel_size = 1)

        return out

    def layer_norm(self, x):
        '''
        perform layer normalisation
        '''
        out = tf.contrib.layers.layer_norm(x, center = True, scale = True)
        return out

    def label_smoothning(self, x):
        '''
        perform label smoothning on the input label
        '''
        smoothed = (1.0 - self.ls_epsilon) * x + (self.ls_epsilon / vocab_size)
        return smoothed

    ###### STACKS ######

    def query_stack(self, q_in, mask, scope):
        '''
        Single query stack 
        Args:
            q_in: (batch_size, seqlen, embed_size)
            mask: (batch_size, seqlen, seqlen)
        '''
        with tf.variable_scope(scope):
            out = self.layer_norm(q_in + self.multihead_attention(q_in, q_in, q_in, mask))
            out = self.layer_norm(out + self.feed_forward(out))

        return out

    def passage_stack(self, p_in, q_out, query_mask, passage_mask, scope):
        '''
        Single passage stack
        Args:
            p_in: (batch_size, seqlen, embed_size)
            q_out: output from query stack
        '''
        with tf.variable_scope(scope):
            out = self.layer_norm(p_in + self.multihead_attention(p_in, p_in, p_in, mask = passage_mask))
            out = self.layer_norm(out + self.multihead_attention(out, out, q_out, mask = query_mask, scope = 'attn2'))
            out = self.layer_norm(out + self.feed_forward(out))

        return out

    def construct_padding_mask(self, inp):
        '''
        Args:
            inp: Original input of word ids, shape: [batch_size, seqlen]
        Returns:
            a mask of shape [batch_size, seqlen, seqlen] where <pad> is 0 and others are 1
        '''
        seqlen = inp.shape.as_list()[1]
        mask = tf.cast(tf.not_equal(inp, self.pad_id), tf.float32)
        mask = tf.tile(tf.expand_dims(mask, 1), [1, seqlen, 1])
        return mask
    
    def print_network(self):
        '''
        Print the network in terms of 
        '''
        network_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = self.scope)
        for x in network_variables:
            print(x)

In [3]:
network = TransformerNetwork(scope = 'trial1',
                             model_name = 'trial1',
                             save_folder = './trial',
                             pad_id = int(12))

# embeddings
emb = np.random.rand(12, 50)

# build network
network.build_model(emb = emb, seqlen = 12, print_stack = False)
network.print_network()

<tf.Variable 'trial1/q_stk_0/attn/dense/kernel:0' shape=(50, 50) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/attn/dense/bias:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/attn/dense_1/kernel:0' shape=(50, 50) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/attn/dense_1/bias:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/attn/dense_2/kernel:0' shape=(50, 50) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/attn/dense_2/bias:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/attn/dense_3/kernel:0' shape=(50, 50) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/attn/dense_3/bias:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/LayerNorm/beta:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/LayerNorm/gamma:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/ff/conv1d/kernel:0' shape=(1, 50, 128) dtype=float32_ref>
<tf.Variable 'trial1/q_stk_0/ff/conv1d/bias:0' shape=(128,) dtype=float32_ref>
<tf.Variable 

In [7]:
x = tf.placeholder(tf.float32, [32, 50], name = 'test_var')

In [8]:
x_reshape = tf.split(x, 5, axis = 1)
print(x_reshape)
x_reshape = tf.concat(x_reshape, axis = 0)
print(x_reshape)

[<tf.Tensor 'split:0' shape=(32, 10) dtype=float32>, <tf.Tensor 'split:1' shape=(32, 10) dtype=float32>, <tf.Tensor 'split:2' shape=(32, 10) dtype=float32>, <tf.Tensor 'split:3' shape=(32, 10) dtype=float32>, <tf.Tensor 'split:4' shape=(32, 10) dtype=float32>]
Tensor("concat:0", shape=(160, 10), dtype=float32)
