In [None]:
# 라이브러리 임포트 
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# 트랜스포머 모델
# 1.위치 정보 전달을 위한 정적 포지셔널 인코딩
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # even index
        pe[:, 1::2] = torch.cos(position * div_term)  # odd index
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return x


# 2.시계열 학습을 위한 트랜스포머 회귀 모델
class TransformerPredictor(nn.Module):
    def __init__(self, input_size=5, output_size=4, d_model=256, nhead=16, num_layers=3, dim_feedforward=1024, dropout=0.2, use_attention_pool=True):
        super(TransformerPredictor, self).__init__()
        # Attention pooling ------------------------------------------------------------------
        self.use_attention_pool = use_attention_pool
        if self.use_attention_pool:
            self.attn_pool = nn.Sequential(
                nn.Linear(d_model, 128),
                nn.Tanh(),
                nn.Linear(128, 1)  # 각 time step에 대한 score 출력
            )
        # --------------------------------------------------------------------------------------
        # Input Projection ---------------------------------------------------------------------
        self.input_proj = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.GELU(),
            nn.Linear(256, d_model)
        )
        # --------------------------------------------------------------------------------------
        # Encoder ------------------------------------------------------------------------------
        self.pos_encoder = PositionalEncoding(d_model) # 포지셔널 인코딩을 통해 순서 정보를 추가
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, # input의 특징들
            nhead=nhead, # 멀티 헤드 어텐션 헤드 수
            dim_feedforward=dim_feedforward, # FFN 차원 수, 기본 2048
            dropout=dropout,
            batch_first=True,
            activation="gelu", # default="relu"
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        # --------------------------------------------------------------------------------------
        # Decoder ------------------------------------------------------------------------------
        self.decoder = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.LayerNorm(256),
            nn.GELU(),

            nn.Linear(256, 256),
            nn.LayerNorm(256),
            nn.GELU(),

            nn.Linear(256, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            
            nn.Linear(256, output_size)
        )
        # --------------------------------------------------------------------------------------
        
    def forward(self, x):
        x = self.input_proj(x)  
        x = self.pos_encoder(x)  
        x = self.transformer_encoder(x)  
        
        if self.use_attention_pool:
            # Attention score 계산
            attn_weights = self.attn_pool(x)  
            attn_weights = torch.softmax(attn_weights, dim=1)  
            x_last = (attn_weights * x).sum(dim=1) 
        else:
            # Mean Pooling
            x_last = x[:, -2:, :].mean(dim=1)
            
        out = self.decoder(x_last)  
        out = out.unsqueeze(1)  
        return out