In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 1. 포지셔널 인코딩: 단어의 순서 정보를 벡터에 더해줌
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model) # [max_len, d_model] 크기의 0 행렬 생성
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term) # 짝수 인덱스는 sin
        pe[:, 1::2] = torch.cos(position * div_term) # 홀수 인덱스는 cos
        self.register_buffer('pe', pe.unsqueeze(0)) # 학습되지 않는 파라미터로 등록

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]에 위치 정보를 더함
        return x + self.pe[:, :x.size(1), :]

# 2. 멀티 헤드 어텐션: 여러 관점에서 문장을 분석
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads # 각 헤드의 차원
        self.w_q = nn.Linear(d_model, d_model) # Query 생성용
        self.w_k = nn.Linear(d_model, d_model) # Key 생성용
        self.w_v = nn.Linear(d_model, d_model) # Value 생성용
        self.fc = nn.Linear(d_model, d_model) # 최종 출력 결합용

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        # 선형 변환 후 헤드 수만큼 쪼갬: [batch, heads, seq_len, d_k]
        q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention 계산
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        
        # 가중합 계산 후 원래 차원으로 복구
        context = torch.matmul(attn, v).transpose(1, 2).contiguous()
        return self.fc(context.view(batch_size, -1, self.n_heads * self.d_k))

# 3. 피드 포워드 네트워크: 각 단어 벡터를 개별적으로 정제
class PoswiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff) # 차원 확장 (보통 4배)
        self.fc2 = nn.Linear(d_ff, d_model) # 다시 원래 차원으로 축소

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

# 4. 인코더 레이어: Attention + FFN
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, n_heads)
        self.ffn = PoswiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNormalization(d_model)
        self.norm2 = nn.LayerNormalization(d_model)

    def forward(self, x, mask):
        x = self.norm1(x + self.mha(x, x, x, mask)) # Self-Attention & Residual
        x = self.norm2(x + self.ffn(x)) # FFN & Residual
        return x

# 5. 디코더 레이어: Masked Self-Attn + Encoder-Decoder Attn + FFN
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super(DecoderLayer, self).__init__()
        self.masked_mha = MultiHeadAttention(d_model, n_heads)
        self.cross_mha = MultiHeadAttention(d_model, n_heads)
        self.ffn = PoswiseFeedForward(d_model, d_ff)
        self.norm1, self.norm2, self.norm3 = [nn.LayerNormalization(d_model) for _ in range(3)]

    def forward(self, x, enc_out, self_mask, cross_mask):
        x = self.norm1(x + self.masked_mha(x, x, x, self_mask)) # 1. Masked Self-Attn
        x = self.norm2(x + self.cross_mha(x, enc_out, enc_out, cross_mask)) # 2. Cross-Attn
        x = self.norm3(x + self.ffn(x)) # 3. FFN
        return x

# 6. 최종 Transformer 모델
class Transformer(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model, n_layers, n_heads, d_ff, max_len):
        super(Transformer, self).__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model)
        self.trg_emb = nn.Embedding(trg_vocab, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        
        # 인코더와 디코더를 지정된 층(n_layers)만큼 쌓음
        self.encoder = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.fc_out = nn.Linear(d_model, trg_vocab) # 마지막 단어 예측 출력층

    def forward(self, src, trg, src_mask, trg_mask):
        # 1. 인코더 통과
        enc_out = self.pos_enc(self.src_emb(src))
        for layer in self.encoder: enc_out = layer(enc_out, src_mask)
        
        # 2. 디코더 통과 (인코더 결과 enc_out을 참고함)
        dec_out = self.pos_enc(self.trg_emb(trg))
        for layer in self.decoder: dec_out = layer(dec_out, enc_out, trg_mask, src_mask)
        
        return self.fc_out(dec_out) # [batch, seq_len, trg_vocab]