In [2]:
import numpy as np
import pandas as pd

import tensorflow as tf
import keras

In [3]:
def scaled_dot_product_attention(query, key, value, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    
    depth : d_model / num_heads
    ... : batch_size
    
    Args:
        query: query shape == (..., num_heads, seq_len_q, depth)
        key: key shape == (..., num_heads, seq_len_k, depth)
        value: value shape == (..., num_heads, seq_len_v, depth)     
        mask : mask shape == (..., 1, 1, seq_len_k)   
        
    Returns:
        output, attention_weights
    """

    matmul_qk = tf.matmul(a=query, b=key, transpose_b=True)  # Q*K while K is transposed. (..., num_heads, seq_len_q, seq_len_k)
    depth_float = tf.cast(tf.shape(key)[-1], tf.float32)
    attention_logits = matmul_qk / tf.math.sqrt(depth_float)  # scale matmul_qk

    # add the mask to the scaled tensor.
    if mask is not None:
        attention_logits += (mask * -1e9)  # -1e9 : -infinite

    # softmax is normalized on the last axis (seq_len_k)
    # calculate the attention weights(== attention distribution).
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)  # (..., num_heads, seq_len_q, seq_len_k)

    attention_values = tf.matmul(attention_weights, value)  # (..., num_heads, seq_len_q, depth)

    return attention_values, attention_weights

In [4]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, name="multi_head_attention"):
        super(MultiHeadAttention, self).__init__(name=name)
        self.d_model = d_model
        self.num_heads = num_heads
        # make sure d_model can be divided by num_heads
        assert d_model % self.num_heads == 0 
        self.depth = d_model // self.num_heads # // : floor division
        
        # WQ, WK, WV
        self.query_dense = tf.keras.layers.Dense(units=d_model)
        self.key_dense = tf.keras.layers.Dense(units=d_model)
        self.value_dense = tf.keras.layers.Dense(units=d_model)
        # WO
        self.dense = tf.keras.layers.Dense(units=d_model)
    
    def split_heads(self, inputs, batch_size):
        """Split Query, Key, Value with num_heads
        
        Args:
            inputs: input shape == (batch_size, seq_len, d_model)
            batch_size: batch size
        
        Returns:
            result: result shape == (batch_size, num_heads, seq_len, depth)
        """
        inputs = tf.reshape(tensor=inputs, shape=(batch_size, -1, self.num_heads, self.depth)) # (batch_size, seq_len, num_heads, depth)
        return tf.transpose(a=inputs, perm=[0, 2, 1, 3]) # (batch_size, num_heads, seq_len, depth)
    
    def call(self, inputs):
        """MultiHeadAttention
        
        Args:
            inputs: Q, K, V, mask
                Q shape == (batch_size, seq_len_q, d_model)
                K shape == (batch_size, seq_len_k, d_model)
                V shape == (batch_size, seq_len_v, d_model)
                mask shape == (batch_size, seq_len_q, seq_len_k)
        
        Returns:
            output, attention_weights
        """
        query, key, value, mask = inputs['query'], inputs['key'], inputs['value'], inputs['mask']
        batch_size = tf.shape(input=query)[0]
        
        # 1. Q,K,V linear layer
        query = self.query_dense(query) # (batch_size, seq_len_q, d_model)
        key = self.key_dense(key) # (batch_size, seq_len_k, d_model)
        value = self.value_dense(value) # (batch_size, seq_len_v, d_model)
        
        # 2. split heads
        query = self.split_heads(query, batch_size) # (batch_size, num_heads, seq_len_q, depth)
        key = self.split_heads(key, batch_size) # (batch_size, num_heads, seq_len_k, depth)
        value = self.split_heads(value, batch_size) # (batch_size, num_heads, seq_len_v, depth)
        
        # 3. scaled dot-product attention
        temp_attention_values, attention_weights = scaled_dot_product_attention(query, key, value, mask) # (batch_size, num_heads, seq_len_q, depth)
        
        # 4. transpose result and concat heads
        temp_attention_values = tf.transpose(a=temp_attention_values, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
        concat_temp_attention_values = tf.reshape(tensor=temp_attention_values, shape=(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
        
        # 5. final linear layer
        attention_values = self.dense(concat_temp_attention_values) # (batch_size, seq_len_q, d_model)
        
        return attention_values, attention_weights
    

### tf.reshape
 - when dimension is -1, it means 'unspecified' dimension.
 - when dimension is -1, calculate this dimension automatically based on the size of the input and the other dimensions". It’s a way of saying "reshape this into whatever dimension is needed so that the total size remains constant".

In [8]:
def create_mask(x):
    """Create mask for padding
    
    Args:
        x: input sequence
        
    Returns:
        mask: mask for padding
    """
    mask = tf.cast(tf.math.equal(x, 0), tf.float32) # 0 is padding value and find it.
    return mask[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)

In [10]:
print(create_mask(tf.constant([[1, 2, 3, 0, 0], [0,0,0,0,0], [1,2,3,4,5]])))

tf.Tensor(
[[[[0. 0. 0. 1. 1.]]]


 [[[1. 1. 1. 1. 1.]]]


 [[[0. 0. 0. 0. 0.]]]], shape=(3, 1, 1, 5), dtype=float32)
