In [2]:
from dataclasses import dataclass
import tiktoken
from torchsummary import summary

import torch
import torch.nn as nn
import torch.nn.functional as F

<b>1. Pre-Processing

1.1 Tokenisation + Encoding

In [3]:
Encoder = tiktoken.get_encoding('gpt2')
Encoder.encode('Hello world!')

[15496, 995, 0]

1.2 Embeddings

In [4]:
class Embedder(nn.Module):
    def __init__(self, config):
        super(Embedder, self).__init__()

        self.token_embedder = nn.Embedding(config.vocab_size, config.embed_dim)
        self.position_embedder = nn.Embedding(config.sequence_length, config.embed_dim)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, idx):
        batch_size, sequence_length = idx.size() if idx.dim == 2 else 1 , idx.size()[0]
        pos = torch.arange(0, sequence_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)

        token_embeddings = self.token_embedder(idx)
        positional_embeddings = self.position_embedder(pos)

        out = self.dropout(token_embeddings + positional_embeddings)

        return out

<b>2. Transformer

2.1 Multi-Head Self Attention (MHSA)

In [5]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadSelfAttention, self).__init__()

        self.embed_dim = config.embed_dim
        self.num_heads = config.num_heads
        self.qkv_bias = config.qkv_bias
        self.head_dim = self.embed_dim // self.num_heads

        self.query = nn.Linear(self.embed_dim, self.embed_dim, bias=self.qkv_bias)
        self.key = nn.Linear(self.embed_dim, self.embed_dim, bias=self.qkv_bias)
        self.value = nn.Linear(self.embed_dim, self.embed_dim, bias=self.qkv_bias)

        self.fc = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        sequence_length = x.shape[1]

        #Reshape for q,k,v calculation
        #q = x.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        #k = x.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        #v = x.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)

        #Calculation of q,k,v
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        q = q.view(batch_size, sequence_length, self.num_heads, self.head_dim)
        k = k.view(batch_size, sequence_length, self.num_heads, self.head_dim)
        v = v.view(batch_size, sequence_length, self.num_heads, self.head_dim)

        #Reshape for Attention Calculation
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 3, 1)
        v = v.permute(0, 2, 1, 3)

        #Calculate Scaled dot-product Attention
        attention = torch.matmul(q, k)
        attention = F.softmax(attention, dim=-1)
        attention = torch.matmul(attention, v)

        #Reshape/Flatten for FFNN 
        out = attention.permute(0,2,1,3).contiguous().view(batch_size, sequence_length, -1)
        out = self.fc(out)

        return out

2.2 Feed Forward Neural Network (FFNN)

In [6]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(self, config):
        super(FeedForwardNeuralNetwork, self).__init__()

        self.expand = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=config.ffnn_bias)
        self.gelu = nn.GELU()
        self.project = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=config.ffnn_bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self,x):
        x = self.expand(x)
        x = self.gelu(x)
        x = self.project(x)
        x = self.dropout(x)
        return x

2.3 Layer Normalisation (LayerNorm)

In [7]:
class LayerNorm(nn.Module):
    def __init__(self, config):
        super(LayerNorm, self).__init__()

        self.weight = nn.Parameter(torch.ones(config.embed_dim))
        self.bias = nn.Parameter(torch.zeros(config.embed_dim)) if config.layernorm_bias else None

    def forward(self, x):
        out = F.layer_norm(x, self.weight.shape, self.weight, self.bias)
        return out

Complete Transformer

In [8]:
class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        self.MHSA = MultiHeadSelfAttention(config)
        self.FFNN = FeedForwardNeuralNetwork(config)
        self.LN1 = LayerNorm(config)
        self.LN2 = LayerNorm(config)

    def forward(self, x):
        x = x + self.MHSA(x)
        x = self.LN1(x)
        x = x + self.FFNN(x)
        x = self.LN2(x)
        return x

<b>3. Post-Processing

3.1 Transformer Output to Probabilities

In [9]:
class PostTransformerLayers(nn.Module):
    def __init__(self, config):
        super(PostTransformerLayers, self).__init__()

        self.fc = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
    
    def forward(self, x):
        x = self.fc(x)
        x = F.softmax(x, dim=-1)
        return x

3.2 Next Token Choice

In [10]:
def PostProcess(probs, topK, temperature):
    sorted_probs, sorted_indices = torch.topk(probs, topK)
    flattened_probs, flattened_indices = sorted_probs.view(-1), sorted_indices.view(-1)
    next_token = torch.multinomial(torch.softmax(flattened_probs / temperature, dim=0), 1)
    next_token = flattened_indices[next_token]
    return next_token

3.3 Decode

In [11]:
Encoder.decode([15496, 995, 0])

'Hello world!'

<b>4. Config Format

