# MiniBERT + Rowlang

In [1]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt


### 1. Encoder Layers

In [2]:
class EncoderLayer:
    def __init__(self, head):
        pass

### 2. MultiHead Attention

In [105]:
from functools import wraps
from abc import ABC, abstractmethod
import functools


def graph_def(f):
    '''Cache layer (w self.outputs) and autoscope (w self.name)'''
    @wraps(f)
    def wrapper(self, *args, **kwargs):
        if self.outputs is None:
            with tf.variable_scope(self.name):
                self.outputs = f(self, *args, **kwargs)
        return self.outputs
    return wrapper


class Layer(ABC):
    def __init__(self, name):
        self.name = name
        self.outputs = None
        
    @abstractmethod 
    def on(self, *args, **kwargs):
        pass
    
class LinearLayer(Layer):
    def __init__(self, out_dim, name):
        super(LinearLayer, self).__init__(name)
        self.out_dim = out_dim
    
    @graph_def
    def on(self, X):
        return tf.layers.dense(X, self.out_dim, activation=None, name=self.name)

class ScaledDotProdAttentionLayer(Layer):
    def __init__(self, scale, name):
        super(ScaledDotProdAttentionLayer, self).__init__(name)
        self.scale = scale
    @graph_def
    def on(self, Q, K, V):
        '''
        Q: queries [ minibatch x queries x dim_k]
        K: keys    [ minibatch x keys x dim_k]
        V: values  [ minibatch x keys x dim_v]
        '''
        dot = tf.einsum('mqd,mkd->mqk', Q, K, name='dot')            
        scores = tf.nn.softmax(tf.scalar_mul(self.scale, dot), name='scores') 
        A = tf.einsum('mqk,mkd->mqd', scores, V, name='a')
        return A
    
class MultiHeadAttention(Layer):
    def __init__(self, h, d_model, dropout=0.1, name="multihead"):
        '''Implement the multiheaded self attention
        '''
        super(MultiHeadAttention, self).__init__(name)
        self.h = h
        self.d_k = d_model // h
        self.d_model = d_model
        self.attentions = []
        self.A = None
        self.O = None

    @graph_def
    def on(self, X):
        Q = K = V = X # todo: separate
        for i in range(self.h):
            with tf.variable_scope("h{}".format(i)):
                q = LinearLayer(self.d_k, "q").on(Q)                
                k = LinearLayer(self.d_k, "k").on(K)                
                v = LinearLayer(self.d_k, "v").on(V)
                scale = 1 / np.sqrt(self.d_k)
                a = ScaledDotProdAttentionLayer(scale, "attn").on(q,k,v)
                self.attentions.append(a)

        self.A = tf.concat(self.attentions, axis=-1, name="A")
        self.O = LinearLayer(self.d_model, "O").on(self.A)
        return self.O
            
                
        

In [127]:
class LayerNormLayer(Layer):
    def __init__(self, name="layernorm"):
        super(LayerNormLayer, self).__init__(name)
        self._eps = 1e-6 # for numerical stability
    
    @graph_def
    def on(self, X):
        """
        X: [minibatch x seq x dims]
        """
        self.mean, self.std = tf.nn.moments(X, axes=-1, keep_dims=True)
        return (X - self.mean)/ (self.std + self._eps)
        
class DropoutLayer(Layer):
    def __init__(self, dropout, name="dropout"):
        super(DropoutLayer, self).__init__(name)
        self.dropout = dropout
    
    @graph_def
    def on(self, X):
        return tf.nn.dropout(X, 1 - self.dropout, name='dropped') 

class EncoderSubLayer(Layer):
    def __init__(self, dropout, sublayer, name, *args, **kwargs):
        super(EncoderSubLayer, self).__init__(name)
        self.dropout = dropout
        self.sublayer = sublayer(*args, **kwargs)
    
    @graph_def
    def on(self, X):
        X_n = LayerNormLayer().on(X)
        return X + DropoutLayer(self.dropout).on(self.sublayer.on(X_n))
    

        
        

