## Compute-Optimal Models

Reproducing Approach 3 from [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf) to determine the optimal model size and dataset size for a decoder-only transformer model.

In [None]:
# Analysis and reprodiction of the results from Chinchilla - https://arxiv.org/pdf/2203.15556.pdf
# Can be used to determine compute-optimal models

import numpy as np
import matplotlib.pyplot as plt

# Use LaTeX style for text rendering
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern']
plt.rcParams['font.size'] = 12
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}\usepackage{amsfonts}'

In [None]:
def abbr_size(i: int) -> str:
    """Abbreviate numbers for logging, showing the three most significant digits."""
    for unit in ['', 'K', 'M', 'B', 'T', 'P']:
        if i < 1000:
            return f'{i:.3g}{unit}'
        i /= 1000

In [None]:
def gpt_params(T, V, C, H, N):
    """Calculate the total number of parameters for GPT models."""
    # embeddings = C * V + C * T
    # Feed-forward network size
    ffw_size = 4 * C
    # Transformer blocks
    attention = 3 * C**2 + 3 * C # Attention weights and biases
    attproj = C**2 + C # Attention output projection
    ffw = C * ffw_size + ffw_size # Feed-forward weights and biases
    ffwproj = ffw_size * C + C # Feed-forward output projection
    layernorms = 2 * 2 * C # 2 layer norms per block
    # Final layer norm and fully connected layer
    ln_f = 2 * C
    dense = C * V
    # Total number of parameters (excluding embeddings)
    return  N * (attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense

def chinchilla_params(T, V, C, H, N, ffw_size):
    """Calculate the total number of parameters for Chinchilla models."""
    # embeddings = C * V (token embedding only)
    # Transformer blocks
    attention = 3 * C**2 + 3 * C  # Attention weights and biases
    relative_pos = C**2 + 2 * C  # Relative keys, content bias, relative bias
    attproj = C**2 + C  # Attention output projection
    ffw = C * ffw_size + ffw_size  # Feed-forward weights and biases
    ffwproj = ffw_size * C + C  # Feed-forward output projection
    layernorms = 2 * 2 * C  # 2 layer norms per block
    # Final layer norm and fully connected layer
    ln_f = 2 * C
    dense = C * V
    # Total number of parameters (excluding embeddings)
    return N * (attention + relative_pos + attproj + ffw + ffwproj + layernorms) + ln_f + dense

# block_size, vocab_size, n_embd, n_head, n_layer
gpt_alpha = dict(T = 1024, V = 50257, C = 768, H = 12, N = 12)
gpt_params(**gpt_alpha)

In [None]:
def chinchilla_flops(T, V, C, H, N, ffw_size):
    """Calculate total number of FLOPs, see Chinchilla paper Appendix F.""" 
    key_size = T // H
    # embeddings = 2 * T * V * C
    # Q,K,V projections
    attention = 2 * 3 * T * C * (key_size * H)
    # K @ Q logits
    attlogits = 2 * T * T * (key_size * H)
    # Softmax
    attsoftmax = 3 * H * T * T # 3 * is for subtract (max), exp, divide (?)
    # Softmax @ V reductions
    attvalue = 2 * T * T * (key_size * H)
    # Final linear
    attlinear = 2 * T * (key_size * H) * C
    att = attention + attlogits + attsoftmax + attvalue + attlinear
    # Feed-forward
    dense = 2 * T * (C * ffw_size + C * ffw_size)
    # logits = 2 * T * C * V
    # Forward pass FLOPs, paper does not include embeddings and logits
    forward_flops = N * (att + dense)
    backward_flops = 2 * forward_flops # as in https://arxiv.org/abs/2001.08361
    return forward_flops + backward_flops


In [None]:
# Reproduce the results from Table A4 from Chinchilla paper Appendix
# (N, C, ffw_size, H)
models = [
    [10, 640, 2560, 10],
    [20, 1024, 4096, 16],
    [24, 1280, 5120, 10],
    [26, 1792, 7168, 14],
    [28, 2048, 8192, 16],
    [40, 3584, 14336, 28]
]

for N, C, ffw_size, H in models:
    # Chinchilla models use a vocabulary size of 32k and a block size of 2048
    config = dict(T=2048, V=32000, C=C, H=H, N=N, ffw_size=ffw_size)
    flops = chinchilla_flops(**config)
    n_params = chinchilla_params(**config)
    print(f'{n_params:.3g} params, {flops:.3g} FLOPs')

## Approach 3: Fitting a parametric loss function - $L(n_{\text{params}},D)$ 

Fitting a function $L(n_{\text{params}},D)$ to approximate the final loss given the model size and dataset size.

In [None]:
def loss(n_params, D):
    """Approximate loss given n_params and D dataset size (in tokens), per Chinchilla paper."""
    E = 1.69 # Entropy of natural language, limit of infinite model on infinite data
    A = 406.4
    B = 410.7
    alpha = 0.34
    beta = 0.28
    return A / (n_params ** alpha) + B / (D ** beta) + E

# Model sizes from 10M to 100B
ns = 10 ** np.arange(7, 11, step=2**-4)
# Dataset sizes from 1B to 1T
ds = 10 ** np.arange(9, 12, step=2**-4)

# 2D contour plot of loss as a function of model size and dataset size
plt.figure(figsize=(12, 8))
loss2d = np.log10(np.array([[loss(n, d) for d in ds] for n in ns]))
im1 = plt.imshow(loss2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5, cmap='plasma', aspect='equal')
cs1 = plt.contour(loss2d, levels=30, extent=[9, 12, 7, 11], origin='lower', colors='white', linewidths=1, alpha=0.8)
plt.clabel(cs1, inline=1, fontsize=8, fmt='%1.1f')
plt.xlabel(r'$\log_{10}(D)$')
plt.ylabel(r'$\log_{10}(n_{\text{params}})$')
cbar1 = plt.colorbar(im1, label=r'$\log_{10}(\text{loss})$', pad=0.02, aspect=30)
plt.tight_layout()
plt.savefig(f'../cache/compute_contour.pdf', bbox_inches='tight')
plt.show()

# 2D contour plot of the compute as a function of model size and dataset size using FLOPs = 6 * n_params * D
plt.figure(figsize=(12, 8))
compute2d = np.log10(np.array([[6*n*d for d in ds] for n in ns]))
im2 = plt.imshow(compute2d, extent=[9, 12, 7, 11], origin='lower', alpha=0.5, cmap='plasma', aspect='equal')
cs2 = plt.contour(compute2d, levels=30, extent=[9, 12, 7, 11], origin='lower', colors='white', linewidths=1, alpha=0.8)
plt.clabel(cs2, inline=1, fontsize=8, fmt='%1.1f')
plt.xlabel(r'$\log_{10}(D)$')
plt.ylabel(r'$\log_{10}(n_{\text{params}})$')
cbar2 = plt.colorbar(im2, label=r'$\log_{10}(\text{flops})$', pad=0.02, aspect=30)
plt.tight_layout()
plt.savefig(f'../cache/loss_contour.pdf', bbox_inches='tight')
plt.show()

Given any $(n_{\text{params}}, D)$, the loss and total flops can be estimated. Therefore, given a specific budget of flops $C$, it is possible to find the optimal $$(n_{\text{params}}^*, D^*) = \argmin_{\text{FLOPs}(n_{\text{params}}, D) = C} L(n_{\text{params}}, D)$$ that minimises the loss. Therefore, the optimal model size and dataset size can be determined for a given compute budget.

In [None]:
# 8x A100 SMX4 40GB GPUs produce 1.25e15 FLOPs using TF32 without sparsity
# Over 12 hours, 1.25e15 * 12 * 60 * 60 = 5.39e19 total FLOPs

C = 5.39e19 # Example compute budget
# Range of model sizes to consider from 10M to 100B
ns = 10 ** np.arange(7, 11, step=2**-4)
# Compute D for each model size using C = 6 * n_params * D
ds = C / (6 * ns)
# Calculate the loss for each model size and dataset size
losses = loss(ns, ds)
# Find the optimal model size
best = np.argmin(losses)
print(f'Optimal model size: {abbr_size(ns[best])} params, {abbr_size(ds[best])} tokens')

# Plot the optimal model size
plt.figure(figsize=(5, 5))
plt.semilogx(ns, losses, color='black', linewidth=1)
plt.semilogx(ns[best], losses[best], color='black', marker='o', markersize=4, label='Optimal point')
plt.xlabel(r'$n_{\text{params}}$', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.grid(True, which='both', ls='-', alpha=0.2)
plt.legend(fontsize=12)
# Annotate the optimal point
plt.annotate(f'{abbr_size(ns[best])}',
             xy=(ns[best], losses[best]), xytext=(ns[best]*0.75, losses[best]*1.02),
             fontsize=12, ha='left', va='bottom')
plt.tight_layout()
plt.savefig(f'../cache/optimal_params.pdf', bbox_inches='tight')
plt.show()

The models on the left of the optimal point are underparameterised and excessively trained. The models on the right of the optimal point are overparameterised and undertrained. The optimal point is the sweet spot where the model is neither underparameterised nor overparameterised.