# Transformer - Attention Is All You Need

구조 : Encoder (self-attention + feed forward network) -> Decoder(self-attention + encoder-decoder attention + feed forward network) -> Fully Connected Layer

- Encoder : 입력 문맥 파악 및 정보를 압축

    - self-attention : input sequence 의 각 단어들을 가지고 Q (qeury), K (key), V (value) 계산 (각각의 Vector) = 단어들의 관련성 파악 
        
        -> 표현 계산 (q,k의 내적 계산 = 유사도 -> softmax -> 결과를 가중치로 변환 -> 가중치 * v) = 문맥에 맞는 단어 학습 (새로운 표현)
        
        -> 이 때 multi-head attention (여러 개의 attention head가 각각 q,k,v를 계산하고 결과를 결합) = attention head가 여러개 이기 때문에 다양한 표현 학습

    - feed forward network (= 2개의 fully connected layer) : self-attention의 결과 표현을 비 선형적 변환 (dim 확장 -> dim 축소) = 복잡한 관계 학습


- Decoder : 인코더 정보와 이전 출력 기반으로 출력 생성

    - self-attention (masked self-attention) : output sequence 의 timestep-1 만 고려 -> 현재 단어 이후의 정보를 참조하지 못하도록 = 문맥 파악

        -> 현재 입력을 가지고 q, k, v 계산

    - encoder-decoder attention (cross-attention) : encoder의 출력(문맥)을 가지고 output sequence 생성 -> Encoder의 압축된 문맥을 참고하여 단어 생성

        -> q (masked self-attention의 출력), k (encoder의 출력), v (encoder의 출력) 을 가지고 계산
    
    - feed forward network : encoder와 동일


- Fully Connected Layer : 차원을 변환하여 단어 예측



- 각 부분 연결 시 Residual Connection (잔차 연결), Layer Normalization (레이어 정규화), Positional Encoding (위치 정보) 필요

    * RC & LN : Residual Connection & Layer Normalization

    Encoder (embedding + positional encoding -> self-attention -> rc & ln -> feed forward network -> rc & ln)

    ->
    
    Decoder(embedding + positional encoding -> masked self-attention -> rc & ln -> Encoder-Decoder Attention -> rc & ln -> feed forward network -> rc & ln)



In [1]:
import torch
from torch.nn import Module, Embedding, Dropout, Linear, CrossEntropyLoss, init, utils, LayerNorm, ModuleList
from torch.optim import Adam
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

from time import time
import math

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.xpu.is_available():
    device = torch.device("xpu")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else: 
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


In [3]:
raw_datasets = load_dataset("wmt16", "de-en", split="train[:10%]")

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")

In [4]:
def tokenize_function(examples):
    tokenized_en = tokenizer(
        [ex['en'] for ex in examples['translation']],
        max_length=64,
        truncation=True,
        padding="max_length"
    )
    tokenized_de = tokenizer(
        [ex['de'] for ex in examples['translation']],
        max_length=64,
        truncation=True,
        padding="max_length"
    )
    return {
        'input_ids': tokenized_en['input_ids'],
        'attention_mask': tokenized_en['attention_mask'],
        'labels': tokenized_de['input_ids']
    }

In [5]:
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_test_split = tokenized_datasets.train_test_split(test_size=0.1)
train_data = train_test_split['train']
valid_data = train_test_split['test']

In [6]:
batch_size = 128
train_iterator = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_iterator = DataLoader(valid_data, batch_size=batch_size)

input_dim = tokenizer.vocab_size
output_dim = tokenizer.vocab_size
pad_idx = tokenizer.pad_token_id

In [7]:
# token들의 위치 정보를 model에 주입
class PositionalEncoding(Module):
    # model_dim : model의 embedding dim
    # max_len : encoding 할 수 있는 sequence 최대 길이
    def __init__(self, model_dim, max_len=5000):
        super().__init__()

        # 위치 인코딩 행렬
        position_encoding_matrics = torch.zeros(max_len, model_dim)
        
        # position tensor
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # sin / cos 계산의 분모
        div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim))
        # pe (positional encoding matrics) 에 
        position_encoding_matrics[:, 0::2] = torch.sin(position * div_term)
        position_encoding_matrics[:, 1::2] = torch.cos(position * div_term)
        
        # batch  dimension 추가 -> 3차원 tensor
        position_encoding_matrics = position_encoding_matrics.unsqueeze(0)
        # parameter가 아니라 상태(state)값으로 저장 -> 학습 X 
        self.register_buffer('position_encoding_matrics', position_encoding_matrics)
        
    def forward(self, x):
        # input sequence의 길이로 잘라서 리턴
        return self.position_encoding_matrics[:, :x.size(1)]

