#### SwiGLU FFN implementation (using GPT-2 config example)
> Taken from: https://github.com/garg-aayush/building-from-scratch/blob/main/gpt-2/play-nbs/swiglu.ipynb

In [None]:
from dataclasses import dataclass

import torch.nn as nn
import torch.nn.functional as F

In [None]:
@dataclass
class GPTConfig:
    block_size: int = 1024  # max seq. length
    vocab_size: int = 50257  # num. of tokens: 50,000 merges + 256 byte pieces + 1 <endoftext> token
    n_layer: int = 12  # number of layers
    n_embd: int = 768  # embedding dimension
    n_head: int = 12  # number of attention heads

class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

In [None]:
ffn = MLP(GPTConfig())
# get number of parameters
num_params = sum(p.numel() for p in ffn.parameters())
print(f"Number of parameters: {num_params / 1e6:.4f}M")
ffn


In [None]:
# SwiGLU: https://arxiv.org/pdf/2002.05202
class SwiGLU(nn.Module):
    def __init__(self, config: GPTConfig, factor: float = 8/3):
        super().__init__()
        # Two linear projections (for swiglu)
        self.c_fc1 = nn.Linear(config.n_embd, int(config.n_embd * factor))
        self.c_fc2 = nn.Linear(config.n_embd, int(config.n_embd * factor))
        # Output projection back to input_dim
        self.c_proj = nn.Linear(int(config.n_embd * factor), config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        # SwiGLU: ((xW1) * swish(xW2)) * W3
        x = self.c_fc1(x)
        gate = F.silu(self.c_fc2(x))
        x = self.c_proj(x * gate)
        return x


In [None]:
ffn_new = SwiGLU(GPTConfig(), factor=4)
num_params = sum(p.numel() for p in ffn_new.parameters())
print(f"Number of parameters: {num_params / 1e6:.4f}M")
ffn_new
# Here, the 


In [None]:
# keep param counts similar to the GELU FFN, follow PaLM/LLaMA practice and set: factor = 8/3
ffn_new = SwiGLU(GPTConfig(), factor=8/3)
num_params = sum(p.numel() for p in ffn_new.parameters())
print(f"Number of parameters: {num_params / 1e6:.4f}M")
ffn_new
