In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Hyperparameters
batch_size = 2 #8
block_size = 48 #1024
dimension = 36 #768
num_heads = 12
n_layers = 4 #12
max_epochs = 2000 #2500
lr = 1e-3
dropout = 0.1
topk = 50
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
torch.manual_seed(99)
DEBUG = True

In [3]:
with open('shakespeare.txt', 'r') as f:
    text = f.read()
print("Length of dataset in characters:", len(text))
print(f"First {50} characters: {text[:50]}")

Length of dataset in characters: 1115394
First 50 characters: First Citizen:
Before we proceed any further, hear


In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"All unique characters sorted: {''.join(chars)}")
print(f"Number of unique characters: {len(chars)}")

All unique characters sorted: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Number of unique characters: 65


In [5]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda ch : [stoi[c] for c in ch]
decode = lambda ix : ''.join([itos[i] for i in ix])
print(encode('hi there'))
print(decode(encode('hi there')))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [6]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:10])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [7]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
print(train_data.shape)
print(val_data.shape)

torch.Size([1003854])
torch.Size([111540])


In [8]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'when context is {context}, next token is: {target}')

when context is tensor([18]), next token is: 47
when context is tensor([18, 47]), next token is: 56
when context is tensor([18, 47, 56]), next token is: 57
when context is tensor([18, 47, 56, 57]), next token is: 58
when context is tensor([18, 47, 56, 57, 58]), next token is: 1
when context is tensor([18, 47, 56, 57, 58,  1]), next token is: 15
when context is tensor([18, 47, 56, 57, 58,  1, 15]), next token is: 47
when context is tensor([18, 47, 56, 57, 58,  1, 15, 47]), next token is: 58
when context is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58]), next token is: 47
when context is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47]), next token is: 64
when context is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64]), next token is: 43
when context is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43]), next token is: 52
when context is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52]), next token is: 10
when context is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 6

In [9]:
def get_batch(data_source='train'):
    source = train_data if data_source == 'train' else val_data 
    ix = torch.randint(low=0, high=len(source) - block_size, size=(batch_size,))
    x = torch.stack([source[i:i+block_size] for i in ix])
    y = torch.stack([source[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

xb, yb = get_batch('train')
print('batch - x:', xb.shape)
print(xb)
print('batch - y:', yb.shape)
print(yb)

batch - x: torch.Size([2, 48])
tensor([[63,  8,  0,  0, 18, 47, 56, 57, 58,  1, 31, 43, 56, 60, 39, 52, 58, 10,
          0, 37, 53, 59,  1, 39, 56, 43,  1, 50, 53, 53, 49, 43, 42,  1, 44, 53,
         56,  1, 39, 52, 42,  1, 41, 39, 50, 50, 43, 42],
        [59, 42, 45, 51, 43, 52, 58,  1, 60, 39, 52, 47, 57, 46,  5, 42,  1, 44,
         56, 53, 51,  1, 46, 47, 57,  1, 50, 47, 54, 57,  6,  0, 26, 53, 58,  1,
         40, 53, 42, 63,  5, 57,  1, 42, 43, 39, 58, 46]])
batch - y: torch.Size([2, 48])
tensor([[ 8,  0,  0, 18, 47, 56, 57, 58,  1, 31, 43, 56, 60, 39, 52, 58, 10,  0,
         37, 53, 59,  1, 39, 56, 43,  1, 50, 53, 53, 49, 43, 42,  1, 44, 53, 56,
          1, 39, 52, 42,  1, 41, 39, 50, 50, 43, 42,  1],
        [42, 45, 51, 43, 52, 58,  1, 60, 39, 52, 47, 57, 46,  5, 42,  1, 44, 56,
         53, 51,  1, 46, 47, 57,  1, 50, 47, 54, 57,  6,  0, 26, 53, 58,  1, 40,
         53, 42, 63,  5, 57,  1, 42, 43, 39, 58, 46,  6]])


In [11]:
class AttentionHead(torch.nn.Module):
    def __init__(self, dimension_head):
        super(AttentionHead, self).__init__()
        self.query_layer = nn.Linear(dimension, dimension_head, bias=False)
        self.key_layer = nn.Linear(dimension, dimension_head, bias=False)
        self.value_layer = nn.Linear(dimension, dimension_head, bias=False)
        self.scale = dimension_head ** 0.5
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size, device=device)))

    def forward(self, x):
        _, T, _ = x.shape
        assert T <= block_size, f"Input length exceeds block size {block_size}."
        Q = self.query_layer(x)
        K = self.key_layer(x)
        V = self.value_layer(x)

        K_T = K.transpose(1, 2)
        attention_scores = Q @ K_T
        attention_scores = attention_scores / self.scale
        attention_scores = attention_scores.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        attention_weights = F.softmax(attention_scores, dim=-1)
        output = attention_weights @ V   

        return output


