In [None]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 200
ffn_hidden = 2048
num_layers = 5

In [None]:
def scaled_dot_product_attention(q,k,v,mask=None):
  d_k = q.shape[-1]
  scaled = tf.linalg.matmul(q,tf.transpose(k, perm=[0,1,3,2])) / tf.math.sqrt(tf.cast(d_k, tf.float32))
  if mask is not None:
    scaled = scaled + mask

  attention = tf.nn.softmax(scaled, axis=-1)
  out = tf.linalg.matmul(attention,v)
  return out, attention

class MultiheadAttention(tf.keras.layers.Layer):

  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model
    self.head_dim = d_model//num_heads
    self.num_heads = num_heads
    self.qkv_layer = tf.keras.layers.Dense(units=3*d_model, input_shape=[d_model,1], activation=None)
    # self.qkv_layer = self.add_weight(shape=[input_dim,3*d_model], trainable=True)
    self.linear_layer = tf.keras.layers.Dense(units=d_model, input_shape=[d_model,1], activation=None)
    # self.linear_layer = self.add_weight(shape=[d_model,d_model], trainable=True)

  def call(self,x,mask=None):
    batch_size, sequence_length, d_model = x.shape
    qkv = self.qkv_layer(x)
    qkv = tf.reshape(qkv, [batch_size, sequence_length, self.num_heads, 3*self.head_dim])
    qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
    q,k,v = tf.split(qkv, 3, axis=-1)
    values, attention = scaled_dot_product_attention(q,k,v,mask)
    values = tf.reshape(tf.transpose(values, perm=[0, 2, 1, 3]), [batch_size, sequence_length, self.num_heads*self.head_dim])
    out = self.linear_layer(values)
    return out

class LayerNormalization(tf.keras.layers.Layer):
  def __init__(self, parameters_shape, eps=1e-5):
    super().__init__()
    self.parameters_shape = parameters_shape
    self.eps = eps
    self.gamma = self.add_weight(shape=parameters_shape, initializer=tf.keras.initializers.Constant(value=1),trainable=True)
    self.beta = self.add_weight(shape=parameters_shape, initializer=tf.keras.initializers.Constant(value=0),trainable=True)

  def call(self, input):
    dims = [-(i+1) for i in range(len(self.parameters_shape))]
    mean = tf.math.reduce_mean(input, axis = dims, keepdims=True)
    var = tf.math.reduce_mean(((input - mean) ** 2), axis = dims, keepdims=True)
    std = tf.math.sqrt((var + self.eps))
    y = (input - mean) / std
    out = self.gamma * y + self.beta
    return out

class PositionalEncoding(tf.keras.layers.Layer):
  def __init__(self):
    super().__init__()

  def call(self, d_model, max_sequence_length):
    d_model = tf.cast(d_model, tf.float64)
    max_sequence_length = tf.cast(max_sequence_length, tf.float64)
    even_i = tf.experimental.numpy.arange(0, stop=d_model, step=2, dtype=float)
    denominator = tf.math.pow(tf.cast(10000, tf.float64), even_i/d_model)
    position = tf.reshape(tf.experimental.numpy.arange(0,stop=max_sequence_length, dtype=denominator.dtype), [max_sequence_length,1])
    even_PE = tf.math.sin(position/denominator)
    odd_PE = tf.math.cos(position/denominator)
    stacked = tf.stack([even_PE,odd_PE], axis=2)
    PE = tf.keras.layers.Flatten()(stacked)
    return PE

class PositionwiseFeedForward(tf.keras.layers.Layer):
  def __init__(self,d_model, hidden, drop_prob):
    super(PositionwiseFeedForward, self).__init__()
    self.linear1 = tf.keras.layers.Dense(units=hidden, input_shape=[d_model,1], activation=None)
    self.linear2 = tf.keras.layers.Dense(units=d_model, input_shape=[hidden,1], activation=None)
    self.relu = tf.keras.layers.ReLU()
    self.dropout = tf.keras.layers.Dropout(drop_prob)

  def call(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.linear2(x)
    return x

class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,d_model,ffn_hidden,num_heads,drop_prob):
    super(EncoderLayer, self).__init__()
    self.attention = MultiheadAttention(d_model, num_heads)
    self.norm1 = LayerNormalization(parameters_shape=[d_model])
    self.dropout1 = tf.keras.layers.Dropout(drop_prob)
    self.ffn = PositionwiseFeedForward(d_model = d_model, hidden=ffn_hidden, drop_prob=drop_prob)
    self.norm2 = LayerNormalization(parameters_shape=[d_model])
    self.dropout2 = tf.keras.layers.Dropout(drop_prob)

  def call(self, x, self_attention_mask):
    residual_x = tf.identity(x)
    x = self.attention(x, mask=None)
    x = self.dropout1(x)
    x = self.norm1(x + tf.cast(residual_x, dtype=x.dtype))
    residual_x = tf.identity(x)
    x = self.ffn(x)
    x = self.dropout2(x)
    x = self.norm2(x + tf.cast(residual_x, dtype=x.dtype))
    return x

class Encoder(tf.keras.layers.Layer):
  def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, language_to_index, START_TOKEN,END_TOKEN, PADDING_TOKEN):
    super().__init__()
    # self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
    self.layers = [
        EncoderLayer(d_model,ffn_hidden,num_heads,drop_prob)
        for _ in range(num_layers)
    ]
    self.num_layers = num_layers

  def call(self, x, self_attention_mask, start_token, end_token):
    # x = self.sentence_embedding(x, start_token, end_token)
    for layer in range(self.num_layers):
      x = self.layers[layer](x, self_attention_mask)

    return x

