In [1]:
import numpy as np
import pandas as pd

import tensorflow as tf
import keras

In [16]:
def scaled_dot_product_attention(query, key, value, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    
    depth : d_model / num_heads
    ... : batch_size
    
    Args:
        query: query shape == (..., num_heads, seq_len_q, depth)
        key: key shape == (..., num_heads, seq_len_k, depth)
        value: value shape == (..., num_heads, seq_len_v, depth)     
        mask : mask shape == (..., 1, 1, seq_len_k)   
        
    Returns:
        output, attention_weights
    """
    
    matmul_qk = tf.matmul(a=query, b=key, transpose_b=True)  # Q*K while K is transposed. (..., num_heads, seq_len_q, seq_len_k)
    depth_float = tf.cast(tf.shape(key)[-1], tf.float32) 
    attention_logits = matmul_qk / tf.math.sqrt(depth_float)  # scale matmul_qk
    
    # add the mask to the scaled tensor.
    if mask is not None:
        attention_logits += (mask * -1e9)  # -1e9 : -infinite
        
    # softmax is normalized on the last axis (seq_len_k)
    # calculate the attention weights(== attention distribution).
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)  # (..., num_heads, seq_len_q, seq_len_k)
    
    attention_values = tf.matmul(attention_weights, value)  # (..., num_heads, seq_len_q, depth)
    
    return attention_values, attention_weights

### logit : 
 - 확률 값으로 변환되기 직전의 최종 결과 값 (== score)
 - 마지막 노드에서 아무런 Activation Function을 거치지 않은 값.

### tf.matmul:
 - when calculating matrices with more than 2-D, last two dimensions are multiplied.

In [17]:
np.set_printoptions(suppress=True)
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)

temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)

In [15]:
# 함수 실행
# temp_q == temp_k[1], temp_out == temp_v[1]
temp_out, temp_attn = scaled_dot_product_attention(temp_q, temp_k, temp_v, None)
print(temp_attn) # attention distribution (== attention weights)
print(temp_out) # attention values

!
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)


In [11]:
# temp_q = temp_k[2], temp_k[3], temp_out = temp_attn[2]*temp_v[2] + temp_attn[3]*temp_v[3]
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32)
temp_out, temp_attn = scaled_dot_product_attention(temp_q, temp_k, temp_v, None)
print(temp_attn)
print(temp_out) 

tf.Tensor([[0.  0.  0.5 0.5]], shape=(1, 4), dtype=float32)
tf.Tensor([[550.    5.5]], shape=(1, 2), dtype=float32)


In [12]:
temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32)  # (3, 3)
temp_out, temp_attn = scaled_dot_product_attention(temp_q, temp_k, temp_v, None)
print(temp_attn) 
print(temp_out)

tf.Tensor(
[[0.  0.  0.5 0.5]
 [0.  1.  0.  0. ]
 [0.5 0.5 0.  0. ]], shape=(3, 4), dtype=float32)
tf.Tensor(
[[550.    5.5]
 [ 10.    0. ]
 [  5.5   0. ]], shape=(3, 2), dtype=float32)
