## Transformer Model Size and Dataset Size Scaling Laws

Reproducing Approach 3 from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf) to determine the optimal model size and dataset size for a decoder-only transformer model. Based on [scaling_laws.ipynb](https://github.com/karpathy/nanoGPT/blob/master/scaling_laws.ipynb) by Andrej Karpathy.

In [None]:
# Analysis and reprodiction of the results from Chinchilla - https://arxiv.org/pdf/2203.15556
# Can be used to determine compute-optimal models

import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

# Use LaTeX style for text rendering
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern']
plt.rcParams['font.size'] = 16
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}\usepackage{amsfonts}'

In [None]:
def abbr_size(i: int) -> str:
    """Abbreviate numbers for logging, showing the three most significant digits."""
    for unit in ['', 'K', 'M', 'B', 'T', 'P']:
        if i < 1000:
            return f'{i:.3g}{unit}'
        i /= 1000

### Calcuating Expected Training Statistics

In [None]:
# GPT model parameters as in GPT-alpha

def gpt_alpha_params(T, V, C, H, N):
    """Calculate the total number of parameters for GPT alpha. No biases."""
    out = OrderedDict()
    # Token and position embeddings
    out['emebedding/position'] = C * T
    out['embedding/token'] = C * V
    out['embedding'] = out['emebedding/position'] + out['embedding/token']
    # Attention blocks
    out['attention/ln'] = C
    out['attention/qkv'] = 3 * (C**2)
    out['attention/proj'] = C**2
    out['attention'] = out['attention/ln'] + out['attention/qkv'] + out['attention/proj']
    # MLP blocks
    ffw_size = 4 * C
    out['mlp/ln'] = C
    out['mlp/ffw'] = C * ffw_size
    out['mlp/proj'] = ffw_size * C
    out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']
    # Transformer blocks
    out['block'] = out['attention'] + out['mlp']
    out['transformer'] = N * out['block']
    out['ln_f'] = C
    out['dense'] = 0 # Zero due to parameter sharing with the embedding layer
    out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']
    return out

def gpt_alpha_flops(T, V, C, H, N):
    """Calculate the total number of FLOPs for GPT alpha. Only weight FLOPs are considered."""
    # Only weight FLOPs are considered as other operations (layernorm, softmax) are negligible
    # Matrix multiplcation FLOPs are 2*M*N*P for (MxN) @ (NxP)
    out = OrderedDict()
    head_size = C // H
    # Attention blocks
    out['attention/qkv'] = 2 * T * (C * 3 * C) # Projection of Q, K, V
    out['attention/scores'] = 2 * T * T * C # Q @ K
    out['attention/reduce'] = 2 * H * (T * T * head_size) # V @ scores
    out['attention/proj'] = 2 * T * (C * C) # Final projection
    out['attention'] = sum(out['attention/'+k] for k in ['qkv', 'scores', 'reduce', 'proj'])
    # MLP blocks
    ffw_size = 4 * C
    out['mlp/ffw1'] = 2 * T * (C * ffw_size)
    out['mlp/ffw2'] = 2 * T * (ffw_size * C)
    out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']
    # Transformer blocks
    out['block'] = out['attention'] + out['mlp']
    out['transformer'] = N * out['block']
    out['dense'] = 2 * T * (C * V)
    # Total FLOPs
    out['forward'] = out['transformer'] + out['dense']
    out['backward'] = 2 * out['forward'] # Estimate backward pass as twice the forward pass
    out['total'] = out['forward'] + out['backward']
    return out

# block_size, vocab_size, n_embd, n_head, n_layer
gpt_alpha = dict(T=1024, V=50257, C=768, H=12, N=12)

params = gpt_alpha_params(**gpt_alpha)
print(f'{"Name":25s}{"Parameters":12s}{"Ratio (%)":10s}')
print("-" * 45)
for k, v in params.items():
    print(f'{k:20s}{v:15,d}{(v/params["total"] * 100):10.2f}')

In [None]:
flops = gpt_alpha_flops(**gpt_alpha)
print(f'{"Name":25s}{"FLOPs":12s}{"Ratio (%)":10s}')
print("-" * 45)
for k, v in flops.items():
    print(f'{k:20s}{v:15,d}{(v/flops["forward"] * 100):10.2f}')