In [None]:
class MultiheadCrossAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model
    self.head_dim = d_model//num_heads
    self.num_heads = num_heads
    self.kv_layer = tf.keras.layers.Dense(units=2*d_model, input_shape=[d_model,1], activation=None)
    self.q_layer = tf.keras.layers.Dense(units=d_model, input_shape=[d_model,1], activation=None)
    self.linear_layer = tf.keras.layers.Dense(units=d_model, input_shape=[d_model,1], activation=None)

  def call(self, x, y, mask=None):
    batch_size, sequence_length, d_model = x.shape
    kv = self.kv_layer(x)
    q = self.q_layer(x)
    kv = tf.reshape(kv, [batch_size, sequence_length, self.num_heads, 2*self.head_dim])
    q = tf.reshape(q, [batch_size, sequence_length, self.num_heads, self.head_dim])
    kv = tf.transpose(kv, perm=[0, 2, 1, 3])
    q = tf.transpose(q, perm=[0, 2, 1, 3])
    k,v = tf.split(kv, 2, axis=-1)
    values, attention = scaled_dot_product_attention(q,k,v,mask)
    values = tf.reshape(tf.transpose(values, perm=[0, 2, 1, 3]), [batch_size, sequence_length, d_model])
    out = self.linear_layer(values)
    return out, attention

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,d_model,ffn_hidden,num_heads,drop_prob):
    super(DecoderLayer, self).__init__()
    self.self_attention = MultiheadAttention(d_model, num_heads)
    self.norm1 = LayerNormalization(parameters_shape=[d_model])
    self.dropout1 = tf.keras.layers.Dropout(drop_prob)

    self.encoder_decoder_attention = MultiheadCrossAttention(d_model=d_model, num_heads=num_heads)
    self.norm2 = LayerNormalization(parameters_shape=[d_model])
    self.dropout2 = tf.keras.layers.Dropout(drop_prob)

    self.ffn = PositionwiseFeedForward(d_model = d_model, hidden=ffn_hidden, drop_prob=drop_prob)
    self.norm3 = LayerNormalization(parameters_shape=[d_model])
    self.dropout3 = tf.keras.layers.Dropout(drop_prob)

  def call(self, x, y , self_attention_mask, cross_attention_mask):
    _y = tf.identity(y)
    y= self.self_attention(y, mask=self_attention_mask)
    y = self.dropout1(y)
    y = self.norm1(y + tf.cast(_y, dtype=y.dtype))

    _y = tf.identity(y)
    y, attention_dist = self.encoder_decoder_attention(x,y,mask=cross_attention_mask)
    # print(attention_dist.shape) (30, 8, 200, 200)
    y = self.dropout2(y)
    y = self.norm2(y + tf.cast(_y, dtype=y.dtype))

    _y = tf.identity(y)
    y = self.ffn(y)
    # print(y.shape) (30, 200, 512)
    y = self.dropout3(y)
    y = self.norm3(y + tf.cast(_y, dtype=y.dtype))

    return y

class Decoder(tf.keras.layers.Layer):
  def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
    super().__init__()

    # self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
    self.layers = [
        DecoderLayer(d_model,ffn_hidden,num_heads,drop_prob)
        for _ in range(num_layers)
    ]
    self.num_layers = num_layers

  def call(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
    # y = self.sentence_embedding(y, start_token, end_token)
    for layer in range(self.num_layers):
      y = self.layers[layer](x, y, self_attention_mask, cross_attention_mask)

    return y

In [None]:
class SentenceEmbedding(tf.keras.layers.Layer):
    "For a given sentence, create an embedding"
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = tf.keras.layers.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = tf.keras.layers.Dropout(drop_prob)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN

    def batch_tokenize(self, batch, start_token, end_token):

        def tokenize(sentence, start_token, end_token):
            sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return tf.convert_to_tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = tf.stack(tokenized)
        return tokenized

    def call(self, x, start_token, end_token): # sentence
        x = self.batch_tokenize(x, start_token, end_token)
        x = self.embedding(x)
        pos = self.position_encoder()
        x = self.dropout(x + pos)
        return x

NEG_INFTY = -1e9

def create_masks(eng_batch, kn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = tf.convert_to_tensor(np.triu(np.full((max_sequence_length, max_sequence_length), fill_value = True), k=1))
    encoder_padding_mask = tf.convert_to_tensor(np.full((num_sentences,max_sequence_length, max_sequence_length), fill_value = False))
    decoder_padding_mask_self_attention = tf.convert_to_tensor(np.full((num_sentences,max_sequence_length, max_sequence_length), fill_value = False))
    decoder_padding_mask_cross_attention = tf.convert_to_tensor(np.full((num_sentences,max_sequence_length, max_sequence_length), fill_value = False))

    for idx in range(num_sentences):
      eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = tf.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  tf.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = tf.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [None]:
class Transformer(tf.keras.layers.Layer):
  def __init__(self,
               d_model,
               ffn_hidden,
               num_heads,
               drop_prob,
               num_layers,
               max_sequence_length,
               kn_vocab_size,
               english_to_index,
               START_TOKEN,
               END_TOKEN,
               PADDING_TOKEN):
    super().__init__()

    self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, english_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
    self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, english_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
    self.linear = tf.keras.layers.Dense(d_model, input_shape=[kn_vocab_size,1], activation=None)

    def call(self,
                x,
                y,
                encoder_self_attention_mask=None,
                decoder_self_attention_mask=None,
                decoder_cross_attention_mask=None,
                enc_start_token=False,
                enc_end_token=False,
                dec_start_token=False, # We should make this true
                dec_end_token=False): # x, y are batch of sentences
        x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
        out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token)
        out = self.linear(out)
        out = tf.nn.softmax(out)
        return out