# Attention Mechanism

![image](./resources/Attention1.png)

Attention mechanism을 간단하게 말하자면, input의 중요한 부분에 집중하여 output 예측 정확도를 향상시키는 기술입니다.

아래의 예제 코드는 Encoder-Decoder 구조의 Attention mechanism입니다. 일반적으로 encoder에서 input을 feature vector로 압축하면, 저차원의 vector로 표현되기 때문에 input에 대한 모든 정보를 담기는 어렵습니다. 즉, 정보의 손실이 발생합니다. 따라서 decoder은 더 많은 정보를 알기 위해서, encoder의 모든 hidden state를 참고합니다. 참고를 할 때는 중요한 hidden state에 더 높은 가중치를 부여하고, decoder의 input에 값을 더하거나 concat하여 output 예측에 이용됩니다.

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import tensorflow as tf
import numpy as np

토이 문제로 간단한 machine translation 문제를 풀어보겠습니다.

이 문제에서는 숫자로만 된 변역 문제입니다.

문제 풀이는 attention을 사용한 seq2seq 모델을 사용합니다.

문제를 단순하게 만들기 위해서 입력 문장의 단어는 전부 1-5 까지 숫자이고, 출력 문장의 단어는 그에 해당하는 6-10까지의 숫자입니다. (즉 "입력 단어" + 5 -> "출력 단어")

In [2]:
#Dataset define
SOS_token = 0
EOS_token = 11
n_samples=6000
seq_len = 5

X = np.array([[np.random.randint(1, 6) for _ in range(seq_len)] for _ in range(n_samples)])
Y = X+5

X= np.insert(X, 0, SOS_token, axis=1)
Y = np.insert(Y, seq_len, EOS_token, axis=1)

x_data = X[:int(n_samples * 0.5),:]
y_data = Y[:int(n_samples * 0.5),:]

x_eval = X[int(n_samples * 0.5): ,:]
y_eval = Y[int(n_samples * 0.5): ,:]

예제 문장 다섯개를 보면

In [9]:
for _ in range(5):
    idx = np.random.randint(len(X))
    print(f"입력문장: {X[idx][1:]} -> 출력문장: {Y[idx][:-1]}")

입력문장: [4 2 5 5 5] -> 출력문장: [ 9  7 10 10 10]
입력문장: [4 2 4 4 2] -> 출력문장: [9 7 9 9 7]
입력문장: [2 4 4 3 4] -> 출력문장: [7 9 9 8 9]
입력문장: [3 5 3 1 4] -> 출력문장: [ 8 10  8  6  9]
입력문장: [2 4 2 3 4] -> 출력문장: [7 9 7 8 9]


In [15]:
#----vocab----#
vocab_size = 12
embedding_dim =32

#----training----#
batch_size = 3000
epochs = 1000

#----encoder,decoder----#
hidden_dim = 128

In [16]:
def batch_generator(x_data, y_data, batch_size):
    n_samples = len(x_data)
    while True:
        batches = range(0, n_samples, batch_size)
        for start in batches:
            end = start + batch_size
            X_batch = x_data[start:end]
            Y_batch = y_data[start:end]
            
            all_data = {
                'Encoder_input' : X_batch,
                'Decoder_input' : Y_batch
            }
            yield (all_data)

In [17]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        # make embedding matrix
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        '''
        LSTM args.  
            
            return_sequences
                True : [batch_size, input_seq_len, hidden_dim]  => return all sequence
                False: [batch_size, hidden_dim] => return last output
        '''
        self.lstm = tf.keras.layers.LSTM(hidden_dim,
                                        return_sequences=True,
                                        return_state=True) 
   
    def call(self, inputs):
        # embeded_input : [batch_size, input_seq_len, embedding_dim]
        embeded_inputs = self.embedding(inputs)
        output, memory_state, carry_state = self.lstm(embeded_inputs, initial_state = self.init_hidden_state())
        return output, memory_state, carry_state
    
    def init_hidden_state(self):
        return (tf.zeros((self.batch_size, self.hidden_dim)),tf.zeros((self.batch_size, self.hidden_dim))) #init state : tuple
    