In [None]:
# Compare this to the number of FLOPS as estimated by PaLM - https://arxiv.org/pdf/2204.02311

def estimate_flops(T, V, C, H, N):
    """Estimate model flops utilisation (MFU) according to PaLM."""
    params = gpt_alpha_params(T, V, C, H, N)
    # Non-embedding model parameters. Token embeddings are shared with the dense layer.
    n_params = params['total'] - params['emebedding/position']
    flops_per_token = 6 * n_params + 12 * N * H * (C // H) * T
    return flops_per_token * T

flops_estimate = estimate_flops(**gpt_alpha)

print(f'Estimated FLOPs: {flops_estimate:,d}')

In [None]:
batch_size = 32 * 16 # Using gradient accumulation to simulate a batch size of 512
dt = 3.696 # Seconds per batch for single A100 SMX4 40GB GPU (using compiled PyTorch)
flops_achieved = flops['total'] * (batch_size / dt)

# A100 produces 312 TFLOPs of bfloat16 running on tensor cores
# https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet.pdf
a100_flops = 312e12

mfu = flops_achieved / a100_flops
# Model flops utilisation (MFU) as a percentage of the total FLOPs promised by the A100
print(f'Fraction of A100 used: {mfu * 100:.2f}%')

In [None]:
n_params = params['total'] # Total number of parameters
D = 40e9 # 40B tokens to process
flops = a100_flops * 8 * mfu # Total FLOPs in an 8x A100 node
total_flops = 6 * n_params * D # 6 FLOPs per parameter per token approximation
train_time = total_flops / flops
print(f'Total expected training time: {train_time / 3600:.2f} hours')

In [None]:
# Calculate the size of each checkpoint, which is the total number of parameters and buffers
params_bytes = params['total'] * 4 # 4 bytes per fp32
params_bytes += 2 * params_bytes # AdamW has two buffers per parameter for stats
print(f'Estimated checkpoint size: {abbr_size(params_bytes)[:-1]}GB')

# Calculate the ratio of GPU memory used just for the parameters and buffers
gpu_mem = 40e9 # 40 GB A100 GPU
print(f'GPU memory usage: {params_bytes / gpu_mem * 100:.2f}%')

This is not much at all. Most of the memory is used by activations and intermediate calculations for forward and backward passes.

### Chinchilla Helper Functions

In [None]:
def gpt_params(T, V, C, H, N):
    """Calculate the total number of parameters for GPT models."""
    embd_pos = C * T
    embd_tok = C * V
    embd = embd_pos + embd_tok
    # Attention blocks
    att_ln = 2 * C
    att_qkv = 3 * (C**2 + C)
    att_proj = C**2 + C 
    # MLP blocks
    ffw_size = 4 * C
    mlp_ln = 2 * C
    mlp_ffw = C * ffw_size + ffw_size
    mlp_proj = ffw_size * C + C
    # Transformer blocks
    attention = att_ln + att_qkv + att_proj
    mlp = mlp_ln + mlp_ffw + mlp_proj
    block = attention + mlp
    transformer = N * block
    # Final layer norm and fully connected layer
    ln_f = 2 * C
    dense = C * V
    # Total number of parameters (excluding embeddings)
    return transformer + ln_f + dense

def chinchilla_params(T, V, C, H, N, ffw_size):
    """Calculate the total number of parameters for Chinchilla models. Chinchilla uses relative position embeddings."""
    # embd = C * V (token embedding only)
    # Attention blocks
    att_ln = 2 * C
    att_qkv = 3 * (C**2 + C)
    att_proj = C**2 + C
    att_rel_pos = C**2 + 2 * C  # Relative keys, content bias, relative bias
    # MLP blocks
    mlp_ln = 2 * C
    mlp_ffw = C * ffw_size + ffw_size
    mlp_proj = ffw_size * C + C
    # Transformer blocks
    attention = att_ln + att_qkv + att_proj + att_rel_pos
    mlp = mlp_ln + mlp_ffw + mlp_proj
    block = attention + mlp
    transformer = N * block
    # Final layer norm and fully connected layer
    ln_f = 2 * C
    dense = C * V
    # Total number of parameters (excluding embeddings)
    return transformer + ln_f + dense

# The difference between gpt_alpha_params() and gpt_params() is due to the additional biases in
# the attention, feed-forward layers and layer normalisation. Furthermore, gpt_params() excludes
# the position embeddings, which are included in gpt_alpha_params(). Token embeddings are simply
# shared with the final fully connected layer.
print(f'{gpt_params(**gpt_alpha):,d}')

In [None]:
def chinchilla_flops(T, V, C, H, N, ffw_size):
    """Calculate total number of FLOPs, see Chinchilla paper Appendix F.""" 
    key_size = T // H
    embd = 2 * T * V * C
    # Attention blocks
    att_qkv = 2 * 3 * T * C * (key_size * H) # Q,K,V @ W
    att_logits = 2 * T * T * (key_size * H) # K @ Q
    att_softmax = 3 * H * T * T # 3 * is for subtract (max), exp, divide
    att_value = 2 * T * T * (key_size * H) # softmax @ V
    att_linear = 2 * T * (key_size * H) * C # Final linear layer
    attention = att_qkv + att_logits + att_softmax + att_value + att_linear
    # MLP blocks
    mlp = 2 * T * (C * ffw_size + C * ffw_size)
    # Logits
    logits = 2 * T * C * V
    # Forward pass FLOPs, unlike paper, embeddings and logits are included
    forward_flops = embd + N * (attention + mlp) + logits
    backward_flops = 2 * forward_flops # as in https://arxiv.org/abs/2001.08361
    return forward_flops + backward_flops


In [None]:
# Reproduce the results from Table A4 from Chinchilla paper
# (N, C, ffw_size, H)
models = [
    [10, 640, 2560, 10],
    [20, 1024, 4096, 16],
    [24, 1280, 5120, 10],
    [26, 1792, 7168, 14],
    [28, 2048, 8192, 16],
    [40, 3584, 14336, 28]
]

for N, C, ffw_size, H in models:
    # Chinchilla models use a vocabulary size of 32k and a block size of 2048
    config = dict(T=2048, V=32000, C=C, H=H, N=N, ffw_size=ffw_size)
    flops = chinchilla_flops(**config)
    n_params = chinchilla_params(**config)
    print(f'{n_params:.3g} params, {flops:.3g} FLOPs')

### Approach 3: Fitting a parametric loss function - $L(|\theta|,D)$ 

From Approach 3 of Chinchilla. Fitting a parametric loss function $L(|\theta|,D)$ to approximate the final loss given the model size and dataset size.

In [None]:
def loss(n_params, D):
    """Approximate loss given n_params and D dataset size (in tokens), per Chinchilla paper."""
    E = 1.69 # Entropy of natural language, limit of infinite model on infinite data
    A = 406.4
    B = 410.7
    alpha = 0.34
    beta = 0.28
    return A / (n_params ** alpha) + B / (D ** beta) + E

# Model sizes from 10M to 100B
ns = 10 ** np.arange(7, 11, step=2**-4)
# Dataset sizes from 1B to 1T
ds = 10 ** np.arange(9, 12, step=2**-4)

# 2D contour plot of loss as a function of model size and dataset size
plt.figure(figsize=(12, 8))
loss2d = np.log10(np.array([[loss(n, d) for d in ds] for n in ns]))
im1 = plt.imshow(loss2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5, cmap='plasma', aspect='auto')
cs1 = plt.contour(loss2d, levels=30, extent=[9, 12, 7, 11], origin='lower', colors='white', linewidths=1, alpha=0.8)
# GPT-alpha model
plt.scatter(np.log10(40e9), np.log10(124e6), color='black', marker='x', s=100, label=f'GPT-$\\alpha$ ($124M$, $40$B)')
# Optimal model
plt.scatter(np.log10(14.4e9), np.log10(365e6), color='black', marker='^', s=100, label=f'Optimal Model ($365M$, $14.4$B)')
plt.clabel(cs1, inline=1, fontsize=14, fmt='%1.1f')
plt.xlabel(r'$\log_{10}(D)$')
plt.ylabel(r'$\log_{10}(|\theta|)$')
plt.legend(loc='upper right')
cbar1 = plt.colorbar(im1, label=r'$\log_{10}(\text{loss})$', pad=0.02, aspect=30)
plt.tight_layout()
plt.savefig(f'../cache/loss_contour.pdf', bbox_inches='tight')
plt.show()

# Contour plot of compute as a function of model size and dataset size
plt.figure(figsize=(12, 8))
compute2d = np.log10(np.array([[6*n*d for d in ds] for n in ns]))
im2 = plt.imshow(compute2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5, cmap='plasma', aspect='auto')
cs2 = plt.contour(compute2d, levels=30, extent=[9, 12, 7, 11], origin='lower', colors='white', linewidths=1, alpha=0.8)
# GPT-alpha model
plt.scatter(np.log10(40e9), np.log10(124e6), color='black', marker='x', s=100, label=f'GPT-$\\alpha$ ($124M$, $40$B)')
# Optimal model
plt.scatter(np.log10(14.4e9), np.log10(365e6), color='black', marker='^', s=100, label=f'Optimal Model ($365M$, $14.4$B)')
plt.clabel(cs2, inline=1, fontsize=14, fmt='%1.1f')
plt.xlabel(r'$\log_{10}(D)$')
plt.ylabel(r'$\log_{10}(|\theta|)$')
plt.legend(loc='upper right')
cbar2 = plt.colorbar(im2, label=r'$\log_{10}(\text{flops})$', pad=0.02, aspect=30)
plt.tight_layout()
plt.savefig(f'../cache/compute_contour.pdf', bbox_inches='tight')
plt.show()

In [None]:
# FLOPs from 10^18 to 10^23
fs = 10 ** np.arange(18, 24, step=2**-4)

# Compute corresponding D values for each combination of FLOPs and n_params
ds_flops = np.array([[flops / (6 * n) for flops in fs] for n in ns])

# 2D contour plot of loss as a function of model size and FLOPs budget
plt.figure(figsize=(12, 8))
loss2d_flops = np.log10(np.array([[loss(n, d) for d in ds] for n, ds in zip(ns, ds_flops)]))
im1 = plt.imshow(loss2d_flops, extent=[18, 23, 7, 11], origin='lower', alpha=0.5, cmap='plasma', aspect='auto')
cs1 = plt.contour(loss2d_flops, levels=30, extent=[18, 23, 7, 11], origin='lower', colors='white', linewidths=1, alpha=0.8)
# GPT-alpha model
plt.scatter(np.log10(3.16e19), np.log10(124e6), color='black', marker='x', s=100, label=f'GPT-$\\alpha$ ($124M$)')
# Optimal model
plt.scatter(np.log10(3.16e19), np.log10(365e6), color='black', marker='^', s=100, label=f'Optimal Model ($365M$)')
plt.clabel(cs1, inline=1, fontsize=14, fmt='%1.1f')
plt.xlabel(r'$\log_{10}(\text{FLOPs})$')
plt.ylabel(r'$\log_{10}(|\theta|)$')
plt.legend(loc='upper right')
cbar1 = plt.colorbar(im1, label=r'$\log_{10}(\text{loss})$', pad=0.02, aspect=30)
plt.tight_layout()
plt.savefig(f'../cache/loss_contour_flops.pdf', bbox_inches='tight')
plt.show()

Given any $(|\theta|, D)$, the loss and total flops can be estimated. Therefore, given a specific budget of flops $C$, it is possible to find the optimal $$(|\theta|^*, D^*) = \text{argmin}_{\text{FLOPs}(|\theta|, D) = C} L(|\theta|, D)$$ that minimises the loss. Therefore, the optimal model size and dataset size can be determined for a given compute budget.

In [None]:
# 8x A100 SMX4 40GB GPUs produce a theoretical maximum of 2.50e15 FLOPs using BF16
# The MFU is 0.3885, which is the fraction of the total FLOPs used by the model,
# giving an expected FLOPs budget of 9.70e14 FLOPs. Over a training time of 12 hours,
# this produces 9.70e14 * 12 * 60 * 60 = 4.19e19 total FLOPs. Note that this has been
# reduced to 3.16e19 FLOPs to account for overhead in training time with validation.

C = 3.16e19 # Example compute budget
# Range of model sizes to consider from 10M to 100B
ns = 10 ** np.arange(7, 11, step=2**-4)
# Compute D for each model size using C = 6 * n_params * D
ds = C / (6 * ns)
# Calculate the loss for each model size and dataset size
losses = loss(ns, ds)
# Find the optimal model size
best = np.argmin(losses)
print(f'Optimal model size: {abbr_size(ns[best])} params, {abbr_size(ds[best])} tokens')