In [14]:
import torch
import torch.nn as nn

In [18]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

In [None]:
# a nonlinear activation function: GELU (Gaussian Error Linear Unit)
# non linear means the output is not directly proportional to the input

# takes in a tensor of a certain shape and return a tensor of the same shape
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))

In [None]:
# feed forward is an implementation of the feed forward neural network
# it consists of two linear layers with a GELU activation function in between
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

In [17]:
input = torch.randn( 3,768 ) # batch size of 3, embedding dimension of 768


In [21]:
ff = FeedForward( GPT_CONFIG_124M )
output = ff( input )

In [22]:
output.shape

torch.Size([3, 768])

In [23]:
output

tensor([[ 0.2758, -0.4343, -0.0090,  ...,  0.2568,  0.0276,  0.2992],
        [-0.3144,  0.2057,  0.2628,  ..., -0.1688,  0.0272,  0.0137],
        [-0.0425,  0.1827,  0.0589,  ..., -0.2293,  0.1868, -0.1039]],
       grad_fn=<AddmmBackward0>)

In [24]:
output = ff( output )

In [25]:
output.shape


torch.Size([3, 768])

In [26]:
output

tensor([[-0.0260, -0.0145, -0.0617,  ...,  0.0376,  0.0079, -0.0329],
        [-0.0266,  0.0184, -0.0061,  ..., -0.0358,  0.0233,  0.0525],
        [-0.0115, -0.0110,  0.0843,  ..., -0.0241, -0.0542, -0.0410]],
       grad_fn=<AddmmBackward0>)