In [12]:
@dataclass
class pashkoConfig:
    sequence_length: int = 1024
    vocab_size: int = 50304
    embed_dim: int = 768

    encoder = 'gpt2'

    batch_size: int = 64

    num_heads: int = 12
    num_blocks: int = 12

    dropout: float = 0.0

    ffnn_bias: bool = False
    qkv_bias: bool = False
    layernorm_bias = False

    topK: int = 10
    temperature: float = 1.0

<b> 5. Complete GPT

In [13]:
class pashko(nn.Module):
    def __init__(self, config):
        super(pashko, self).__init__()

        self.Encoder = tiktoken.get_encoding('gpt2')
        self.Embedder = Embedder(config)
        
        self.Blocks = [Block(config) for _ in range(config.num_blocks)]
        for i, block in enumerate(self.Blocks):
            self.add_module(f'Transformer Block {i}', block)

        self.PostTransformer = PostTransformerLayers(config)

        self.LossFunction = nn.CrossEntropyLoss()

        self.config = config

        #Weight Initialisation according to GPT2 Paper
        self.apply(self.init_weights)

        #Weight tying Embedding to Final Linear
        self.Embedder.token_embedder.weight = self.PostTransformer.fc.weight

    #Different types of generation.
    def forward(self, x, targets=None):
        x = self.Embedder(x) #Embeddings

        for Transformer in self.Blocks: #Transformer Block
            x = Transformer(x)
        
        x = self.PostTransformer(x) #Post-Transformer Layers

        x = x.view(-1, self.config.vocab_size) #Loss Calculation
        targets = targets.view(-1)
        loss = self.LossFunction(x, targets)
        return x, loss
    
    @torch.no_grad()
    def inference(self, x):
        x = self.Embedder(x) #Embeddings

        for Transformer in self.Blocks: #Transformer Block
            x = Transformer(x)
        
        x = self.PostTransformer(x) #Post-Transformer Layers
        token = PostProcess(x, self.config.topK, self.config.temperature)
        return token
    
    @torch.no_grad()
    def generate(self, context, max_new_tokens=64, show=True):
        response = []

        x = torch.LongTensor(Encoder.encode(context))

        for _ in range(max_new_tokens):
            x = x if x.size(0) <= self.config.sequence_length else x[:, -self.config.sequence_length:]

            next_token = self.inference(x)

            response.append(next_token.numpy()[0])

            if show:
                print(self.Encoder.decode(next_token.numpy()), end='', flush=True)

            x = torch.cat((x, next_token), dim=0)
    #Utils
    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def num_params(self, embedding=False):
        num_params = sum(p.numel() for p in self.parameters())
        if not embedding:
            num_params -= self.Embedder.position_embedder.weight.numel()
        return "{:.2f}M".format(num_params / 1000000), num_params

<b>6. Model Summary

In [14]:
config = pashkoConfig()
Pashko = pashko(config)

In [15]:
Pashko.num_params() #Correct Number of Parameters due to weight tying.

('123.60M', 123595776)

Model Summary

In [16]:
summary(Pashko, input_size=(config.batch_size,config.sequence_length))

Layer (type:depth-idx)                        Param #
├─Embedder: 1-1                               --
|    └─Embedding: 2-1                         38,633,472
|    └─Embedding: 2-2                         786,432
|    └─Dropout: 2-3                           --
├─Block: 1-2                                  --
|    └─MultiHeadSelfAttention: 2-4            --
|    |    └─Linear: 3-1                       589,824
|    |    └─Linear: 3-2                       589,824
|    |    └─Linear: 3-3                       589,824
|    |    └─Linear: 3-4                       590,592
|    └─FeedForwardNeuralNetwork: 2-5          --
|    |    └─Linear: 3-5                       2,359,296
|    |    └─GELU: 3-6                         --
|    |    └─Linear: 3-7                       2,359,296
|    |    └─Dropout: 3-8                      --
|    └─LayerNorm: 2-6                         768
|    └─LayerNorm: 2-7                         768
├─Block: 1-3                                  --
|    └─MultiHea

Layer (type:depth-idx)                        Param #
├─Embedder: 1-1                               --
|    └─Embedding: 2-1                         38,633,472
|    └─Embedding: 2-2                         786,432
|    └─Dropout: 2-3                           --
├─Block: 1-2                                  --
|    └─MultiHeadSelfAttention: 2-4            --
|    |    └─Linear: 3-1                       589,824
|    |    └─Linear: 3-2                       589,824
|    |    └─Linear: 3-3                       589,824
|    |    └─Linear: 3-4                       590,592
|    └─FeedForwardNeuralNetwork: 2-5          --
|    |    └─Linear: 3-5                       2,359,296
|    |    └─GELU: 3-6                         --
|    |    └─Linear: 3-7                       2,359,296
|    |    └─Dropout: 3-8                      --
|    └─LayerNorm: 2-6                         768
|    └─LayerNorm: 2-7                         768
├─Block: 1-3                                  --
|    └─MultiHea