<a href="https://colab.research.google.com/github/juhumkwon/source_code/blob/main/SelfAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

import tensorflow as tf
from tensorflow.keras.layers import Layer

class SelfAttention(Layer):
    def __init__(self, d_model):
        super(SelfAttention, self).__init__()
        self.d_model = d_model

        self.query_dense = tf.keras.layers.Dense(d_model)
        self.key_dense = tf.keras.layers.Dense(d_model)
        self.value_dense = tf.keras.layers.Dense(d_model)
        self.softmax = tf.keras.layers.Softmax(axis=-1)

    def call(self, inputs):
        queries = self.query_dense(inputs)
        keys = self.key_dense(inputs)
        values = self.value_dense(inputs)

        attention_scores = tf.matmul(queries, keys, transpose_b=True)
        attention_scores /= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        attention_weights = self.softmax(attention_scores)

        output = tf.matmul(attention_weights, values)
        return output, attention_weights

# 테스트 데이터 생성
batch_size = 2
seq_length = 5
d_model = 8

inputs = tf.random.normal((batch_size, seq_length, d_model))
self_attention = SelfAttention(d_model)
output, attention_weights = self_attention(inputs)

print("Self-Attention Output:", output)
print("Attention Weights:", attention_weights)

Self-Attention Output: tf.Tensor(
[[[ 0.9311143  -0.33381045  0.36613807 -0.69959617 -0.63235754
    0.02820292  0.4311988  -0.24781424]
  [-0.28204525  0.34441727  0.18293333 -0.06207423 -0.4488752
   -0.94877803 -0.1723257   0.65572953]
  [-0.4450016   0.30996057  0.08544388  0.02746993 -0.4649567
   -1.0028745  -0.28625324  0.7357754 ]
  [-0.19831698  0.13410234  0.25178906  0.05277795 -0.6322757
   -0.793494   -0.03390846  0.6526605 ]
  [-0.4789609   0.08307049 -0.00669652  0.14956263 -0.5012684
   -0.7799514  -0.23635475  0.68860555]]

 [[ 0.96474886 -0.01032655 -0.13513587 -0.9784777   0.19818676
    0.6600754  -1.1026492  -0.5322181 ]
  [ 1.013519    0.1624079  -0.95928025 -1.4860604   0.5269492
    0.57625806 -1.848207   -0.5674479 ]
  [ 1.3793604  -0.19381861  0.3637418  -1.3125342   0.08677528
    0.50825506 -0.82970864 -0.24490663]
  [ 1.4154563  -0.09034033 -0.01997704 -1.5280592   0.2084867
    0.5239569  -1.2390686  -0.31188726]
  [ 0.28618154 -0.11962916 -0.39009106 -0.5