In [None]:
import torch
import torch.nn as nn

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        """
        Args:
            d_model (int): モデルの次元数
            d_ff (int): FFNの中間層の次元数
            dropout (float, optional): ドロップアウト率. Defaults to 0.1.
        """
        super().__init__()

        # 一層目 d_model -> d_ff
        self.w_1 = nn.Linear(d_model, d_ff)
        # 二層目 d_ff -> d_model
        self.w_2 = nn.Linear(d_ff, d_model)
        # ドロップアウト
        self.dropout = nn.Dropout(dropout)
        # 活性化関数 ReLU
        # 元論文ではReLUが使われているが、近年のLLMではGELUがよく使われている
        self.activation = nn.ReLU()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: [batch_size, seq_len, d_model]
        """
        # Linear -> ReLU -> Dropout -> Linear
        # x: [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_ff]
        hidden = self.activation(self.w_1(x))
        hidden = self.dropout(hidden)

        # [batch_size, seq_len, d_ff] -> [batch_size, seq_len, d_model]
        output = self.w_2(hidden)
        return output

In [None]:
def verify_ffn():
    batch_size = 2
    seq_len = 10
    d_model = 512
    d_ff = 2048

    ffn = PositionwiseFeedForward(d_model, d_ff)

    # ダミー入力
    x = torch.randn(batch_size, seq_len, d_model)
    output = ffn(x)

    print("Input shape:", x.shape)
    print("Output shape:", output.shape)

    assert output.shape == x.shape, "Output shape must match input shape"
    print("PositionwiseFeedForward verification passed!")

verify_ffn()