In [1]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, Embedding, Lambda, TimeDistributed, \
                                    Add, Conv1D, Dropout, Concatenate, Activation
from tensorflow.keras.models import Model

from src.multihead import *

In [2]:
import numpy as np
import json
from tqdm import tqdm

In [3]:
!pwd

/home/sweet/1-workdir/NLP_attention/en_vi_attention_nlp/src


In [4]:
db_dir = "/home/sweet/1-workdir/NLP_attention/en_vi_data_preprocess/src/"
db_file = "train-test.json"
dict_file = "dictionary.json"

train_X = []
train_Y = []

test_X = []
test_Y = []

with open(db_dir + db_file, 'r') as f_db, open(db_dir + dict_file, 'r') as f_dict:
    db = json.load(f_db)
    dictionary = json.load(f_dict)
    
train_X = db['train_X']
train_Y = db['train_Y']
test_X = db['test_X']
test_Y = db['test_Y']

dictionary_from = dictionary['from']['dictionary']
rev_dictionary_from = dictionary['from']['rev_dictionary']

dictionary_to = dictionary['to']['dictionary']
rev_dictionary_to = dictionary['to']['rev_dictionary']

In [5]:
GO = dictionary_from['GO']
PAD = dictionary_from['PAD']
EOS = dictionary_from['EOS']
UNK = dictionary_from['UNK']

In [6]:
for i in tqdm(range(len(train_X))):
    train_X[i] += ' EOS'

100%|██████████| 133317/133317 [00:00<00:00, 2354864.63it/s]


In [7]:
for i in tqdm(range(len(test_X))):
    test_X[i] += ' EOS'

100%|██████████| 2821/2821 [00:00<00:00, 2150123.86it/s]


