* add MFU calculation

In [1]:
from tqdm import tqdm
import tiktoken
import torch 
import datasets
import random

dataset = datasets.load_dataset("HuggingFaceFW/fineweb-edu", data_files=["sample/10BT/000_00000.parquet"], split="train")
dataset = dataset.train_test_split()

enc = tiktoken.get_encoding("gpt2")
assert enc.decode(enc.encode("hello world")) == "hello world"

def encode(string):
    return torch.tensor(enc.encode(string), dtype=torch.long)

def decode(tensor):
    return enc.decode(tensor.cpu().squeeze().numpy())

num_samples = 10_000
dataset_tok_train = torch.cat([encode(dataset["train"][i]["text"]) for i in tqdm(range(num_samples))])
dataset_tok_test = torch.cat([encode(dataset["test"][i]["text"]) for i in tqdm(range(num_samples))])

def get_sample(split, sample_length, batch_size):
    tokens = dataset_tok_train if split == "train" else dataset_tok_test
    idcs = torch.randint(len(tokens)-sample_length, (batch_size,))
    x = torch.stack([torch.tensor(tokens[x:x+sample_length]) for x in idcs])
    y = torch.stack([torch.tensor(tokens[x+1:x+sample_length+1]) for x in idcs])
    return x, y

100%|█████████████████████████████████████████████| 10000/10000 [00:06<00:00, 1643.43it/s]
100%|█████████████████████████████████████████████| 10000/10000 [00:05<00:00, 1985.75it/s]


In [2]:
print(f"Train data: {len(dataset_tok_train):,} tokens")
print(f"Test data: {len(dataset_tok_test):,} tokens")

Train data: 10,276,684 tokens
Test data: 10,426,816 tokens


In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        # out = F.scaled_dot_product_attention(q, k, v)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [5]:
vocab_size = 50_272
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
batch_size = 64
block_size = 128 
device = "mps"
learning_rate = 3e-4

model = GPTLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')


49.34896 M parameters


In [13]:

with elapsed_timer() as elapsed:
    time.sleep(1)
    print(elapsed())
    time.sleep(2)
    print(elapsed())
    time.sleep(3)

1.004555249994155
3.005257291981252


In [186]:
import torch
from torch.utils.flop_counter import FlopCounterMode
from contextlib import contextmanager
from timeit import default_timer

@contextmanager
def elapsed_timer():
    start = default_timer()
    elapser = lambda: default_timer() - start
    yield lambda: elapser()
    end = default_timer()
    elapser = lambda: end-start

def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
        f()
    total_flops = flop_counter.get_total_flops()

    
    with elapsed_timer() as elapsed:
        f()
    
    s_per_iter = elapsed()
    # iters_per_second = 1/s_per_iter
    print(f"{s_per_iter}s/iter, {total_flops}flops, {total_flops / s_per_iter / 1e12} TF/s")

def train_one_sample():
    xb, yb = get_sample('train', block_size, batch_size)
    logits, loss = model(xb.to(device), yb.to(device))
    loss.backward()

train_one_sample()
# get_flops_achieved(train_one_sample)

  x = torch.stack([torch.tensor(tokens[x:x+sample_length]) for x in idcs])
  y = torch.stack([torch.tensor(tokens[x+1:x+sample_length+1]) for x in idcs])


In [148]:
model = model.half()
get_flops_achieved(train_one_sample)

  x = torch.stack([torch.tensor(tokens[x:x+sample_length]) for x in idcs])
  y = torch.stack([torch.tensor(tokens[x+1:x+sample_length+1]) for x in idcs])


0.8047828749986365, 0.03602341718323606 TF/s


In [322]:
from torch import mps
import time 
n = 1024
a = torch.rand(n, n, device="mps")
num_timetrials = 100

def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
        f()
    total_flops = flop_counter.get_total_flops()

    s_total = 0
    for _ in range(num_timetrials):
        mps.synchronize()
        start = time.time()
        f()
        mps.synchronize()
        s_total += time.time() - start
        print(time.time() - start)
    
    s_per_iter = s_total / num_timetrials
    # iters_per_second = 1/s_per_iter
    print(f"{s_per_iter}s/iter, {total_flops}flops, {total_flops / s_per_iter / 1e12} TF/s")

def single_matmul():
    b = torch.matmul(a, a)

get_flops_achieved(single_matmul)

0.0035593509674072266
0.004293918609619141
0.0035490989685058594
0.0035119056701660156
0.004039287567138672
0.0033559799194335938
0.0034139156341552734
0.0032520294189453125
0.0016567707061767578
0.0016558170318603516
0.0016429424285888672
0.0016100406646728516
0.0015971660614013672
0.00144195556640625
0.0014500617980957031
0.0013990402221679688
0.0014350414276123047
0.00144195556640625
0.0014979839324951172
0.0015919208526611328
0.0016698837280273438
0.0016667842864990234
0.0016551017761230469
0.0015749931335449219
0.0014240741729736328
0.0014469623565673828
0.0014700889587402344
0.0014073848724365234
0.0014209747314453125
0.001592874526977539
0.0016260147094726562
0.0016450881958007812
0.0017039775848388672
0.0019021034240722656
0.001589059829711914
0.0014939308166503906
0.0014808177947998047
0.001483917236328125
0.0014750957489013672
0.0015442371368408203
0.0016360282897949219
0.0016732215881347656
0.0016379356384277344
0.001672983169555664
0.0015940666198730469
0.001455783843994140

In [354]:
from torch import mps
import time 
n = 1024
a = torch.rand(n, n, device="mps")
num_timetrials = 10