![image](./resources/Attention2.png)

In [18]:
class Attention(tf.keras.layers.Layer):
    '''
        Attention Mechanism
        1. encoder hidden state 생성
        2. previous decoder hidden state와 모든 encoder hidden state와의 Alignment score 계산 => attention score
        3. attention score softmax함 => attention weight
        4. attention weight X 모든 encoder hidden state => context vector
        5. previous decoder output과 context vector concat하여 Decoder의 input으로 사용함
        6. 1~5반복
    '''
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(hidden_dim)
        self.W2 = tf.keras.layers.Dense(hidden_dim)
        self.V = tf.keras.layers.Dense(1)
    def call(self, encoder_output, decoder_state, BahdanauAttn=True):
        '''
            encoder_output : [batch_size, time_step, hidden_dim]
            decoder_state : [batch_size, hidden_dim]
        '''
        # encoder_output과 shape 맞춤
        decoder_state = tf.expand_dims(decoder_state, 1)
        
        # Equation (4) [batch_size, time_step, 1]
        # self.W2(encoder_output)의 각 time step에 self.W1(decoder_state)더함
        attn_score = self.V(tf.nn.tanh(self.W1(decoder_state) + self.W2(encoder_output)))
    
        # Equation (1) [batch_size, input_seq_len, 1]
        attn_weights = tf.nn.softmax(attn_score, axis =1)
        
        # Equation (2)
        # [batch_size, input_seq_len, 1], 각 encoder_hidden에 attn_weigts 곱함
        context_vector = attn_weights * encoder_output
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attn_weights

구현에서 Attention 클래스는 Decoder에서 생성하고, Decoder 의 call 함수에서 사용한다는 점을 들 수 있습니다. 즉, 가장 위 개념도에서와 달리 Decoder의 부품으로 생각하고 구현합니다.

In [19]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size):
        super(Decoder, self).__init__()
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

        self.lstm = tf.keras.layers.LSTM(self.hidden_dim,
                                        return_sequences=True,
                                        return_state=True)
        self.attn = Attention(hidden_dim)
        #vocab_logit
        self.projection_layer = tf.keras.layers.Dense(vocab_size)
        
    def call(self, inputs, prev_decoder_state, encoder_output):
        context_vector,_ = self.attn(encoder_output, prev_decoder_state)

        #embeded_inputs: [batch_size, 1 , embedding_dim]
        embeded_inputs = self.embedding(inputs)
        
        # Equation (3)
        #new_input: [batch_size, 1 , embedding_dim + hidden_dim]
        new_input = tf.concat([tf.expand_dims(context_vector, 1), embeded_inputs], axis=-1)
        decoder_output, state, carry_state = self.lstm(new_input)

        #[batch, hidden_dim]
        decoder_output = tf.reshape(decoder_output, (-1, decoder_output.shape[2]))
        v =  self.projection_layer(decoder_output)
        return v, state

In [20]:
def train_step(encoder, decoder, optimizer, loss_object, encoder_input, target):
    loss = 0
    with tf.GradientTape() as tape:
        encoder_output, memory_state, carry_state= encoder(encoder_input)
        # decoder init state : feature vector
        hidden = memory_state 
        
        # [[0]* batch_size]
        decoder_input = tf.expand_dims([SOS_token] * batch_size, 1)
        
        # Teacher forcing 
        for t in range(0,target.shape[1]):
            pred, hidden = decoder(decoder_input, hidden, encoder_output)
            # pred와 target[:,t]가 매칭되도록 학습 
            loss += loss_function(loss_object, target[:,t], pred)
            decoder_input = tf.expand_dims(target[:,t], 1)
            
        batch_loss = (loss/int(decoder_input.shape[1]))
        variables = encoder.trainable_variables + decoder.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))
    
        return batch_loss


