# Attention

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


### 어텐션 가중치 계산

In [3]:
def attention(query, key, value):
    # 1. 어텐션 스코어 계산(Query - Key)
    scores = torch.matmul(query, key.transpose(-2,-1))
    print('Attention Score shape : ', scores.shape)

    # 2. Softmax 적용 (가중치 계산)
    attention_weights = F.softmax(scores, dim=-1)
    print('attention_weights shape : ', attention_weights.shape)

    # 3. 어텐션 벨류 계산 (Value 적용 => 최종 Context Vector 계산)
    context_vector = torch.matmul(attention_weights, value)
    print('context_vector shape : ', context_vector.shape)
    
    return context_vector

In [4]:
# 토큰화 및 임베딩 결과 예시
vocab = {
    '나는' : 0,
    '학교에' : 1,
    '간다' : 2,
    '<pad>' : 3
}
vocab_size = len(vocab)
EMBEDDING_DIM = 4

In [5]:
# 입력 문장
inputs = ['나는', '학교에', '간다']
inputs_ids = torch.tensor([[vocab[word] for word in inputs]])   # (1,3)

In [6]:
# 1. 임베딩
embedding_layer = nn.Embedding(vocab_size, EMBEDDING_DIM)
inputs_embeded = embedding_layer(inputs_ids)
# print(inputs_embeded.shape)


# 2. 선형 변환 -> Query, Key, Value
HIDDEN_DIM = 4
W_query = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)
W_key = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)
W_value = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)

input_query = W_query(inputs_embeded)
input_key = W_key(inputs_embeded)
input_value = W_value(inputs_embeded)

print(input_query.shape, input_key.shape, input_value.shape)


torch.Size([1, 3, 4]) torch.Size([1, 3, 4]) torch.Size([1, 3, 4])


In [7]:
context_vector = attention(input_query, input_key, input_value)
context_vector

Attention Score shape :  torch.Size([1, 3, 3])
attention_weights shape :  torch.Size([1, 3, 3])
context_vector shape :  torch.Size([1, 3, 4])


tensor([[[-0.4006,  0.3373,  0.1517,  0.5974],
         [-0.3587,  0.3901,  0.2104,  0.5476],
         [-0.4059,  0.3132,  0.1366,  0.6119]]], grad_fn=<UnsafeViewBackward0>)

### Seq2Seq 모델에 어텐션 추가

In [8]:
class Attention(nn.Module):
    def __init__(self, hidden_size):    # 히든 사이즈는 은닉 사이즈
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)     # 쿼리와 키 벡터를 반환 하기 위해 *2를 실행
        self.v = nn.Parameter(torch.rand(hidden_size))  # 어텐션 값들을 스칼라 값으로 변환하기 위한 작업 / 가중치를 의미 

    def forward(self, hidden, encoder_outputs):
        seq_len = encoder_outputs.shape[1]  # 첫번째 시퀀스를 넣어야 배치 사이즈만큼 크기만큼 받아올 수 있음
        hidden_expanded = hidden.unsqueeze(1).repeat(1, seq_len, 1)     # 입력 시퀀스 만큼 복제를 해서 반환 / 디코더의 현재 STATE를 받아서 사용
        energy = torch.tanh(self.attn(torch.cat((hidden_expanded, encoder_outputs), dim=2)))    # 현재의 상태와 현재의 출력(계산된 state)을 받아와서 사용 / 디코더의 현재 상태와 인코더의 출력을 연결하기 위함
        attention_scores= torch.sum(self.v * energy, dim=2)     # 가중치에 현재 상태를 합채줘서 곱해줌 
        attention_weights = F.softmax(attention_scores, dim=1)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        return context_vector, attention_weights

In [9]:
class Seq2SeqWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim *2, output_dim)
        self.decoder_input_transform = nn.Linear(input_dim, hidden_dim)

    def forward(self, encoder_input, decoder_input):
        encoder_output, hidden = self.encoder(encoder_input)
        context_vector, _ = self.attention(hidden[-1], encoder_output)
        decoder_input = self.decoder_input_transform(decoder_input)
        output, _ = self.decoder(decoder_input, hidden)
        combined = torch.cat((output, context_vector.unsqueeze(1)), dim=2)
        return self.fc(combined)

In [10]:
batch_size = 1
seq_len = 5
input_dim = 10
hidden_dim = 20
output_dim = 15

encoder_input = torch.randn(batch_size, seq_len, input_dim)
decoder_input = torch.randn(batch_size, 1, input_dim)

model = Seq2SeqWithAttention(input_dim=10, hidden_dim=20, output_dim=15)
result = model(encoder_input,decoder_input)
print(result)


tensor([[[-0.1095,  0.0429,  0.1522,  0.1474,  0.0809, -0.1300, -0.1748,
          -0.0850,  0.2255, -0.0280,  0.4096, -0.1512, -0.0450, -0.1638,
          -0.1225]]], grad_fn=<ViewBackward0>)
