In [1]:
"""
Full definition of a GPT Language Model, all of it in this single file.
References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
"""

'\nFull definition of a GPT Language Model, all of it in this single file.\nReferences:\n1) the official GPT-2 TensorFlow implementation released by OpenAI:\nhttps://github.com/openai/gpt-2/blob/master/src/model.py\n2) huggingface/transformers PyTorch implementation:\nhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py\n'

In [2]:
import math
from dataclasses import dataclass
import inspect

import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
text[:8]

'First Ci'

In [4]:
len(text)

1115394

In [5]:
chars = sorted(list(set(''.join(text))))
stoi = {c : i for i, c in enumerate(chars)}
itos = {i : c for c, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
vocab_size = len(itos)
print(itos)
print(vocab_size)

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z'}
65


In [6]:
# train and test split
data = torch.tensor(encode(text), dtype=torch.long)
n_train = int(0.9*len(data))

train_data = data[:n_train]       # 90%
eval_data = data[n_train:]         # 10%
data[:100]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

In [7]:
# hyperparameters
block_size = 32
batch_size = 16 # how many independent sequences will we process in parallel?
epochs = 1000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200

In [8]:
torch.manual_seed(1337)

<torch._C.Generator at 0x210ae9e2270>

In [9]:
def get_batch(split):
    # generate small batch of data of inputs data X and targets y
    data = train_data if split == 'train' else eval_data
    ix = torch.randint(len(data)-block_size, (batch_size,))
    X = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    X, y = X.to(device), y.to(device)
    return X, y

In [10]:
X, y = get_batch('eval')
X[:2], y[:2]

(tensor([[59, 56, 43, 42,  1, 51, 43,  1, 57, 39, 63,  1, 51, 63,  1, 51, 47, 52,
          42,  6,  0, 13, 52, 42,  1, 47, 44,  1, 63, 53, 59,  1],
         [47, 58,  1, 52, 53, 58,  1, 58, 46, 43, 52,  1, 53, 59, 56,  1, 43, 63,
          43, 50, 47, 42, 57,  1, 57, 47, 52, 49, 12,  1, 21,  1]]),
 tensor([[56, 43, 42,  1, 51, 43,  1, 57, 39, 63,  1, 51, 63,  1, 51, 47, 52, 42,
           6,  0, 13, 52, 42,  1, 47, 44,  1, 63, 53, 59,  1, 41],
         [58,  1, 52, 53, 58,  1, 58, 46, 43, 52,  1, 53, 59, 56,  1, 43, 63, 43,
          50, 47, 42, 57,  1, 57, 47, 52, 49, 12,  1, 21,  1, 44]]))

In [11]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'eval']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, y = get_batch(split)
            logits, loss = model(X, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [12]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias, PyTorch doesn't support simply bias=False"""
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
    
    def forward(self, X):
        return F.layer_norm(X, self.weight.shape, self.weight, self.bias, 1e-5)


@dataclass
class GPTConfig:
#     block_size: int = 1024
#     vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
#     n_layer: int = 12
#     n_head: int = 12
#     n_embd: int = 768
#     dropout: float = 0.0
#     bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
    block_size: int = block_size
    vocab_size: int = vocab_size
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 64
    dropout: float = 0.0
    bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) # 3 means key,query,value concatenate
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            #causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer('bias', torch.tril(torch.ones(config.block_size, config.block_size))
                                         .view(1, 1, config.block_size, config.block_size))
        
    def forward(self, X):
        B, T, C = X.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(X).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nhs, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # output prjection
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) # in transformer paper, the dimension is 512 and projectino to 2048, so it's 4 times
        self.gelu    = nn.GELU()
        self.c_proj  = nn. Linear(4 * config.n_embd, config.n_embd, bias=config.bias) # projection the 4 times dimension back to dimension
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, X):
        X = self.c_fc(X)
        X = self.gelu(X)
        X = self.c_proj(X)
        X = self.dropout(X)
        return X


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)
    
    def forward(self, X):
        X = X + self.attn(self.ln_1(X)) # + means residual connection
        X = X + self.mlp(self.ln_2(X)) # + means residual connection
        return X


class NanoGPTLanguageModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            token_embedding = nn.Embedding(config.vocab_size, config.n_embd), # (vocab_size, C)
            position_embedding = nn.Embedding(config.block_size, config.n_embd), # (T, C)
            dropout = nn.Dropout(config.dropout),
            blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias), # final layer norm
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        self.transformer.token_embedding.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
        
        # init all weights
        self.apply(self._init_weights)
        # apply sepcial scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
        
        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
    
    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get substracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.position_embedding.weight.numel()
        return n_params
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, X, y=None):
        device = X.device
        # X and y are both (B, T) tensor integers, B = batch_size, T = block_size
        B, T = X.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        pos = torch.arange(0, T, dtype=torch.long, device=device) # shape (T)
        
        # forward the GPT model itself
        tok_emb = self.transformer.token_embedding(X) # (B, T, C)
        pos_emb = self.transformer.position_embedding(pos) # (T, C)
        X = self.transformer.dropout(tok_emb + pos_emb)
        for block in self.transformer.blocks:
            X = block(X) # (B, T, C)
        X = self.transformer.ln_f(X)   # (B, T, C)
        
        if y is not None:
            # if we are given some desired y also calculate the loss
            logits = self.lm_head(X) # (B, T, vocab_size)
#             B, T, C = logits.shape
#             logits = logits.view(B*T, C)
#             y = y.view(B*T)
#             loss = F.cross_entropy(logits, y)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
        else:
            # inference time mini optimization: only forward the lm_head on the very last position
            logits = self.lm_head(X[:, [-1], :]) # note: using list[-1] to preserve the time dim
            loss = None
        return logits, loss
    
    def generate(self, X, max_new_tokens):
        # X is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop X to the last block_size tokens
            X_cond = X[:, -block_size:]
            # get the predictions
            logits, loss = self(X_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_nxt = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            X = torch.cat((X, idx_nxt), dim=1) # (B, T+1)
        return X
        

In [13]:
model = NanoGPTLanguageModel(GPTConfig())
model.to(device)

number of parameters: 0.20M


In [14]:
print('total parameters number:', sum(p.nelement() for p in model.parameters()))

total parameters number: 203392


In [15]:
# train model
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in range(epochs):
    # every once in a while evaluate the loss on train and eval sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
    
    # sample a batch of data
    Xb, yb = get_batch('train')
    
    # evaluate the loss
    logits, loss = model(Xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.1844, eval loss 4.1840
step 100: train loss 2.6713, eval loss 2.6584
step 200: train loss 2.5072, eval loss 2.5084
step 300: train loss 2.4233, eval loss 2.4314
step 400: train loss 2.3848, eval loss 2.3849
step 500: train loss 2.3422, eval loss 2.3411
step 600: train loss 2.2891, eval loss 2.2915
step 700: train loss 2.2536, eval loss 2.2576
step 800: train loss 2.2422, eval loss 2.2439
step 900: train loss 2.1927, eval loss 2.2089


In [16]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))


ESIRY IY:
Paveng he-Kard see honen Ixen hangs of-Zends,
Ind brou, the sour ing me, thit die.


PUtom habeth, your grasijar'nzen say the hat your coung oof wand
thablk : slises frushe. I sicot of orbe hefblt theat, larde ther hiply't hit ich berer. wighwesss a math
Fou, li.

Hupast an Rice?
TRICMOLITAM'NENYIORIUK:
Kenfaits.

Sighn Rusm'd I to abuts I det the denmearss
My, fakvivengrn's our mon he may with ridey kish ing thish is--
Bndbuost homurghy, you haghting ath, cowaull twer wackeden renst.

NORUKulan:
Nay, the Vakest hekoks nazitt mellbhiste.

HEFLOLEEY:
The ise thassen bly, a youu poll atim'?

KENprure Thats igh:
Hhavll oum the trofe of wethy taiths duse ayAns thene to ean sucay dise,
JOrversuknk groikesiong stomrpient atin'sIn tupeeld baet; I loche omartore; ikn, wis ton prostevear, growerst con, hand;
Fell I kneam salll: be'bllt to aegrimant
Soimt nour hanmf supeom with, Aglaed
Shang, bous ay buartong obrs; her squiken be,
Arused the but the the cut my you ournwanch unell',
Fa