# GPT
GPT consist of customized Transformer Decoder, instead of using normal self attention. In Transformer's Decoder, it's using masked self attention

In [1]:
import torch.nn as nn
import torch

In [26]:
class GPTTransformerDecoder(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int) -> None:
        super().__init__()
        self.multi_head_attention = nn.MultiheadAttention(embed_dim=embed_dim, 
                                                          num_heads=num_heads,
                                                          batch_first=True)
        
        self.layer_normalization_multi_head = nn.LayerNorm(embed_dim)
        
        self.layer_normalization_feed_forward = nn.LayerNorm(embed_dim)
        
        self.feed_froward = nn.Sequential(
            nn.Linear(embed_dim, 2048),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(2048, embed_dim)
        )
    
    def forward(self, x):
        # The shape should be (BATCH, NUM_TOKEN, EMBEDDING_SIZE)
        input_shape = x.shape
        
        original_input = torch.clone(x)
        
        query = torch.clone(x)
        key = torch.clone(x)
        value = torch.clone(x)
        
        # THIS IS WHERE GPT DIFFERENT
        # Create masking so that the current token only attend to current token and the previous token
        token_num = input_shape[1]
        masking = torch.tril(torch.ones(token_num,token_num))
        
        attn_output, _ = self.multi_head_attention(query, key, value, attn_mask=masking)

        add_attn_output = self.layer_normalization_multi_head(original_input + attn_output)
        feed_forward_output = self.feed_froward(add_attn_output)
        out = self.layer_normalization_feed_forward(add_attn_output + feed_forward_output)
        
        return out

In [27]:
class GPT(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        EMBEDDING_SIZE = 768
        NUM_HEAD_ATTENTION = 12
        self.model = nn.Sequential(
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
            GPTTransformerDecoder(EMBEDDING_SIZE, NUM_HEAD_ATTENTION),
        )
        
    def forward(self, x):
        return self.model(x)

In [28]:
EMBEDDING_SIZE = 768

input_example = torch.randn(1, 3, EMBEDDING_SIZE)
model = GPTTransformerDecoder(EMBEDDING_SIZE, 4)
model(input_example)

tensor([[[-0.6389, -0.6145,  0.4074,  ..., -0.1969, -1.6490, -0.5305],
         [ 0.8292,  0.8687,  2.4639,  ...,  0.2528,  0.4494, -0.0038],
         [ 0.5725,  1.1218, -0.8308,  ..., -0.7162,  0.8438, -2.5336]]],
       grad_fn=<NativeLayerNormBackward0>)

In [29]:
model_GPT = GPT()
model_GPT(input_example)

tensor([[[ 0.3060, -0.1759, -0.1900,  ..., -0.1140, -0.9538, -0.2425],
         [ 1.4414, -0.0882,  0.0853,  ...,  0.7449,  0.6197,  0.6292],
         [ 0.3551,  0.5918, -0.6042,  ..., -0.1452,  1.2090, -1.0654]]],
       grad_fn=<NativeLayerNormBackward0>)

In [30]:
model_GPT

GPT(
  (model): Sequential(
    (0): GPTTransformerDecoder(
      (multi_head_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (layer_normalization_multi_head): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layer_normalization_feed_forward): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (feed_froward): Sequential(
        (0): Linear(in_features=768, out_features=2048, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=2048, out_features=768, bias=True)
      )
    )
    (1): GPTTransformerDecoder(
      (multi_head_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (layer_normalization_multi_head): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layer_normalization_feed_forward): LayerNorm((768,), eps=1e-