def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
        f()
    total_flops = flop_counter.get_total_flops()

    s_total = 0
    for _ in range(num_timetrials):
        mps.synchronize()
        start = time.time()
        f()
        mps.synchronize()
        s_total += time.time() - start
    
    s_per_iter = s_total / num_timetrials
    # iters_per_second = 1/s_per_iter
    print(f"{s_per_iter}s/iter, {total_flops}flops, {total_flops / s_per_iter / 1e12} TF/s")

xb, yb = get_sample('train', block_size, batch_size)
xb = xb.to(device)
yb = yb.to(device)

def train_one_sample():
    model(xb)
    # logits, loss = model(xb, yb)
    # loss.backward()

get_flops_achieved(train_one_sample)

  x = torch.stack([torch.tensor(tokens[x:x+sample_length]) for x in idcs])
  y = torch.stack([torch.tensor(tokens[x+1:x+sample_length+1]) for x in idcs])


0.2116630792617798s/iter, 9663676416flops, 0.04565593796378724 TF/s


In [355]:
def train_one_sample():
    # model(xb)
    logits, loss = model(xb, yb)
    # loss.backward()

get_flops_achieved(train_one_sample)

0.25863964557647706s/iter, 9663676416flops, 0.037363476873241194 TF/s


In [356]:
def train_one_sample():
    logits, loss = model(xb, yb)
    loss.backward()

get_flops_achieved(train_one_sample)

0.8026062965393066s/iter, 28991029248flops, 0.03612110865938142 TF/s


In [359]:
from torch import mps
import time 
n = 1024
a = torch.rand(n, n, device="mps")
num_timetrials = 10

def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
        f()
    total_flops = flop_counter.get_total_flops()

    s_total = 0
    for _ in range(num_timetrials):
        # mps.synchronize()
        start = time.time()
        f()
        # mps.synchronize()
        s_total += time.time() - start
    
    s_per_iter = s_total / num_timetrials
    # iters_per_second = 1/s_per_iter
    print(f"{s_per_iter}s/iter, {total_flops}flops, {total_flops / s_per_iter / 1e12} TF/s")

xb, yb = get_sample('train', block_size, batch_size)
xb = xb.to(device)
yb = yb.to(device)

def train_one_sample():
    model(xb)
    # logits, loss = model(xb, yb)
    # loss.backward()

for _ in range(10):
    get_flops_achieved(train_one_sample)

  x = torch.stack([torch.tensor(tokens[x:x+sample_length]) for x in idcs])
  y = torch.stack([torch.tensor(tokens[x+1:x+sample_length+1]) for x in idcs])


0.2135932445526123s/iter, 9663676416flops, 0.04524336168141143 TF/s
0.2108713150024414s/iter, 9663676416flops, 0.04582736355529493 TF/s
0.21042635440826415s/iter, 9663676416flops, 0.045924268579261546 TF/s
0.2622026205062866s/iter, 9663676416flops, 0.036855758334300485 TF/s
0.21017711162567138s/iter, 9663676416flops, 0.045978728802835364 TF/s
0.2583198308944702s/iter, 9663676416flops, 0.03740973498835961 TF/s
0.2100447654724121s/iter, 9663676416flops, 0.046007699331451594 TF/s
0.23678829669952392s/iter, 9663676416flops, 0.040811461337816324 TF/s
0.2257081985473633s/iter, 9663676416flops, 0.042814910925675326 TF/s
0.22636716365814208s/iter, 9663676416flops, 0.042690274772334065 TF/s


In [360]:
from torch import mps
import time 
n = 1024
a = torch.rand(n, n, device="mps")
num_timetrials = 10

def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
        f()
    total_flops = flop_counter.get_total_flops()

    s_total = 0
    for _ in range(num_timetrials):
        mps.synchronize()
        start = time.time()
        f()
        mps.synchronize()
        s_total += time.time() - start
    
    s_per_iter = s_total / num_timetrials
    # iters_per_second = 1/s_per_iter
    print(f"{s_per_iter}s/iter, {total_flops}flops, {total_flops / s_per_iter / 1e12} TF/s")

xb, yb = get_sample('train', block_size, batch_size)
xb = xb.to(device)
yb = yb.to(device)

def train_one_sample():
    model(xb)
    # logits, loss = model(xb, yb)
    # loss.backward()

for _ in range(10):
    get_flops_achieved(train_one_sample)

  x = torch.stack([torch.tensor(tokens[x:x+sample_length]) for x in idcs])
  y = torch.stack([torch.tensor(tokens[x+1:x+sample_length+1]) for x in idcs])


0.21155788898468017s/iter, 9663676416flops, 0.04567863889348882 TF/s
0.23959777355194092s/iter, 9663676416flops, 0.04033291408655378 TF/s
0.23498382568359374s/iter, 9663676416flops, 0.04112485779771142 TF/s
0.22499163150787355s/iter, 9663676416flops, 0.04295127045941627 TF/s
0.24395358562469482s/iter, 9663676416flops, 0.03961276646643299 TF/s
0.2253368616104126s/iter, 9663676416flops, 0.042885466438720696 TF/s
0.256885027885437s/iter, 9663676416flops, 0.037618682939784674 TF/s
0.24661767482757568s/iter, 9663676416flops, 0.03918484927228521 TF/s
0.24898619651794435s/iter, 9663676416flops, 0.03881209702042074 TF/s
0.33020141124725344s/iter, 9663676416flops, 0.02926600579778831 TF/s
