# MiniBERT + Rowlang

In [1]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt


### 1. Base & Utility Layers

In [2]:
from functools import wraps
from abc import ABC, abstractmethod
from collections import namedtuple


def graph_def(f):
    '''Cache layer (w self.outputs) and autoscope (w self.name)'''
    @wraps(f)
    def wrapper(self, *args, **kwargs):
        if self.outputs is None:
            with tf.variable_scope(self.name):
                self.outputs = f(self, *args, **kwargs)
        return self.outputs
    return wrapper

class Layer(ABC):
    def __init__(self, name):
        self.name = name
        self.outputs = None
        
    @abstractmethod 
    def on(self, *args, **kwargs):
        pass
    
    
class LinearLayer(Layer):
    def __init__(self, out_dim, name):
        super(LinearLayer, self).__init__(name)
        self.out_dim = out_dim
    
    @graph_def
    def on(self, X):
        return tf.layers.dense(X, self.out_dim, activation=None, name=self.name)

class DropoutLayer(Layer):
    def __init__(self, dropout, name="dropout"):
        super(DropoutLayer, self).__init__(name)
        self.dropout = dropout
    
    @graph_def
    def on(self, X):
        return tf.nn.dropout(X, 1 - self.dropout, name='dropped')   

class LayerNormLayer(Layer):
    def __init__(self, name="layernorm"):
        super(LayerNormLayer, self).__init__(name)
        self._eps = 1e-6 # for numerical stability
        self.mean = None
        self.std = None
    
    @graph_def
    def on(self, X):
        """
        X: [minibatch x seq x dims]
        """
        self.mean, self.std = tf.nn.moments(X, axes=-1, keep_dims=True)
        return (X - self.mean)/ (self.std + self._eps)
  

### 2. MultiHead Attention & FeedFwd

In [3]:
class FeedForwardLayer(Layer):
    def __init__(self, dropout, d_model, d_ff, name="f_fwd"):
        super(FeedForwardLayer, self).__init__(name)
        
        self.d_ff = d_ff
        self.linear1 = LinearLayer(self.d_ff, "ff_up")
        self.linear2 = LinearLayer(d_model, "ff_down")
        self.dropout = DropoutLayer(dropout)
        
    @graph_def
    def on(self, X):
        relu_d = tf.nn.relu(self.linear1.on(X), name="relu")
        return self.linear2.on(self.dropout.on(relu_d))
           
class ScaledDotProdAttentionLayer(Layer):
    def __init__(self, scale, dropout, name):
        super(ScaledDotProdAttentionLayer, self).__init__(name)
        self.scale = scale
        self.scores = None
        self.dot = None
        self.dropout = DropoutLayer(dropout)
        
    @graph_def
    def on(self, Q, K, V):
        '''
        Q: queries [ minibatch x queries x dim_k]
        K: keys    [ minibatch x keys x dim_k]
        V: values  [ minibatch x keys x dim_v]
        '''
        self.dot = tf.einsum('mqd,mkd->mqk', Q, K, name='dot')            
        self.scores = tf.nn.softmax(self.scale * self.dot, name='scores') 
        dropped_scores = self.dropout.on(self.scores)
        A = tf.einsum('mqk,mkd->mqd', dropped_scores, V, name='a')
        return A
    
class MultiHeadAttention(Layer):
    def __init__(self, dropout, d_model, h, name="multihead"):
        '''Implement the multiheaded self attention
        '''
        super(MultiHeadAttention, self).__init__(name)
        self.h = h
        self.d_k = d_model // h
        self.d_model = d_model
        self.scale = 1 / np.sqrt(self.d_k)
        
        self.attentions = []
        self.heads = []
        Head = namedtuple("Head", ["to_q", "to_k", "to_v", "attn"])        
        for i in range(h):
            q = LinearLayer(self.d_k, "q")
            k = LinearLayer(self.d_k, "k")
            v = LinearLayer(self.d_k, "v")
            attn = ScaledDotProdAttentionLayer(self.scale, dropout, "attn")
            self.heads.append(Head(q, k, v, attn))
        self.A = None
        self.out_layer = LinearLayer(self.d_model, "O")

    @graph_def
    def on(self, X):
        for i, h in enumerate(self.heads):
            with tf.variable_scope("h{}".format(i)):
                q = h.to_q.on(X)                
                k = h.to_k.on(X)                
                v = h.to_v.on(X)
                a = h.attn.on(q,k,v)
                self.attentions.append(a)

        self.A = tf.concat(self.attentions, axis=-1, name="A")
        return self.out_layer.on(self.A)

### 3. Composite Encoder Layers