In [21]:
def train_run(epoch, encoder, decoder, x_data, y_data):  
    generator= batch_generator(x_data,y_data,batch_size)
    
    optimizer = tf.keras.optimizers.Adam()
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
    
    for e in range(epoch):
        total_loss = 0
        for step in range(len(x_data)//batch_size):
            data = next(generator)
            batch_loss = train_step(encoder, decoder, optimizer, loss_object ,data['Encoder_input'] ,data['Decoder_input'])
        if e % 50 ==0:
            print(f'Epochs :{e}/{epoch}, \t Batch_loss : {batch_loss:.5f}')

In [22]:
def loss_function(loss_object, real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    # padding된 부분 masking하여 loss에 영향을 주지 않도록함
    loss_ *= mask
    return tf.reduce_mean(loss_)

In [23]:
def evaluate(encoder, decoder, x_eval, y_eval, n_samples= 10):
    generator= batch_generator(x_eval,y_eval,1)
    encoder.batch_size = 1
    for i in range(n_samples):
        data = next(generator)
        inputs = data['Encoder_input']
        target = data['Decoder_input'][:,:-1].squeeze()
        encoder_output, memory_state, carry_state= encoder(inputs)
        hidden = memory_state
        
        decoder_input = tf.expand_dims([0], 1)
        result=''
        for t in range(inputs.shape[1]):
            pred, hidden = decoder(decoder_input, hidden, encoder_output)
            pred_id = tf.argmax(pred,axis=1).numpy()
            if pred_id ==EOS_token:
                break
            else :
                result += str(pred_id[0]) +' '
            decoder_input = tf.expand_dims(pred_id,1)
        print(f'real: {target} \t pred : {result}')
        result=''

In [24]:
def run(): 
    encoder = Encoder(vocab_size, embedding_dim, hidden_dim, batch_size)
    decoder = Decoder(vocab_size, embedding_dim, hidden_dim, batch_size)
    train_run(epochs, encoder, decoder, x_data, y_data)
    evaluate(encoder, decoder, x_eval ,y_eval)

In [25]:
run()

Epochs :0/1000, 	 Batch_loss : 14.90709
Epochs :50/1000, 	 Batch_loss : 10.74005
Epochs :100/1000, 	 Batch_loss : 10.53249
Epochs :150/1000, 	 Batch_loss : 9.37914
Epochs :200/1000, 	 Batch_loss : 8.67753
Epochs :250/1000, 	 Batch_loss : 8.23183
Epochs :300/1000, 	 Batch_loss : 8.04186
Epochs :350/1000, 	 Batch_loss : 12.75408
Epochs :400/1000, 	 Batch_loss : 10.17719
Epochs :450/1000, 	 Batch_loss : 8.97124
Epochs :500/1000, 	 Batch_loss : 6.99341
Epochs :550/1000, 	 Batch_loss : 3.70626
Epochs :600/1000, 	 Batch_loss : 1.86160
Epochs :650/1000, 	 Batch_loss : 0.97385
Epochs :700/1000, 	 Batch_loss : 0.51756
Epochs :750/1000, 	 Batch_loss : 0.29511
Epochs :800/1000, 	 Batch_loss : 0.18802
Epochs :850/1000, 	 Batch_loss : 0.13040
Epochs :900/1000, 	 Batch_loss : 0.09644
Epochs :950/1000, 	 Batch_loss : 0.07472
real: [ 9  9 10  6 10] 	 pred : 9 9 10 6 10 
real: [ 8 10  8  9  9] 	 pred : 8 10 8 9 9 
real: [ 7  6  9  7 10] 	 pred : 7 6 9 7 10 
real: [ 8 10  7  8  8] 	 pred : 8 10 7 8 8 
r