In [1]:
import torch
import torch.nn as nn

In [2]:
class Transformer(nn.Module):
    # Transformer (GPT-2 architecture)
    def __init__(self, embed_dim, hidden_dim, 
                 num_embed, num_pos, num_heads, num_layers, dropout):
        super().__init__()
        self.token_embeddings = nn.Embedding(num_embed, embed_dim)
        self.position_embedings = nn.Embedding(num_pos, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
        self.ln_1, self.ln_2 = nn.ModuleList(), nn.ModuleList()  # layer norm
        
        for _ in range(num_layers):
            # Multi Head Attention 모듈을 불러와서 추가
            self.attentions.append(nn.ModuleList(
                    nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)))
            # Feed forward layer 추가
            self.feed_forwards.append(nn.Sequential(
                    nn.Linear(embed_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, embed_dim)))
            self.ln_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
            self.ln_2.append(nn.LayerNorm(embed_dim, eps=1e-12))
            
    def forward(self, x):
        # token, position embedding
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)  # [len(x)] -> [len(x), 1]
        h = self.token_embeddings(x)
        h = h + self.position_embedings(position).expand_as(h)
        h = self.dropout(h)
        # attention mask를 씌우기 위해 (len(x), len(x)) 사이즈의, -Inf 행렬을 만든다
#         tensor([[-inf, -inf, -inf],
#                 [-inf, -inf, -inf],
#                 [-inf, -inf, -inf]])
        attn_mask = torch.full((len(x), len(x)), -float('Inf'), 
                               device=h.device, dtype = h.dtype)
        # 길이만큼은 attention 이 영향받도록 0, padding 부분은 -Inf로 배정한다
#         tensor([[0., -inf, -inf],
#                 [0., 0., -inf],
#                 [0., 0., 0.]])
        attn_mask = torch.triu(attn_mask, diagonal=1)
        
        # for x, y in zip(a, b)
#         a = [1,2,3,4,5]
#         b = ['a','b','c','d','e']
#         (x, y) = (1, 'a'), (2, 'b'), (3, 'c'), ...
        for layer_norm_1, attention, layer_norm_2, feed_forward in zip(
            self.ln_1, self.attentions, self.ln_2, self.feed_forwards):
            
            h = layer_norm_1(h)
            x, _ = attention(h, h, h, attn_mask=attn_mask, need_weight=False)  # [target length, batch size, embed dim] 
            x = self.dropout(x)
            h = x + h  # residual connection
            
            h = layer_norm_1(h)
            x = feed_forward(h)
            x = self.dropout(x)
            h = x + h  # residual connection
            
        return h