<a href="https://colab.research.google.com/github/jongheonleee/LLM_study/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch.nn as nn
import torch


# > 토큰화 <
input_text = '나는 최근 파리 여행을 다녀왔다'
input_text_list = input_text.split()
print('input text list => ', input_text_list)

str2idx = {word:idx for idx, word in enumerate(input_text_list)}
idx2str = {idx:word for idx, word in enumerate(input_text_list)}
print('str2idx => ', str2idx)
print('idx2str => ', idx2str)

input_ids = [str2idx[word] for word in input_text_list]
print('input_ids => ', input_ids)

# > 토큰 임베딩으로 변환하기 <
embedding_dim = 16
embed_layer = nn.Embedding(len(str2idx), embedding_dim)

input_embeddings = embed_layer(torch.tensor(input_ids))
input_embeddings = input_embeddings.unsqueeze(0)
input_embeddings.shape


# > 절대적 위치 인코딩 <
embedding_dim = 16
max_position = 12
embed_layer = nn.Embedding(len(str2idx), embedding_dim)
position_embed_layer = nn.Embedding(max_position, embedding_dim)

position_ids = torch.arange(len(input_ids), dtype=torch.long).unsqueeze(0)
position_encodings = position_embed_layer(position_ids)
token_embeddings = embed_layer(torch.tensor(input_ids))
token_embeddings = token_embeddings.unsqueeze(0)
input_embeddings = token_embeddings + position_encodings
input_embeddings.shape


input text list =>  ['나는', '최근', '파리', '여행을', '다녀왔다']
str2idx =>  {'나는': 0, '최근': 1, '파리': 2, '여행을': 3, '다녀왔다': 4}
idx2str =>  {0: '나는', 1: '최근', 2: '파리', 3: '여행을', 4: '다녀왔다'}
input_ids =>  [0, 1, 2, 3, 4]


torch.Size([1, 5, 16])

In [7]:
# > 쿼리, 키, 값 벡터를 만드는 nn.Linear <
head_dim = 16

# 쿼리, 키, 값을 계산하기 위한 변환
weight_q = nn.Linear(embedding_dim, head_dim)
weight_k = nn.Linear(embedding_dim, head_dim)
weight_v = nn.Linear(embedding_dim, head_dim)

# 변환 수행
querys = weight_q(input_embeddings)
keys = weight_k(input_embeddings)
values = weight_v(input_embeddings)

In [8]:
# > 스케일 점곱 방식의 어텐션 <
from math import sqrt
import torch.nn.functional as F

def compute_attention(querys, keys, values, is_causal=False):
  dim_k = querys.size(-1)
  scores = querys @ keys.transpose(-2, -1)
  weights = F.softmax(scores, dim=1)
  return weights @ values


In [9]:
# > 어텐션 연산의 입력과 출력 <

print('원본 입력 형태: ', input_embeddings.shape)

after_attention_embeddings = compute_attention(querys, keys, values)

print('어텐션 적용 후 형태: ', after_attention_embeddings.shape)

원본 입력 형태:  torch.Size([1, 5, 16])
어텐션 적용 후 형태:  torch.Size([1, 5, 16])


In [12]:
# > 어텐션 연산을 수행하는 AttentionHead 클래스 <

class AttentionHead(nn.Module):
  def __init__(self, token_embed_dim, head_dim, is_causal=False):
    super().__init__()
    self.is_causal = is_causal
    self.weight_q = nn.Linear(token_embed_dim, head_dim)
    self.weight_k = nn.Linear(token_embed_dim, head_dim)
    self.weight_v = nn.Linear(token_embed_dim, head_dim)

  def forward(self, querys, keys, values):
    outputs = compute_attention(
        self.weight_q(querys),
        self.weight_k(keys),
        self.weight_v(values),
        is_causal=self.is_causal
    )

attention_head = AttentionHead(embedding_dim, embedding_dim)
after_attention_embeddings = attention_head(input_embeddings, input_embeddings, input_embeddings)

In [14]:
# > 멀티 헤드 어텐션 구현 <

class MultiheadAttention(nn.Module):

  def __init__(self, token_embed_dim, d_model, n_head, is_causal=False):
    super().__init__()
    self.n_head = n_head
    self.is_causal = is_causal

    self.weight_q = nn.Linear(token_embed_dim, d_model)
    self.weight_k = nn.Linear(token_embed_dim, d_model)
    self.weight_v = nn.Linear(token_embed_dim, d_model)
    self.concat_linear = nn.Linear(d_model, d_model)

  def forward(self, querys, keys, values):
    B, T, C = querys.size()
    querys = self.weight_q(querys).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    keys = self.weight_k(keys).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    values = self.weight_v(values).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    attention = compute_attention(querys, keys, values, self.is_causal)
    output = attention.transpose(1, 2).contiguous().view(B, T, C)
    output = self.concat_linear(output)
    return output


n_head = 4
mh_attention = MultiheadAttention(embedding_dim, embedding_dim, n_head)
after_attention_embeddings = mh_attention(input_embeddings, input_embeddings, input_embeddings)
after_attention_embeddings.shape


torch.Size([1, 5, 16])