<a href="https://colab.research.google.com/github/juhumkwon/source_code/blob/main/Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense

class ScaledDotProductAttention(Layer):
    def __init__(self):
        super().__init__()

    def call(self, q, k, v, mask=None):
        matmul_qk = tf.matmul(q, k, transpose_b=True)  # (batch, num_heads, seq_len, seq_len)
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  # mask된 위치는 매우 작은 값으로 만들어 무시

        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        output = tf.matmul(attention_weights, v)  # (batch, num_heads, seq_len, depth)

        return output, attention_weights

In [3]:

class MultiHeadSelfAttention(Layer):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.wq = Dense(d_model)
        self.wk = Dense(d_model)
        self.wv = Dense(d_model)
        self.dense = Dense(d_model)
        self.attention = ScaledDotProductAttention()

    def split_heads(self, x, batch_size):
        # x: (batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, depth)
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, x, mask=None):
        batch_size = tf.shape(x)[0]

        q = self.split_heads(self.wq(x), batch_size)
        k = self.split_heads(self.wk(x), batch_size)
        v = self.split_heads(self.wv(x), batch_size)

        scaled_attention, attention_weights = self.attention(q, k, v, mask)

        # Concatenate heads
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch, seq_len, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.num_heads * self.depth))

        output = self.dense(concat_attention)  # (batch_size, seq_len, d_model)
        return output, attention_weights

In [4]:

# 하이퍼파라미터
batch_size = 2
seq_len = 5
d_model = 512
num_heads = 8

# 입력 예시
dummy_input = tf.random.uniform((batch_size, seq_len, d_model))

# 셀프 어텐션 레이어 실행
self_attention = MultiHeadSelfAttention(d_model=d_model, num_heads=num_heads)
output, attn_weights = self_attention(dummy_input)

print("출력 shape:", output.shape)            # (2, 5, 512)
print("어텐션 가중치 shape:", attn_weights.shape)  # (2, 8, 5, 5)

출력 shape: (2, 5, 512)
어텐션 가중치 shape: (2, 8, 5, 5)
