In this notebook we will implement the GPT-2 model

Architecture:

- Token & Positional Embeddings
- Dropout
- Transformer
- Normalization
- Logits

In [6]:
import torch
import torch.nn as nn

In [None]:
class DummyLayerNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

class DummyTransformer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

In [41]:
class DummyGPT(nn.Module):
    def __init__(self, cfg):
        super.__init__()
        # Dropout
        self.dropout = nn.Dropout(cfg["dropRate"])
        # Token embedding
        self.tokenEmbed = nn.Embedding(cfg["vocabSize"], cfg["embedDim"])
        # Positional embedding
        self.posEmbed = nn.Embedding(cfg["contextLength"], cfg["embedDim"])
        # Normalization
        self.normLayer = DummyLayerNorm()
        # Transformer Blocks
        self.trfBlocks = nn.Sequential([
            DummyTransformer() for _ in range(cfg["nLayers"])
        ])
        # Linear projection.
        self.outLayer = nn.Linear(cfg["embedDim"], cfg["vocabSize"], bias=False)

    def forward(self, x):
        b, seqLen = x.shape
        tokEmbeds = self.tokenEmbed(x)
        posEmbeds = self.posEmbed(torch.arange(seqLen, device=x.device))
        x = tokEmbeds + posEmbeds
        x = self.dropout(x)
        x = self.normLayer(x)
        for l in self.trfBlocks:
            x = l(x)
        return self.outLayer(x)

In [42]:
cfg = {
    "vocabSize": 52048,
    "contextLength": 1024, 
    "embedDim": 768,
    "nLayers" : 12,
    "dropRate": 0.1,
}

In [43]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")
batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"
batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)
print(batch)

tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])


In [44]:
DummyGPT(cfg)

TypeError: descriptor '__init__' of 'super' object needs an argument