In [8]:
# attention 생성 (sequence -> token -> attention)
class MultiHeadAttention(Module):
    def __init__(self, model_dim, n_heads, dropout):
        super().__init__()

        # input / output embedding 차원
        self.model_dim = model_dim
        # attention을 생성할 head의 갯수
        self.n_heads = n_heads
        # head의 차원
        self.head_dim = model_dim // n_heads 
        
        # layer 생성
        # query layer
        self.fc_q = Linear(model_dim, model_dim)
        # key layer
        self.fc_k = Linear(model_dim, model_dim)
        # value layer
        self.fc_v = Linear(model_dim, model_dim)
        # output layer
        self.fc_o = Linear(model_dim, model_dim)
        
        self.dropout = Dropout(dropout)
        
        # attention score를 계산하기 위한 scale vector
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        batch_size = query.shape[0]
        
        # vector 생성
        q = self.fc_q(query)
        k = self.fc_k(key)
        v = self.fc_v(value)
        
        # 4차원으로 변환 -> head별 병렬 계산을 위해
        q = q.reshape(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # q, k scaling
        energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale
        
        # masking
        if mask is not None:
            # -1e9 (매우 작은 음수) 를 채워서 softmax 가 계산 시 attention weights 를 0으로 변경
            # 0으로 채우면 softmax 가 계산할 때 1 (e^0) 이 되버림 
            energy = energy.masked_fill(mask == 0, -1e9)
        
        # attention weights -> softmax 계산
        attention = torch.softmax(energy, dim = -1)
        
        # attention weights * v = 문맥에 맞는 단어 학습
        x = torch.matmul(self.dropout(attention), v)
        
        # contiguous : memory 재정렬
        x = x.permute(0, 2, 1, 3).contiguous()
        # 모든 head의 계산 결과를 결합 (concatenate)
        x = x.reshape(batch_size, -1, self.model_dim)
        
        # output layer를 통해 정렬
        x = self.fc_o(x)
        
        # 최종 출력과 attention weights 리턴
        return x, attention

In [9]:
# Positionwise : 모든 postion에 대해 동일한 가중치 계산
# attention 은 다른 position의 vector (token) 의 관계를 함께 계산하지만, 
# feedforward 에서는 각각의 vector (token)들이 독립적으로 가중치 계산 -> 표현 (특징 = feature) 을 뚜렷하게 (sharpen) 한다. 
class PositionwiseFeedforward(Module):
    def __init__(self, model_dim, feedforward_dim, dropout):
        super().__init__()
        
        # 확장
        self.fc_1 = Linear(model_dim, feedforward_dim)
        # 축소
        self.fc_2 = Linear(feedforward_dim, model_dim)
        self.dropout = Dropout(dropout)
        
    def forward(self, x):
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        
        return x

In [10]:
# Residual Connection & Layer Normalization -> 논문에서의 구조 (잔차 연결 -> 정규화)
class SublayerConnection(Module):
    def __init__(self, model_dim, dropout):
        super().__init__()
        self.norm = LayerNorm(model_dim)
        self.dropout = Dropout(dropout)
        
    def forward(self, x, sublayer):
        return self.norm(x + self.dropout(sublayer(x)))

In [11]:
# Layer Normalization & Residual Connection -> 정규화를 먼저 한 이후에 잔차 연결 (학습이 더 안정적으로 동작)
class SublayerConnection_PreNorm(Module):
    def __init__(self, model_dim, dropout):
        super().__init__()
        self.norm = LayerNorm(model_dim)
        self.dropout = Dropout(dropout)
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [12]:
# Encoder에서 반복되는 부분을 encapsulation
class EncoderLayer(Module):
    def __init__(self, model_dim, n_heads, feedforward_dim, dropout):
        super().__init__()
        
        # self-attention layer
        self.self_attention = MultiHeadAttention(model_dim, n_heads, dropout)
        # LN & RC
        self.self_attention_sublayer = SublayerConnection_PreNorm(model_dim, dropout) 
        # feed forward layer
        self.feed_forward = PositionwiseFeedforward(model_dim, feedforward_dim, dropout)
        # LN & RC 
        # 둘 다 ln * rc 인데 두 개 만드는 이유 : 독립적으로 동작시키기 위해 (독립적인 가중치)
        self.feedforward_sublayer = SublayerConnection_PreNorm(model_dim, dropout)
        
    # src : soruce sequence
    # src_mask : src의 padding token을 masking
    def forward(self, src, src_mask):
        # ln(정규화) -> attention 계산 -> rc(잔차 연결)
        src = self.self_attention_sublayer(src, lambda x: self.self_attention(x, x, x, src_mask)[0])
        # ln -> ff -> rc
        src = self.feedforward_sublayer(src, self.feed_forward)
        
        return src

In [13]:
class Encoder(Module):
    def __init__(self, input_dim, model_dim, n_layers, n_heads, feedforward_dim, dropout, max_len=64):
        super().__init__()
        
        # input sentence -> token
        self.tok_embedding = Embedding(input_dim, model_dim, padding_idx=pad_idx)
        # token + position 정보
        self.position_embedding = PositionalEncoding(model_dim, max_len)
        self.dropout = Dropout(dropout)
        
        # n개의 EncoderLayer (반복되는 코드 블록)
        self.layers = ModuleList([
            EncoderLayer(model_dim, n_heads, feedforward_dim, dropout) for _ in range(n_layers)
        ])
        
        # ln (마지막 출력) -> decoder의 k,v 로 사용
        self.norm = LayerNorm(model_dim)
        
    def forward(self, src, src_mask):
        embedded = self.tok_embedding(src)
        position_encoding = self.position_embedding(embedded)
        src = self.dropout(embedded + position_encoding)
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        return self.norm(src)

In [14]:
class DecoderLayer(Module):
    def __init__(self, model_dim, n_heads, feedforward_dim, dropout):
        super().__init__()
        
        # masked self-attention
        self.self_attention = MultiHeadAttention(model_dim, n_heads, dropout)
        self.self_attention_sublayer = SublayerConnection_PreNorm(model_dim, dropout)
        
        # encoder-decoder attention (cross attention)
        self.cross_attention = MultiHeadAttention(model_dim, n_heads, dropout)
        self.cross_attention_sublayer = SublayerConnection_PreNorm(model_dim, dropout)
        
        # encoder와 동일
        self.feed_forward = PositionwiseFeedforward(model_dim, feedforward_dim, dropout)
        self.feedforward_sublayer = SublayerConnection_PreNorm(model_dim, dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        trg = self.self_attention_sublayer(trg, lambda x: self.self_attention(x, x, x, trg_mask)[0])
        trg = self.cross_attention_sublayer(trg, lambda x: self.cross_attention(x, enc_src, enc_src, src_mask)[0])
        trg = self.feedforward_sublayer(trg, self.feed_forward)
        
        return trg

In [15]:
class Decoder(Module):
    def __init__(self, output_dim, model_dim, n_layers, n_heads, feedforward_dim, dropout, max_len=64):
        super().__init__()
        
        self.tok_embedding = Embedding(output_dim, model_dim, padding_idx=pad_idx)
        self.pos_embedding = PositionalEncoding(model_dim, max_len)
        self.dropout = Dropout(dropout)
        
        self.layers = ModuleList([
            DecoderLayer(model_dim, n_heads, feedforward_dim, dropout) for _ in range(n_layers)
        ])
        
        # Decoder only 모델에서는 Decoder 내부에 FC layer 존재
        # self.fully_connected_layer = Linear(model_dim, output_dim)
        self.norm = LayerNorm(model_dim)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        embedded = self.tok_embedding(trg)
        pos_enc = self.pos_embedding(embedded)
        trg = self.dropout(embedded + pos_enc)
        
        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)
            
        trg = self.norm(trg)
        # output = self.fully_connected_layer(trg)
        
        return trg

In [16]:
def create_masks(src, trg, pad_idx):
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    
    trg_pad_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(3)
    
    trg_len = trg.shape[1]
    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=device)).bool()
    
    trg_mask = trg_pad_mask & trg_sub_mask
    
    return src_mask, trg_mask