In [137]:
tf.reset_default_graph()

DATA_POINTS = 4
SEQ = 3
MODEL_DIM = 5
HEADS = 3
DROPOUT = 1e-20
X = tf.placeholder(tf.float32, shape=[None, None, MODEL_DIM])
# with tf.variable_scope("multihead/"):
#     Y = tf.layers.dense(X, 30, use_bias=False)


mha = MultiHeadAttention(HEADS, MODEL_DIM )
ES = EncoderSubLayer(DROPOUT, MultiHeadAttention, "self_attn", 
                     HEADS, MODEL_DIM)

Y = ES.on(X)
Z = ES.sublayer.on(None)


init = tf.group(tf.global_variables_initializer(),
                tf.local_variables_initializer())    
with tf.Session() as sess:
    sess.run(init)
    x_d = np.random.random((DATA_POINTS, SEQ, MODEL_DIM))
    print(x_d)
    y = (sess.run(Y, feed_dict={X:x_d}))
    z = (sess.run(Z, feed_dict={X:x_d}))
    print(z)
    file_writer = tf.summary.FileWriter('./logdir', sess.graph)
    
print(y.shape)
y

[[[0.44902091 0.80281073 0.00862657 0.14515628 0.0123972 ]
  [0.13858129 0.80313088 0.58091599 0.73941974 0.57524398]
  [0.80027948 0.93625195 0.99246857 0.51971626 0.17310635]]

 [[0.22600938 0.68505564 0.46919082 0.14917206 0.1759088 ]
  [0.11683096 0.21005736 0.79661539 0.36662984 0.45335137]
  [0.43006872 0.79080767 0.6609233  0.15022437 0.89146928]]

 [[0.03756818 0.93855819 0.56245586 0.79120315 0.0049632 ]
  [0.3492735  0.02472538 0.40855029 0.21187245 0.15533361]
  [0.24403327 0.42323078 0.86115713 0.38359963 0.42857641]]

 [[0.24732274 0.25097316 0.57164709 0.08956083 0.17019009]
  [0.45170401 0.66593378 0.8022394  0.15200534 0.47501003]
  [0.16161879 0.50399882 0.2457969  0.62609678 0.91827856]]]
[[[-2.251142    0.88513637 -0.14581622  1.7560604  -1.7967244 ]
  [-2.8017209   0.73333037 -0.3401395   0.89860594 -2.2055368 ]
  [-2.9691918   0.6631591  -0.4129312   0.53556997 -2.3332217 ]]

 [[ 2.8255858  -0.51193964  0.27725762 -1.0102865   1.9443855 ]
  [ 5.2150383  -1.006813  

array([[[-1.8021212 ,  1.687947  , -0.13718966,  1.9012166 ,
         -1.7843273 ],
        [-2.6631396 ,  1.5364612 ,  0.24077648,  1.6380258 ,
         -1.6302929 ],
        [-2.1689124 ,  1.599411  ,  0.5795374 ,  1.0552862 ,
         -2.1601152 ]],

       [[ 3.0515952 ,  0.17311603,  0.74644846, -0.8611144 ,
          2.1202943 ],
        [ 5.331869  , -0.7967557 ,  1.630214  ,  0.1990729 ,
          4.602429  ],
        [ 3.9451294 , -0.6211491 ,  0.9305927 , -2.3950222 ,
          3.7868283 ]],

       [[-4.33772   , -0.80761254, -1.2534559 , -8.123856  ,
         -3.261192  ],
        [ 4.342996  , -2.696918  ,  0.7096808 , -3.970833  ,
          4.296358  ],
        [ 4.9364862 , -2.6132066 ,  1.1203446 , -4.868133  ,
          5.019191  ]],

       [[ 6.183278  , -2.0869732 ,  1.1317971 , -3.5811126 ,
          5.190774  ],
        [11.52078   , -0.4477945 ,  3.065682  ,  3.6645339 ,
          9.281338  ],
        [11.246951  , -0.6167678 ,  2.508426  ,  4.114967  ,
         