class MultiHeadAttention(torch.nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        assert dimension % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        dimension_head = dimension // num_heads  
        self.heads = nn.ModuleList([AttentionHead(dimension_head) for _ in range(num_heads)])
        self.proj = nn.Linear(dimension, dimension)
        
    def forward(self, x):
        head_outputs = [head(x) for head in self.heads]                
        output = torch.cat(head_outputs, dim=-1)
        output = self.proj(output)                                       
        return output


class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dimension, 4*dimension,),
            nn.GELU(approximate="tanh"),
            nn.Linear(4*dimension, dimension)
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(dimension)
        self.sa = MultiHeadAttention()
        self.ln2 = nn.LayerNorm(dimension)
        self.ff = FeedForward()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.sa(self.ln1(x)))
        x = x + self.ff(self.ln2(x))
        return x


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, dimension)
        self.position_embedding_table = nn.Embedding(block_size, dimension)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layers)])
        self.lm_head = nn.Linear(dimension, vocab_size)

    def forward(self, idx, targets=None):
        _, T_idx = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos = torch.arange(T_idx, device=device)
        pos_emb = self.position_embedding_table(pos)
        pos_emb = pos_emb.unsqueeze(0)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            logits_flattened = logits.view(-1, logits.size(-1))
            targets = targets.view(-1)
            loss = F.cross_entropy(logits_flattened, targets)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]                              
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            topk_probs, topk_indices = torch.topk(probs, k=topk, dim=-1)
            idx_next = torch.multinomial(topk_probs, num_samples=1)
            xcol = torch.gather(topk_indices, dim=-1, index=idx_next)
            idx = torch.cat((idx, xcol), dim=1)
        return idx

In [12]:
model = Model().to(device)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)

torch.Size([2, 48, 65])
tensor(4.5552, grad_fn=<NllLossBackward0>)


In [13]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
generated = model.generate(idx, 50)[0].tolist()
print(decode(generated))



;AqDAJdM'XH!p&&tuLARVdmrijt
Db;k&BpxoO,op'3dcqeRAo


In [14]:
def train():
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    for epoch in range(max_epochs):
        xb, yb = get_batch('train')
        _, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if epoch == 0 or (epoch + 1) % 500 == 0:
            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

In [15]:
train()

Epoch 1, Loss: 4.518505573272705
Epoch 500, Loss: 2.4197661876678467
Epoch 1000, Loss: 2.6544950008392334
Epoch 1500, Loss: 2.662320137023926
Epoch 2000, Loss: 2.6340301036834717


In [97]:
model_state_path = 'gpt-char-level.pth'
torch.save(model.state_dict(), model_state_path)

In [98]:
state_dict = torch.load(model_state_path, map_location="cpu")
state_dict

OrderedDict([('token_embedding_table.weight',
              tensor([[ 1.1282,  0.0148,  1.0783,  ...,  0.9926, -0.4567, -1.9195],
                      [ 0.0685,  1.2233, -1.0039,  ...,  1.5837,  0.0771,  1.3063],
                      [-0.6846, -0.6239, -0.7813,  ..., -1.8358, -0.9043, -1.3798],
                      ...,
                      [-1.2493, -0.0464, -0.1076,  ...,  0.9245, -0.3058,  0.5545],
                      [-0.9794, -1.2060, -1.0753,  ..., -1.2805, -1.4240,  1.2248],
                      [ 1.1210,  1.2918,  0.8996,  ...,  0.7433, -0.3463, -0.7676]])),
             ('position_embedding_table.weight',
              tensor([[-1.7494e+00,  1.6451e+00, -1.1817e+00, -6.3707e-01, -1.4959e+00,
                        1.5132e+00,  7.4200e-01, -1.2802e-01,  2.6503e-02,  8.3261e-01,
                       -4.8534e-01, -1.2051e+00,  1.2452e+00, -1.2536e+00,  4.0152e-02,
                        8.3950e-01,  5.9007e-01,  6.0966e-01, -7.0797e-01, -3.8498e-01,
                   

In [234]:
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params}")

Number of parameters: 17873


In [None]:
idx = get_batch('val')[0]
print("prompt:")
print(decode(idx[0].tolist()))
print("generated:")
print(decode(model.generate(idx, 50)[0].tolist()))

prompt:
NIO:
If but one of his pockets could speak, woul
generated:
NIO:
If but one of his pockets could speak, woulochen!
As or ou Anooma f thin sthe crafeen, a dort
