<a href="https://colab.research.google.com/github/newmantic/GPT/blob/main/GPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, x):
        return self.embedding(x)

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        i = torch.arange(0, embed_size // 2).float()
        angle_rates = 1 / (10000 ** (2 * i / embed_size))
        self.encoding[:, 0::2] = torch.sin(pos * angle_rates)
        self.encoding[:, 1::2] = torch.cos(pos * angle_rates)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1), :]

In [4]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.keys = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.queries = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into multiple heads for multi-head attention
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )

        out = self.fc_out(out)
        return out

In [5]:
class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_hidden_size, dropout):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden_size)
        self.fc2 = nn.Linear(ff_hidden_size, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden_size, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_hidden_size, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.ff(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [7]:
class GPT(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len):
        super(GPT, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, embed_size)
        self.position_encoding = PositionalEncoding(embed_size, max_len)
        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, heads, ff_hidden_size, dropout) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask):
        out = self.token_embedding(x)
        out = self.position_encoding(out)

        for layer in self.layers:
            out = layer(out, out, out, mask)

        out = self.fc_out(out)
        return out

In [8]:
import torch

# Define a small vocab size and model parameters for the example
vocab_size = 10000
embed_size = 128
num_layers = 2
heads = 8
ff_hidden_size = 512
dropout = 0.3
max_len = 512

# Instantiate the GPT model
model = GPT(vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len)

# Example input: batch size of 2, sequence length of 10
x = torch.randint(0, vocab_size, (2, 10))

# No mask for simplicity in this example
mask = None

# Forward pass through the model
output = model(x, mask)

# Print the shapes of the input and output tensors to verify
print(f"Input (x): \n{x}")
print(f"Model Output: \n{output}")

Input (x): 
tensor([[9781, 5019, 8708,  102, 5421, 6176, 6697, 1499, 7350, 6237],
        [6452, 3634, 3411, 9905, 5510, 2496, 1891, 7042, 9769, 9932]])
Model Output: 
tensor([[[-0.5209, -0.1617, -0.4542,  ..., -0.3552,  0.2403, -0.4144],
         [ 0.7682,  0.1492,  0.8142,  ...,  0.9354, -0.8658, -1.2806],
         [ 1.0452, -0.3901, -0.1208,  ..., -0.2633, -1.0446,  0.7894],
         ...,
         [ 0.5350, -0.9450,  0.5673,  ...,  0.1257,  0.6041,  0.8342],
         [ 0.0353,  0.9506,  0.4805,  ..., -0.2886,  0.4188,  0.2447],
         [ 0.2212,  0.1105, -0.1346,  ..., -0.0510, -0.0761, -0.9258]],

        [[-0.2333, -1.0283,  0.3152,  ...,  1.1060,  0.1645,  0.4280],
         [ 1.1838,  0.6025, -0.8950,  ..., -0.4215, -0.5368, -0.5116],
         [-0.1613,  0.9625, -0.3746,  ...,  0.7593,  0.7765, -0.7956],
         ...,
         [-0.7567, -0.0351, -1.5185,  ..., -0.5424,  0.3597,  0.0461],
         [-1.4078,  0.8243, -0.0842,  ...,  0.3214, -0.1682,  0.0677],
         [-0.8614, -0