In [8]:
def get_positional_encoding_matrix(length, d_model):
    pe = np.zeros((length, d_model), dtype=np.float32)
    positions = np.arange(length, dtype=np.float32)
    denom = np.power(10000.0, np.arange(0, d_model, 2, np.float32) / d_model)

    for i in range(d_model):
        if i % 2 == 0:
            pe[:,i] = np.sin(positions/denom[i//2])
        else:
            pe[:,i] = np.cos(positions/denom[i//2])
    
    return pe

In [9]:
def get_pos_seq(x, null_token_value=0):
    mask = K.cast(K.not_equal(x, null_token_value), 'float32')
    pos = K.cumsum(K.ones_like(x, 'float32'), 1)
    return pos * mask

In [10]:
def get_loss(args, null_token_value):
    y_pred, y_true = args

    y_true_id = K.cast(y_true, "int32")

    mask = K.cast(K.equal(y_true_id, null_token_value), K.floatx())
    mask = 1.0 - mask
    loss = K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True) * mask

    # take average w.r.t. the number of unmasked entries
    return K.sum(loss) / K.sum(mask)

In [11]:
def get_accuracy(args, null_token_value):
    y_pred, y_true = args

    y_true = K.cast(y_true, "int32")
    mask = 1.0 - K.cast(K.equal(y_true, null_token_value), K.floatx())

    y_pred = K.cast(K.argmax(y_pred, axis=-1), "int32")
    correct = K.cast(
        K.equal(y_pred, y_true),
        K.floatx()
    )
    correct = K.sum(correct * mask, -1) / K.sum(mask, -1)

    return K.mean(correct)

In [12]:
class PositionWiseFeedForward(object):
    # def __init__(self, d_model=512, d_ff=2048, **kwargs):
    def __init__(self, d_model=512, d_ff=512, **kwargs):
        self._d_model = d_model
        self._d_ff = d_ff

        self._conv1 = Conv1D(self._d_ff, kernel_size=1, activation="relu")
        self._conv2 = Conv1D(self._d_model, kernel_size=1)
    
    def __call__(self, x):
        intermediate_x = self._conv1(x)
        return self._conv2(intermediate_x)

In [13]:
class EncoderLayer(object):
    def __init__(self, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048):
        self._mha = MultiHeadAttention(n_head=h, d_k=d_k, d_v=d_v, d_model=d_model)
        self._ln_a = LayerNormalization()
        self._psfw = PositionWiseFeedForward(d_model=d_model, d_ff=d_inner_hid)
        self._ln_b = LayerNormalization()
        self._add_a = Add()
        self._add_b = Add()
        
    def __call__(self, x):
        y = self._mha(x, x, x)
        y = self._add_a([x, y])
        x = self._ln_a(y)
        
        y = self._psfw(x)
        y = self._add_b([x, y])
        x = self._ln_b(y)
        
        return x   

class Encoder(object):
	def __init__(self, embedding, position_embedding, 
                 n=6, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048, null_token_value=0):
		self._embedding = embedding
		self._position_embedding = position_embedding
		self._n = n
		self._position_encoding = Lambda(get_pos_seq, arguments={"null_token_value": null_token_value})
		
		self._layers = [EncoderLayer(h=h, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=d_inner_hid) for _ in range(n)]
	
	def __call__(self, x):
		x_embedded = self._embedding(x)
		pos_encoding = self._position_encoding(x)
		pos_encoding_embedded = self._position_embedding(pos_encoding)
		x = Add()([x_embedded, pos_encoding_embedded])
		
		for layer in self._layers:
			x = layer(x)
			
		return x

In [14]:
class DecoderLayer(object):
	def __init__(self, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048, return_attention=True):
		self._mha_a = MultiHeadAttention(n_head=h, d_k=d_k, d_v=d_v, d_model=d_model, return_attention=return_attention)
		self._mha_b = MultiHeadAttention(n_head=h, d_k=d_k, d_v=d_v, d_model=d_model, return_attention=return_attention)
		self._psfw = PositionWiseFeedForward(d_model=d_model, d_ff=d_inner_hid)
		self._ln_a = LayerNormalization()
		self._ln_b = LayerNormalization()
		self._ln_c = LayerNormalization()
		self._add_a = Add()
		self._add_b = Add()
		self._add_c = Add()
		self._return_attention = return_attention
		
	def __call__(self, x, encoder_output):
		y, self_atn = self._mha_a(x, x, x)
		y = self._add_a([x, y])
		x = self._ln_a(y)
		
		y, enc_atn = self._mha_b(x, encoder_output, encoder_output)
		y = self._add_b([x, y])
		x = self._ln_b(y)
		
		y = self._psfw(x)
		y = self._add_c([x, y])
		x = self._ln_c(y)
		
		if self._return_attention:
			return [x, self_atn, enc_atn]
		else:
			return x 

class Decoder(object):
	def __init__(self, embedding, position_embedding, 
                 n=6, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048, null_token_value=0):
		self._embedding = embedding
		self._position_embedding = position_embedding
		self._n = n
		self._position_encoding = Lambda(get_pos_seq, arguments={"null_token_value": null_token_value})
		
		self._layers = [DecoderLayer(h=h, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=d_inner_hid) for _ in range(n)]
	
	def __call__(self, x, encoder_output, return_attention=False):
		x_embedded = self._embedding(x)
		pos_encoding = self._position_encoding(x)
		pos_encoding_embedded = self._position_embedding(pos_encoding)
		x = Add()([x_embedded, pos_encoding_embedded])

		self_atts = []
		enc_atts = []

		for layer in self._layers:
			x, self_att, enc_att = layer(x, encoder_output)

			if return_attention: 
				self_atts.append(self_att)
				enc_atts.append(enc_att)
		 
		if return_attention: 
			return [x, self_atts, enc_atts]
		else:
			return x

In [15]:
num_encoders=6
num_multi_heads=8
d_k=64
d_v=64
d_model=512
optimizer="adam"
null_token_value=0
source_vocab_size = len(dictionary_from)
target_vocab_size = len(dictionary_to)
share_word_embedding=False
MAXIMUM_TEXT_LENGTH = 250

In [16]:
# Build transformer

# define some placeholders for source and target
source_input = Input(shape=(None,), name="source_input")
target_input = Input(shape=(None,), name="target_input")

# define some placeholders for encoder's input, decoder's input, and decoder's output 
enc_input = Lambda(lambda x:x[:,1:])(source_input)
dec_input  = Lambda(lambda x:x[:,:-1])(target_input)
dec_target_output = Lambda(lambda x:x[:,1:])(target_input)

In [17]:
# create embedding

# weights=[_get_positional_encoding_matrix(max_length, d_model)]
source_word_embedding = Embedding(source_vocab_size, d_model, name="source_embedding")

if share_word_embedding:
    target_word_embedding = source_word_embedding
else:
    target_word_embedding = Embedding(target_vocab_size, d_model, name="target_embedding")
    
# embedding for the position encoding
pos_enc_mat = get_positional_encoding_matrix(MAXIMUM_TEXT_LENGTH, d_model)
position_encoding = Embedding(MAXIMUM_TEXT_LENGTH, d_model, trainable=False, 
                              weights=[pos_enc_mat], 
                              name="position_embedding")

In [18]:
enc = Encoder(source_word_embedding, position_encoding, 
              n=num_encoders, h=num_multi_heads, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=512)
dec = Decoder(target_word_embedding, position_encoding, 
              n=num_encoders, h=num_multi_heads, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=512)

In [20]:
enc_output = enc(enc_input)
dec_output = dec(dec_input, enc_output)

# lin_dense = TimeDistributed(Dense(d_model))
fin_output = TimeDistributed(Dense(target_vocab_size, activation=None, use_bias=False), name="output") # "softmax"

# lin_dense_out = lin_dense(dec_output)
fin_output_out = fin_output(dec_output) # lin_dense_out)

accuracy = Lambda(get_accuracy, arguments={"null_token_value": null_token_value})([fin_output_out, dec_target_output])
loss = Lambda(get_loss, arguments={"null_token_value": null_token_value})([fin_output_out, dec_target_output])

train_model = Model(inputs=[source_input, target_input], outputs=loss)
train_model.add_loss([loss])
train_model.compile(optimizer, None)
train_model.metrics_names.append('accuracy')
train_model.metrics.append(accuracy)

inference_model = Model([source_input, target_input], fin_output_out)

In [21]:
train_model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
target_input (InputLayer)       [(None, None)]       0                                            
__________________________________________________________________________________________________
source_input (InputLayer)       [(None, None)]       0                                            
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None)         0           target_input[0][0]               
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None)         0           source_input[0][0]               
____________________________________________________________________________________________