In [7]:

import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

In [8]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster


In [86]:
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        # one activation in the MLP
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class CausalSelfAttention(nn.Module):
    """attention implementation with causal mask"""
    def __init__(self, config):
        super().__init__()
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.mask = torch.tril(torch.ones((config.block_size, config.block_size))).view(1, 1, config.block_size, config.block_size)

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # attn
        # (B, bh, T, hs) dot (B, nh, hs, T) -> (B, bh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        # causal mask
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)

        y = att @ v # (B, bh, T, T) dot (B, bh, T, hs) -> (B, bh, T, hs)

        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y



class Block(nn.Module):
    """Transformer block"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        # layer norm first
        # skip connections
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = GPTConfig()
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, idx, targets=None):
        """
        idx: BxT(seq length)
        """
        device = idx.device
        B, T = idx.size()
        pos = torch.arange(0, T, dtype=torch.long, device=device)
        
        # embedding
        word_embed = self.transformer.wte(idx)
        pos_embed = self.transformer.wpe(pos)
        embed = word_embed + pos_embed

        # input to transformer blocks
        x = self.transformer.drop(embed)

        # transformer blocks
        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)
        
        if targets is not None:
            logits = self.lm_head(x) # (B, T, C)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
                                   targets.view(-1), ignore_index=-1)
        else:
            # inference-time sampling. Only need logits at the last token
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            logits, _ = self(idx_cond) # (B, T, C)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature

            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            # apply softmax to convert logits to probs
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

        

model = GPT(GPTConfig())


In [27]:
# Weird data loader

import os
import numpy as np

# poor man's data loader
dataset = "shakespeare_char"
device = "cpu"
data_dir = os.path.join('data', dataset)

block_size = GPTConfig().block_size
batch_size = 8

def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [22]:
one_batch_x, one_batch_y = get_batch("train")
one_batch_x.shape

torch.Size([8, 1024])

In [77]:
# Training loop
max_iters = 50
optimizer = torch.optim.AdamW(model.parameters())

for iter_num in range(max_iters):
    optimizer.zero_grad()
    logits, loss = model(one_batch_x, one_batch_y)
    loss.backward()
    optimizer.step()
    print(f"step {iter_num} loss is {loss}")


KeyboardInterrupt: 

In [87]:
# Generation loop

start_ids = [0,3]
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
print(x.size())
temperature = 1.0
top_k = 5
max_new_tokens = 128

model.eval()
# run generation
with torch.no_grad():
    y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
    print(y)


torch.Size([1, 2])
tensor([[    0,     3, 12004,  7833, 44420, 30150, 12485, 29746, 30229, 35873,
         35500, 39430, 46640, 34643,   912,  8409, 36031, 15794, 48022, 22688,
         35370, 47789,  7774, 16445, 26783,  7359, 17979, 25597, 46483, 20834,
         16022, 39541, 38831, 10983,   667, 13785, 40031, 36093, 39538, 45678,
         44696,  1799, 18824, 46107, 37801,  6473, 35437, 11110, 32789, 25602,
         44209, 24678,  2848,  9209, 37609, 32691, 15366, 22268, 17047, 45684,
         15540,   441, 32331, 14722, 20978, 10061, 19869, 30836, 27996, 19678,
         13030, 48259, 48572, 13586, 48731, 38471,  4235, 29605, 32585, 43913,
         22796, 28636, 42118, 50144, 29371, 27211, 44080, 14573, 41490, 13433,
         35850, 24283,  7379, 36939, 23327, 49748, 45308,  8151, 14326,  9823,
         21694,   781, 32968, 27074, 45877,  6526, 16685, 25500,  5918, 33960,
         35178, 43360, 46966,  4558, 29772,   538, 44094, 30761, 38943, 47454,
          3326, 17947, 18954, 413