In [4]:
class EncoderSubLayer(Layer):
    def __init__(self, sublayer, dropout, d_model,  name, *args, **kwargs):
        super(EncoderSubLayer, self).__init__(name)
        self.dropout = DropoutLayer(dropout)
        self.sublayer = sublayer(dropout, d_model, *args, **kwargs)
        self.layer_norm = LayerNormLayer()
    
    @graph_def
    def on(self, X):
        return X + self.dropout.on(self.sublayer.on(self.layer_norm.on(X)))
    
class EncoderLayer(Layer):
    def __init__(self, dropout, d_model, heads, d_ff, name):
        super(EncoderLayer, self).__init__(name)
        self.mha_sublayer = EncoderSubLayer(MultiHeadAttention, dropout, d_model, 
                                            "self_attn", heads)
        self.ffwd_sublayer = EncoderSubLayer(FeedForwardLayer, dropout, d_model, 
                                             "feed_fwd", d_ff)
    @graph_def    
    def on(self, X):
        return self.ffwd_sublayer.on(self.mha_sublayer.on(X))        

### 4. Final Model

In [8]:
class BERTModel(Layer):
    def __init__(self, layers, dropout, d_model, heads, d_ff, name="BERT"):
        super(BERTModel, self).__init__(name)
        self.layers = []
        for i in range(layers):
            el = EncoderLayer(dropout, d_model, heads, d_ff, "layer{}".format(i))
            self.layers.append(el)
    
    @graph_def
    def on(self, X):
        for layer in self.layers:
            X = layer.on(X)
        return X
            

In [7]:
DATA_POINTS = 4
SEQ = 3
MODEL_DIM = 5
MODEL_FF = 12
HEADS = 3
LAYERS = 6
DROPOUT = 0.1



tf.reset_default_graph()

X = tf.placeholder(tf.float32, shape=[None, None, MODEL_DIM])
d = tf.placeholder(tf.float32, shape=tuple())

bert = BERTModel(LAYERS, d, MODEL_DIM, HEADS, MODEL_FF)

Y = bert.on(X)
init = tf.group(tf.global_variables_initializer(),
                tf.local_variables_initializer())    
with tf.Session() as sess:
    sess.run(init)
    x_d = np.random.random((DATA_POINTS, SEQ, MODEL_DIM))
    print(x_d)
    y = (sess.run(Y, feed_dict={X:x_d , d:0.1}))
    file_writer = tf.summary.FileWriter('./logdir', sess.graph)
    
print(y.shape)
y

[[[0.86631261 0.21646375 0.89855265 0.69724927 0.34142794]
  [0.93314301 0.97796753 0.14437069 0.52645811 0.86292791]
  [0.32033719 0.92523595 0.88753367 0.17552919 0.89574283]]

 [[0.07789924 0.56715091 0.49229367 0.47560545 0.73355563]
  [0.5279176  0.65726757 0.0765989  0.90232849 0.88114423]
  [0.06543972 0.31373356 0.24131088 0.67517648 0.192631  ]]

 [[0.71670188 0.08114562 0.22180622 0.17892    0.99999689]
  [0.30849568 0.77699325 0.88885983 0.0181805  0.69846696]
  [0.43597488 0.57099396 0.18301371 0.48416136 0.07300577]]

 [[0.57180019 0.90015768 0.89639054 0.42244088 0.6499929 ]
  [0.06288385 0.8844261  0.09732825 0.49169101 0.21321628]
  [0.87738454 0.45994459 0.69213015 0.941501   0.62377234]]]
(4, 3, 5)


array([[[ 0.9131644 ,  1.1837215 , -3.9401248 , -1.6667563 ,
          4.174741  ],
        [ 2.306495  ,  0.15267938,  5.5956373 , -0.3653612 ,
          0.08902161],
        [ 0.94705117,  2.6214416 , -5.46303   , -1.6032618 ,
          4.207371  ]],

       [[ 2.3319218 ,  3.0767753 ,  3.67152   ,  0.20346087,
         -3.693404  ],
        [ 3.1243339 ,  0.8846858 ,  7.0291624 ,  1.7395796 ,
         -1.5512298 ],
        [ 0.1226456 ,  6.291777  ,  1.2744232 ,  0.33477584,
         -0.20087352]],

       [[-8.546267  ,  3.0947614 , -0.5664164 , 13.124961  ,
         -2.4376817 ],
        [-8.39151   ,  3.820796  , -5.1076126 , 10.065827  ,
          1.4659433 ],
        [ 4.2532444 , -3.639643  , 12.330132  ,  0.28948504,
          0.72488177]],

       [[-3.4543557 ,  5.0694294 ,  0.6532984 ,  5.4662213 ,
         -3.150346  ],
        [ 0.07833236, -0.89091146, 13.4288    ,  5.837252  ,
         -5.912517  ],
        [-2.3734071 ,  1.5151955 , -1.9402819 ,  4.287184  ,
         