In [1]:
import torch
import torch.nn as nn

better ReLU for optimization, GELU sits between two linear layers, usually a linear layer followed by a nonlinear activation function

In [2]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

In [3]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x): # receives input
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3)) # mathematical approx of GELU function
        ))


below is the expansion and contraction of the layer outputs in feed forward from small to large to small. that's to help the input extract information, then when the matrices are optimized it helps the LLM extract relevant useful info.

In [4]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), # small
            GELU(), # large
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), # small
        )

    def forward(self, x):
        return self.layers(x)

In [5]:
ff = FeedForward( GPT_CONFIG_124M )

In [6]:
input = torch.randn( 3, 4, 768 ) #batch of 3 samples, each has 4 tokens, each embedding is 768 dimensional

In [7]:
input.shape

torch.Size([3, 4, 768])

In [8]:
input

tensor([[[-0.1075, -0.2459, -1.0045,  ...,  0.6016, -1.5415,  1.0489],
         [-0.2028, -0.6508, -0.5252,  ..., -0.7364,  1.3857,  0.7752],
         [ 1.1746,  0.2382, -0.5163,  ..., -0.7314, -0.6422, -0.5772],
         [-0.9940, -0.0569,  0.2630,  ..., -0.6018,  0.8611, -1.6849]],

        [[-1.3929, -0.3243, -1.0600,  ...,  0.4747,  0.0769, -0.8110],
         [ 0.1953,  1.4797,  0.4379,  ..., -0.4202, -1.5305, -0.9016],
         [ 0.4767,  0.8204,  0.7494,  ...,  0.5691,  0.8258, -0.3055],
         [ 0.6394,  0.1500, -0.7287,  ..., -0.9913, -0.8217,  0.0419]],

        [[-0.7472,  2.0156,  0.8460,  ...,  1.3574, -0.2736, -0.3243],
         [-0.0990,  0.4752,  0.8166,  ..., -1.1673,  0.7969,  1.1714],
         [-0.6079, -0.4061, -2.6400,  ..., -0.6105, -0.5942,  0.4367],
         [-0.6790,  0.0346, -0.1611,  ..., -0.0938, -1.6204, -0.3449]]])

In [9]:
output = ff( input )

In [10]:
output.shape

torch.Size([3, 4, 768])

In [11]:
output

tensor([[[ 1.5231e-01,  9.3932e-02,  6.7874e-02,  ..., -2.8243e-01,
           7.9161e-02, -1.0680e-01],
         [ 2.9400e-01, -9.5439e-02,  6.9948e-02,  ..., -7.3149e-02,
          -1.6704e-03, -1.4875e-01],
         [ 1.3994e-01, -1.2535e-01,  6.5960e-02,  ..., -1.0853e-01,
          -1.3622e-01, -1.1611e-01],
         [-2.8352e-01,  8.8919e-02, -1.2949e-02,  ..., -6.5900e-02,
          -2.7742e-01,  9.6893e-02]],

        [[ 4.1891e-02,  4.9116e-03,  2.3661e-01,  ...,  4.7974e-02,
          -4.4154e-02,  1.4358e-01],
         [-8.6580e-02,  5.3038e-02, -9.2203e-02,  ...,  8.2365e-02,
           1.7490e-02,  1.2583e-01],
         [-2.3628e-01, -2.3842e-01,  1.7354e-01,  ..., -5.4518e-01,
           1.5104e-01, -5.0939e-01],
         [ 2.4275e-01,  3.9560e-01, -2.6566e-01,  ...,  2.0590e-01,
          -1.2851e-01, -3.7776e-02]],

        [[ 6.9482e-03,  1.7645e-01, -2.7216e-01,  ...,  1.1824e-01,
          -1.2129e-01, -2.2817e-02],
         [ 2.5770e-02, -3.7166e-02,  1.8667e-01,  .

In [12]:
output = ff( output )

In [13]:
output.shape

torch.Size([3, 4, 768])