In [17]:
class Transformer(Module):
    def __init__(self, model_dim, output_dim, encoder, decoder, src_pad_idx, trg_pad_idx):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.fully_connected_layer = Linear(model_dim, output_dim)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        
    def forward(self, src, trg):
        src_mask, trg_mask = create_masks(src, trg, self.src_pad_idx)
        enc_src = self.encoder(src, src_mask)
        dec_output = self.decoder(trg, enc_src, trg_mask, src_mask)
        output = self.fully_connected_layer(dec_output)
        
        return output

In [18]:
model_dim = 512
n_heads = 8
n_layers = 3
feedforward_dim = model_dim * 4
dropout = 0.1
max_len = 64

In [19]:
encoder = Encoder(input_dim, model_dim, n_layers, n_heads, feedforward_dim, dropout, max_len)
decoder = Decoder(output_dim, model_dim, n_layers, n_heads, feedforward_dim, dropout, max_len)

In [20]:
model = Transformer(model_dim, output_dim, encoder, decoder, pad_idx, pad_idx).to(device)

In [21]:
# 가중치 초기화로 xavier uniform 사용
def init_weights(model):
    for name, param in model.named_parameters():
        if param.dim() > 1:
            init.xavier_uniform_(param)
        
model.apply(init_weights)

Transformer(
  (encoder): Encoder(
    (tok_embedding): Embedding(58101, 512, padding_idx=58100)
    (position_embedding): PositionalEncoding()
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (self_attention): MultiHeadAttention(
          (fc_q): Linear(in_features=512, out_features=512, bias=True)
          (fc_k): Linear(in_features=512, out_features=512, bias=True)
          (fc_v): Linear(in_features=512, out_features=512, bias=True)
          (fc_o): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (self_attention_sublayer): SublayerConnection_PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedforward(
          (fc_1): Linear(in_features=512, out_features=2048, bias=True)
          (fc_2): Linear(in_features=2048, out

In [22]:
# learning rate scheduler
# warmup 단계동안 증가, 그 후 역제곱근(inverse square root) 에 따라 감소
class NoamOptimizer:
    def __init__(self, model_dim, warmup_steps, optimizer):
        self.optimizer = optimizer
        self.model_dim = model_dim
        self.warmup_steps = warmup_steps
        self._step = 0
        self._rate = 0.

    # 각 batch마다 호출
    def step(self):
        self._step += 1
        rate = self.rate()

        for param in self.optimizer.param_groups:
            param['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    # 현재 step에 대한 lr 계산
    def rate(self, step=None):
        if step is None:
            step = self._step
        
        scale = self.model_dim ** (-0.5)
        
        return scale * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()

In [23]:
loss_function = CrossEntropyLoss(ignore_index=pad_idx)

learning_rate = 0.0005
epsilon = 1e-9
betas=(0.9, 0.98)
optimizer = Adam(model.parameters(), lr=learning_rate, betas=betas, eps=epsilon)

warmup_steps = 4000
optimizer = NoamOptimizer(model_dim, warmup_steps, optimizer)

In [24]:
def train(model, iterator, optimizer, loss_function):
    
    model.train()
    epoch_loss = 0
    num_batches = len(iterator)
    total_batch_time = 0
    
    for i, batch in enumerate(iterator):
        batch_start_time = time()

        src = batch['input_ids'].to(device)
        trg = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        output = model(src, trg[:, :-1])

        output_dim = output.shape[-1]
        output = output.reshape(-1, output_dim)
        
        trg = trg[:, 1:].reshape(-1) 
        
        loss = loss_function(output, trg)
        
        loss.backward()
        
        utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
        batch_end_time = time()
        batch_time = batch_end_time - batch_start_time
        total_batch_time += batch_time

        batch_mins = int(total_batch_time / 60)
        batch_secs = int(total_batch_time% 60)
        
        if i != 0 and i % 100 == 0:
            print(f"\t train batch: {i:5d}/{num_batches} \t loss: {loss.item():.3f} \t batch_time_for_100 : {batch_mins}m {batch_secs}s")
            total_batch_time = 0
        
    return epoch_loss / len(iterator)

In [25]:
def evaluate(model, iterator, loss_function):
    
    model.eval()
    epoch_loss = 0
    num_batches = len(iterator)
    total_batch_time = 0

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            batch_start_time = time()

            src = batch['input_ids'].to(device)
            trg = batch['labels'].to(device)
            
            output = model(src, trg[:, :-1]) 

            output_dim = output.shape[-1]
            output = output.reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)

            loss = loss_function(output, trg)

            epoch_loss += loss.item()

            batch_end_time = time()
            batch_time = batch_end_time - batch_start_time
            total_batch_time += batch_time

        batch_mins = int(total_batch_time / 60)
        batch_secs = int(total_batch_time% 60)

        print(f"\t eval batch: {i:5d}/{num_batches} \t loss: {loss.item():.3f} \t batch_time_for_100 : {batch_mins}m {batch_secs}s")

    return epoch_loss / len(iterator)

In [26]:
epochs = 30

total_time = 0

for epoch in range(1,epochs+1):
    print(f"[epoch : {epoch}]")
    start_time = time()
    
    train_loss = train(model, train_iterator, optimizer, loss_function)
    valid_loss = evaluate(model, valid_iterator, loss_function)
    
    train_ppl = math.exp(train_loss)
    valid_ppl = math.exp(valid_loss)

    end_time = time()
    
    epoch_time = end_time - start_time
    total_time += epoch_time
    
    epoch_mins = int((epoch_time) / 60)
    epoch_secs = int((epoch_time) % 60)

    print(f"epoch: {epoch:3d}/{epochs} \t train ppl: {train_ppl:4.3f} val ppl: {valid_ppl:4.3f} \t {epoch_mins}m {epoch_secs}s \n")

print(f"\ntotal time : {total_time}")

[epoch : 1]
	 train batch:   100/3199 	 loss: 10.101 	 batch_time_for_100 : 0m 15s
	 train batch:   200/3199 	 loss: 8.004 	 batch_time_for_100 : 0m 14s
	 train batch:   300/3199 	 loss: 6.458 	 batch_time_for_100 : 0m 14s
	 train batch:   400/3199 	 loss: 6.172 	 batch_time_for_100 : 0m 14s
	 train batch:   500/3199 	 loss: 6.109 	 batch_time_for_100 : 0m 14s
	 train batch:   600/3199 	 loss: 6.082 	 batch_time_for_100 : 0m 14s
	 train batch:   700/3199 	 loss: 5.907 	 batch_time_for_100 : 0m 14s
	 train batch:   800/3199 	 loss: 5.739 	 batch_time_for_100 : 0m 14s
	 train batch:   900/3199 	 loss: 5.641 	 batch_time_for_100 : 0m 14s
	 train batch:  1000/3199 	 loss: 5.467 	 batch_time_for_100 : 0m 14s
	 train batch:  1100/3199 	 loss: 5.228 	 batch_time_for_100 : 0m 14s
	 train batch:  1200/3199 	 loss: 5.144 	 batch_time_for_100 : 0m 14s
	 train batch:  1300/3199 	 loss: 4.919 	 batch_time_for_100 : 0m 14s
	 train batch:  1400/3199 	 loss: 4.861 	 batch_time_for_100 : 0m 14s
	 train

In [27]:
def predict_translation(model, sentence, tokenizer, max_len=64):
    model.eval()

    tokenized = tokenizer(sentence, return_tensors='pt', truncation=True, max_length=max_len)

    src = tokenized['input_ids'].to(device)

    sos_token_id = tokenizer.pad_token_id
    eos_token_id = tokenizer.eos_token_id
    pad_token_id = tokenizer.pad_token_id

    trg_tokens = torch.tensor([[sos_token_id]], dtype=torch.long, device=device)

    with torch.no_grad():
        src_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)
        enc_src = model.encoder(src, src_mask)

    for i in range(max_len - 1):
        with torch.no_grad():
            _, trg_mask = create_masks(src, trg_tokens, pad_token_id)
            dec_output = model.decoder(trg_tokens, enc_src, trg_mask, src_mask)
            output = model.fully_connected_layer(dec_output)

        output_token = output[:, -1, :] 

        next_token_id = output_token.argmax(1).item()

        trg_tokens = torch.cat(
            [trg_tokens, torch.tensor([[next_token_id]], dtype=torch.long, device=device)], dim = 1)

        if next_token_id == eos_token_id:
            break

    translated_tokens = trg_tokens.squeeze(0).tolist()[1:]

    translated_sentence = tokenizer.decode(translated_tokens, skip_special_tokens=True)

    return translated_sentence

In [None]:
test_sentence = "I am happy. Because studing is hard." 

predicted_translation = predict_translation(model, test_sentence, tokenizer)

print(f"input sequence: {test_sentence}")
print(f"output sequence: {predicted_translation}")
# 제대로 학습되었다면 : Ich bin glücklich. Denn das Studium ist anstrengend. (또는 유사한 번역)
# 번역 결과 : uern, weil die Zulassung schwierig ist. (지불하는 이유는 승인 절차가 어렵기 때문입니다....?)

input sequence: I am happy. Because studing is hard.
output sequence: uern, weil die Zulassung schwierig ist.
