Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# Unit Scaling: A How-To Guide

In our paper [Unit Scaling: Out-of-the-Box Low-Precision Training](TODO), we describe a scheme for designing neural networks that have approximate unit variance after every operation in the forward and backward pass.

This can be seen as an alternative to (static) loss scaling, or its automatic variant, as used in Automatic Mixed Precision. Whereas both of those schemes rely on a single, global scaling factor for all the gradients, unit scaling is more fine-grained.

A unit-scaled model adds scaling factors (constant scalar multiplications) to each operation in the computational graph to achieve this unit variance property. The result is a model which naturally produces tensors in the middle of the dynamic range provided by floating-point formats. There's no extra loss-scale hyperparameter—it works out-of-the-box!

## Implementing unit scaling

Here we demonstrate how to go about unit scaling in practice. This involves re-implementing common neural network layers to add variance-preserving scaling factors.

As explained in the paper, we can sometimes justify using different scaling factors in the forward and backward pass. We introduce a special `scaled()` op which allows us to do just that.

Below we implement two models. The first is a simple implementation of a transformer-decoder. The design and hyperparameters are inspired by Andrej Karpathy's [popular NanoGPT implementation](https://github.com/karpathy/nanoGPT) (though some differ). It's also 🤗-compatible!

The second model is the same, but unit-scaled. Let's get stuck in...

## Building a unit-scaled NanoGPT

In [1]:
!pip install git+https://github.com/huggingface/optimum-graphcore.git
!pip install git+https://github.com/graphcore-research/poptorch-experimental-addons
!pip install altair

Collecting git+https://github.com/huggingface/optimum-graphcore.git
  Cloning https://github.com/huggingface/optimum-graphcore.git to /tmp/pip-req-build-0ay1rt57
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/optimum-graphcore.git /tmp/pip-req-build-0ay1rt57
  Resolved https://github.com/huggingface/optimum-graphcore.git to commit 78756c7e2a7ed3bbce034708598ddab5619cb653
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting datasets
  Downloading datasets-2.10.1-py3-none-any.whl (469 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
Collecting scipy
  Downloading scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.5/34.5 MB[0m [31m62.2 MB/s[0m eta [36

In [2]:
import math
from typing import Dict, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers.activations import GELUActivation
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel

### The MLP layer

#### Regular scaling

We'll start by setting up a basic config for a reasonably small transformer:

In [3]:
class NanoGPTConfig(PretrainedConfig):
    model_type = "nano-gpt"

    def __init__(
        self,
        hidden_size: int = 384,
        num_hidden_layers: int = 6,
        num_attention_heads: int = 6,
        dropout: float = 0.1,
        vocab_size: int = 384,
        eos_token_id: int = 1,
        **kwargs,
    ) -> None:
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.dropout = dropout
        self.vocab_size = vocab_size
        self.eos_token_id = eos_token_id
        super().__init__(**kwargs)

Along with a standard (pre-norm) MLP module:

In [4]:
class MLP(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.ln = nn.LayerNorm(config.hidden_size)
        self.linear_1 = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = GELUActivation()
        self.linear_2 = nn.Linear(config.hidden_size * 4, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: Tensor) -> Tensor:
        hidden_states = self.ln(hidden_states)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return self.dropout(hidden_states)

In [5]:
# TODO: hide
def init_weights(init_fn) -> None:
    def inner_fn(module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            fan_out, fan_in = module.weight.shape  # TODO: correct base impl
            module.weight.data.normal_(mean=0.0, std=init_fn(fan_in, fan_out))
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=1)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    return inner_fn

In [6]:
# TODO: hide
import pandas as pd

np.seterr(divide = 'ignore')

def instrument(module):
    stats = {}
    instrument_recursive(module, stats)
    return stats

def instrument_recursive(module, stats, name=''):
    children = list(module.named_children())
    if children:
        for c_name, c in children:
            _name = f'{name}.{c_name}' if name and name != 'blocks' else c_name
            instrument_recursive(c, stats, _name)
    else:
        instrument_terminal(module, stats, name)

def instrument_terminal(module, stats, name):
    module_stats = {}
    def require_input_grads(_module, input):
        for i in input:
            if isinstance(i, Tensor) and i.is_floating_point():
                i.requires_grad_()
    
    module.register_forward_pre_hook(require_input_grads)
    
    if name.split('.')[-1] == 'softmax':
        return
    
    def record_fwd_scale(_module, input, output):
        if isinstance(output, Tensor) and output.is_floating_point():
            module_stats['x'] = np.log2(output.std().item())
    
    module.register_forward_hook(record_fwd_scale)
    
    def record_bwd_scales(_module, grad_input, grad_output):
        grad_input = list(grad_input)
        for g in grad_input:
            if g is not None and isinstance(g, Tensor) \
                and g.is_floating_point() and len(grad_input) == 1:
                module_stats['grad_x'] = np.log2(g.std().item())
        
        for param_name, param in _module.named_parameters():
            if param_name == "weight":
                module_stats['w'] = np.log2(param.std().item())
                if param.grad is not None:
                    module_stats['grad_w'] = np.log2(param.grad.std().item())
    
    module.register_full_backward_hook(record_bwd_scales)
    
    stats[name] = module_stats

In [7]:
# TODO: hide
import altair as alt

def visualise(stats, subnormal=False):
    df = pd.DataFrame(stats)
    df = df.stack().to_frame('scale (log₂)').reset_index(names=['type', 'op'])
    plot(df, subnormal)

def plot(df, subnormal=False):
    is_x_or_grad_x = (df['type'] == 'x') | (df['type'] == 'grad_x')
    op_order = df[df['type'] == 'x']['op'].tolist()
    colors = ['#6C8EBF', '#FF8000', '#5D8944', '#ED3434']
    x_range = np.arange(-18 if subnormal else -14, 18+1 if subnormal else 16+1, 2)
    
    fp16_min = alt.Chart().mark_rule(strokeDash=(4, 4)).encode(x=alt.datum(-14))
    fp16_min_text = alt.Chart().mark_text(dy=-740).encode(text=alt.Text(value='Min FP16 (normal)'), x=alt.datum(-10))
    fp16_max = alt.Chart().mark_rule(strokeDash=(4, 4)).encode(x=alt.datum(16))
    fp16_max_text = alt.Chart().mark_text(dy=-740).encode(text=alt.Text(value='Max FP16'), x=alt.datum(13))
    
    x_chart = alt.Chart(df[is_x_or_grad_x]).mark_line().encode(
        x=alt.X(
            'scale (log₂):Q',
            axis=alt.Axis(orient='top', values=x_range),
            scale=alt.Scale(domain=[x_range[0], x_range[-1]]),
        ),
        y=alt.Y('op:O', title='', sort=op_order),
        color=alt.Color(
            'type',
            legend=alt.Legend(title='', labelFontSize=12, symbolSize=100),
            scale=alt.Scale(range=colors[:2]),
            sort='descending'
        ),
    )
    w_chart = alt.Chart(df[~is_x_or_grad_x]).mark_point(size=100).encode(
        x=alt.X(
            'scale (log₂):Q',
            axis=alt.Axis(orient='top', values=x_range),
            scale=alt.Scale(domain=[x_range[0], x_range[-1]])
        ),
        y=alt.Y('op:O', title='', sort=op_order),
        color=alt.Color(
            'type',
            legend=alt.Legend(title='', labelFontSize=12, symbolSize=100),
            scale=alt.Scale(range=colors[2:]),
            sort='descending'
        ),
        shape=alt.Shape(
            'type',
            scale=alt.Scale(range=['square', 'triangle-down']),
            sort='descending'
        ),
    )
    layers = [x_chart, w_chart]
    if subnormal:
        layers += [fp16_min, fp16_max, fp16_min_text, fp16_max_text] 
    combined_chart = alt.layer(
        *layers
    ).resolve_scale(
        color='independent', shape='independent'
    ).configure_axis(
        labelFontSize=12,
        titleFontSize=16
    ).properties(
        width=500
    )
    display(combined_chart)

Now we can analyse the scale (i.e. standard deviation) of each operation within the MLP. To do this, we provide an `instrument()` operation, which goes through and tracks the scale coming out of each operation in the forward and backward pass.

Our network's weights will also use the standard glorot initialisation.

Let's feed in a unit normal tensor in both directions, and examine the result.

In [8]:
def analyse_mlp(mlp, config, batch_size=64, seq_len=16):
    stats = instrument(mlp)
    x = torch.normal(0.0, 1.0, size=(batch_size, seq_len, config.hidden_size))
    y = mlp(x)
    y.backward(torch.normal(0.0, 1.0, size=y.shape))
    visualise(stats)

In [9]:
def glorot_init(fan_in, fan_out):
    return ((fan_in + fan_out) / 2) ** -0.5

config = NanoGPTConfig()
mlp = MLP(config).apply(init_weights(glorot_init))
analyse_mlp(mlp, config)

**How to interpret this chart:** The x-axis shows the scale in $\log_2$ form. The range here represents the minimum and maximum absolute values that can be represented in the FP16 or FP8 E5 number formats (excluding subnormal values). In other words, if the scale exceeds the x-axis bounds we're in the numerics "danger zone" where training begins to degrade.

The y-axis shows the operations in the model, in the order in which they execute. We show the scale of activations (x), gradients (grad_x, grad_w) and weights (w). Activations can be said to "flow" forward through these layers, and grad_xs flow backwards, so we represent these by solid lines, and weights and grad_ws by symbols.

By the end of both passes the x and grad_x scales have dropped by half. This is due to glorot slightly under-scaling values, and GeLU dropping the scale further.

Worse, the weights are significantly under-scaled and the grad_ws over-scaled by a factor of $2^4$. This is largely because glorot scaling (along with all other weight init schemes) only accounts for the forward and grad_x scales.

#### Unit scaling

We'll now implement the equivalent layer using unit scaling, beginning with a basic linear operation:

In [10]:
# TODO: hide
class ScaledGrad(torch.autograd.Function):
  @staticmethod
  def forward(ctx, X, alpha, beta):
    ctx.save_for_backward(
      torch.tensor(beta, dtype=X.dtype))
    return alpha * X

  @staticmethod
  def backward(ctx, grad_Y):
    beta, = ctx.saved_tensors
    return beta * grad_Y, None, None

def scaled(X, alpha=1.0, beta=1.0):
  # Forward: Y = X * alpha
  # Backward: grad_X = grad_Y * beta
  return ScaledGrad.apply(X, alpha, beta)

def geometric_mean(xs):
    xs = np.array(xs)
    return xs.prod() ** (1 / xs.size)

In [11]:
class UnitScaledLinear(nn.Linear):
    def __init__(self, *args, scale_for="fwd, grad_x", **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.scale_for = scale_for
    
    def get_scales(self, input):
        fwd_scale = self.weight.shape[1] ** -0.5
        grad_x_scale = self.weight.shape[0] ** -0.5
        grad_w_scale = np.prod(input.shape[:-1]) ** -0.5
        if self.scale_for == "fwd":
            grad_x_scale = fwd_scale
        elif self.scale_for == "grad_x":
            fwd_scale = grad_x_scale
        elif self.scale_for == "fwd, grad_x":
            fwd_scale = grad_x_scale = geometric_mean([fwd_scale, grad_x_scale])
        else:
            assert self.scale_for == "separate", f"demo implementation has no {self.scale_for} scaling"
        return fwd_scale, grad_x_scale, grad_w_scale
           
    def forward(self, input):
        fwd_scale, grad_x_scale, grad_w_scale = self.get_scales(input)
        input = scaled(input, beta=grad_x_scale)
        weight = scaled(self.weight, beta=grad_w_scale)
        bias = scaled(self.bias, beta=grad_w_scale) if self.bias is not None else None
        output = F.linear(input, weight, bias)
        return scaled(output, alpha=fwd_scale)

Let's break this down a bit. The `forward()` method is still based on the fundamental `F.linear(input, weight, bias)` operation, but has each of its operands and output scaled.

This is done via the `scaled(X, alpha, beta)` method. This is a special operation in that it has different dynamics in the forward and backward pass, where we have

$$Y = X \cdot \alpha$$

$$\nabla_X = \nabla_Y \cdot \beta$$

So what scaling factors do we choose? This is determined in `get_scales()`. The standard approach is to set the forward and grad_x scales as a compromise between their "ideal" scale-preserving values (via a geometric mean). See our paper for more details on how we arrive at these ideal values.

We also provide versions of the operation which select forward and grad_x scales based on only one of their ideal values. We also have a scheme which has separate forward and grad_x scales, though this is only allowed in special circumstances (again, see the paper).

Now for the GeLU. This is pretty similar to the linear op. However, as activation functions are nonlinear, our usual approach of calculating scaling values analytically (i.e. by working through the maths) isn't always possible.

Fortunately, for these elemenwise ops we can calulate them empirically, like so:

In [12]:
def analyse_elemenwise_fn(fn, num_samples=2**22):
    x = torch.normal(0.0, 1.0, size=(num_samples,)).requires_grad_()
    y = fn(x)
    y.backward(torch.normal(0.0, 1.0, size=(num_samples,)))
    print(f"fwd scale={y.std():.3f}, bwd scale={x.grad.std():.3f}")

print("GeLU:", end=" ")
analyse_elemenwise_fn(GELUActivation())
print("Tanh:", end=" ")
analyse_elemenwise_fn(nn.Tanh())

GeLU: fwd scale=0.587, bwd scale=0.675
Tanh: fwd scale=0.628, bwd scale=0.682


We then put these scaling factors into our unit-scaled op as follows:

In [13]:
class UnitScaledGELU(GELUActivation):
    def forward(self, input):
        fwd_scale = bwd_scale = geometric_mean([0.588, 0.675]) ** -1
        input = scaled(input, beta=bwd_scale)
        output = self.act(input)
        return scaled(output, alpha=fwd_scale)

Simple! Now we just need to unit scale the layernorm and dropout and we're ready.

These take the same kind of approach, so we won't go into too much detail here:

In [14]:
class UnitScaledLayerNorm(nn.LayerNorm):
    def forward(self, input: Tensor) -> Tensor:
        scale = (np.prod(self.normalized_shape) / input.nelement()) ** 0.5
        weight = scaled(self.weight, beta=scale)
        bias = scaled(self.bias, beta=scale)
        return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

class UnitScaledDropout(nn.Dropout):
    def forward(self, input: Tensor) -> Tensor:
        # Dropout is typically implemented with a (1-p) ** -1 scaling
        # However, to preserve variance this ought to be (1-p) ** -0.5
        # We correct for this by multiplying by (1-p) ** 0.5
        scale = (1-self.p) ** 0.5
        input = scaled(input, beta=scale)
        output = F.dropout(input, self.p, self.training, self.inplace)
        return scaled(output, alpha=scale)

We're now ready to unit-scale the full MLP 🥳

All this requires is swapping out the old layers for our new ones:

In [15]:
class UnitScaledMLP(MLP):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.linear_1 = UnitScaledLinear(config.hidden_size, config.hidden_size * 4)
        self.act = UnitScaledGELU()
        self.linear_2 = UnitScaledLinear(config.hidden_size * 4, config.hidden_size)
        self.dropout = UnitScaledDropout(config.dropout)

We also change our initialisation to give our weights unit scale. Let's analyse our new unit-scaled MLP:

In [16]:
def unit_init(*args):
    return 1

mlp = UnitScaledMLP(config).apply(init_weights(unit_init))
analyse_mlp(mlp, config)

This is much better! The final scales are almost exactly 1 in both directions, and weights and grad_ws look much better.

The compromise fwd & grad_x scaling factors in our linear layers do give temporary non-unit scaling, but this is minor and the second linear layer cancels out the first.

## The self-attention layer

#### Regular scaling

We start with a basic implementation of a (self-)attention module.

(Note that we use [ALiBi](https://arxiv.org/abs/2108.12409) biases over positional embeddings here. This is primarily for simplicity, though these have also been [shown to perform](https://arxiv.org/abs/2210.15424) remarkably well!)

In [17]:
# TODO: hide
class Matmul(nn.Module):
    def forward(self, a: Tensor, b: Tensor, scale: float=1.0) -> Tensor:
        return (a @ b) * scale

def split_heads(tensor: Tensor, num_heads: int, head_size: int) -> Tensor:
    batch_size, seq_len, hidden_size = tensor.shape
    tensor = tensor.view(batch_size, seq_len, num_heads, head_size)
    return tensor.permute(0, 2, 1, 3)


def merge_heads(tensor: Tensor, num_heads: int) -> Tensor:
    tensor = tensor.permute(0, 2, 1, 3).contiguous()
    batch_size, seq_len, num_heads, head_size = tensor.shape
    return tensor.view(batch_size, seq_len, num_heads * head_size)


def causal_mask(seq_len: int, num_heads: int) -> Tensor:
    causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.float16))
    causal_mask = causal_mask.view(1, 1, seq_len, seq_len)
    alibi_mask = gen_alibi_mask(causal_mask, num_heads)
    causal_mask = (1.0 - causal_mask) * -10_000
    return alibi_mask + causal_mask


# Based on https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/alibi/__init__.py
def gen_alibi_mask(causal_mask: Tensor, num_heads: int) -> Tensor:
    distances = causal_mask.to(torch.float32).cumsum(dim=-1)
    slopes = gen_slopes(num_heads)
    return distances.to(torch.float16) * slopes.view(1, num_heads, 1, 1)


def gen_slopes(num_heads: int) -> Tensor:
    n = 2 ** math.floor(math.log2(num_heads))
    m_0 = 2.0 ** (-8.0 / n)
    m = torch.pow(m_0, torch.arange(1, 1 + n))
    if n < num_heads:
        m_hat_0 = 2.0 ** (-4.0 / n)
        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
        m = torch.cat([m, m_hat])
    return m

In [18]:
class Attention(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.ln = nn.LayerNorm(config.hidden_size)
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        
        self.w_qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.qk_matmul = Matmul()
        self.softmax = nn.Softmax(-1)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.qkv_matmul = Matmul()
        self.w_o = nn.Linear(config.hidden_size, config.hidden_size)
        self.residual_dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        seq_len = hidden_states.shape[1]
        head_size = self.hidden_size // self.num_heads

        hidden_states = self.ln(hidden_states)
        q_k_v = self.w_qkv(hidden_states)
        q, k, v = q_k_v.split(self.hidden_size, dim=-1)
        q, k, v = (split_heads(t, self.num_heads, head_size) for t in (q, k, v))

        qk = self.qk_matmul(q, k.transpose(-1, -2), scale=1 / head_size ** 0.5).clone()
        qk += causal_mask(seq_len, self.num_heads) + attention_mask
        qk = self.softmax(qk)
        qk = self.attn_dropout(qk)

        qkv = self.qkv_matmul(qk, v)
        qkv = merge_heads(qkv, self.num_heads)
        qkvo = self.w_o(qkv)
        return self.residual_dropout(qkvo).clone()

In [19]:
def analyse_attn(attention, config, batch_size=64, seq_len=16):
    stats = instrument(attention)
    x = torch.normal(0.0, 1.0, size=(batch_size, seq_len, config.hidden_size))
    attention_mask = torch.zeros(batch_size, 1, 1, seq_len)
    y = attention(x, attention_mask)
    y.backward(torch.normal(0.0, 1.0, size=y.shape))
    visualise(stats)

In [20]:
attention = Attention(config).apply(init_weights(glorot_init))
analyse_attn(attention, config)

Again here, the activation scales and grad_x scales fall by half as they go through the layer, and fluctuate in between (Note that we omit the softmax operation here as a) its output isn't normally distributed, and b) we usually do it in higher-precision anyway). We see the same problems as before for weights and grad_ws.

Let's fix this:

#### Unit scaling

In [21]:
class UnitScaledAttention(Attention):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.w_qkv = UnitScaledLinear(
            config.hidden_size, 3 * config.hidden_size, scale_for="fwd"
        )
        self.qk_matmul = UnitScaledMatmul(scale_for="fwd")
        self.softmax = UnitScaledSoftmax(-1)
        self.attn_dropout = UnitScaledDropout(config.dropout)
        self.qkv_matmul = UnitScaledMatmul(scale_for="fwd")
        self.w_o = UnitScaledLinear(config.hidden_size, config.hidden_size)
        self.residual_dropout = UnitScaledDropout(config.dropout)

class UnitScaledMatmul(Matmul):
    def __init__(self, scale_for="fwd, grad_a, grad_b") -> None:
        super().__init__()
        self.scale_for = scale_for
    
    def get_scales(self, a: Tensor, b: Tensor):
        fwd_scale = a.shape[-1] ** -0.5
        grad_a_scale = b.shape[-1] ** -0.5
        grad_b_scale = a.shape[-2] ** -0.5
        if self.scale_for == "fwd":
            grad_a_scale = grad_b_scale = fwd_scale
        elif self.scale_for == "fwd, grad_a, grad_b":
            fwd_scale = grad_a_scale = grad_b_scale = geometric_mean([
                fwd_scale, grad_a_scale, grad_b_scale
            ])
        else:
            assert False, f"demo implementation has no {self.scale_for} scaling"
        return fwd_scale, grad_a_scale, grad_b_scale
    
    def forward(self, a: Tensor, b: Tensor, scale: float=1.0) -> Tensor:
        # ignores provided scale
        fwd_scale, grad_a_scale, grad_b_scale = self.get_scales(a, b)
        a = scaled(a, beta=grad_a_scale)
        b = scaled(b, beta=grad_b_scale)
        output = a @ b
        return scaled(output, alpha=fwd_scale)

class UnitScaledSoftmax(nn.Softmax):
    def forward(self, input: Tensor) -> Tensor:
        fwd_scale = bwd_scale = input.shape[-1] ** 0.5
        input = scaled(input, fwd_scale)
        output = F.softmax(input, self.dim, _stacklevel=5)
        return scaled(output, bwd_scale)

To unit scale the attention layer we again swap out regular layers for unit-scaled ones. This requires altering two more operations: matmul and softmax.

Like the linear layer, we provide scaling the matmul based on varying criteria. We find that scaling each matmul and linear layer here for the forward pass ensures good scaling in both directions. We see this empirically in our analysis:

In [22]:
attention = UnitScaledAttention(config).apply(init_weights(unit_init))
analyse_attn(attention, config, batch_size=64, seq_len=64)

Again, a great improvement for all four types of tensors, with each having approximate unit scale. Note that these scaling rules are robust to changes in hyperparameters too. You can experiment with the initial config to verify this.

We now have our key building-blocks 🧱 Let's put it all together!

### The transformer block

#### Regular scaling

In [23]:
class TransformerBlock(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.attention_layer = Residual(Attention(config))
        self.mlp_layer = Residual(MLP(config))
        
    def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        hidden_states = self.attention_layer(hidden_states, attention_mask)
        return self.mlp_layer(hidden_states)


class Residual(nn.Module):
    def __init__(self, f: nn.Module):
        super().__init__()
        self.f = f
    
    def forward(self, x: Tensor, *args):
        return x + self.f(x, *args)

#### Unit scaling

In [24]:
class UnitScaledTransformerBlock(TransformerBlock):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.attention_layer = UnitScaledResidual(UnitScaledAttention(config))
        self.mlp_layer = UnitScaledResidual(UnitScaledMLP(config))

class UnitScaledResidual(Residual):
    def __init__(self, f: nn.Module, tau: float=0.2):
        super().__init__(f)
        self.tau = tau
    
    def forward(self, x: Tensor, *args):
        y = x * (1 - self.tau) ** 0.5
        z = scaled(x, beta=self.tau ** 0.5)
        z = self.f(z, *args)
        z = scaled(z, alpha=self.tau ** 0.5)
        return y + z

Note that the scaling in our residual layer includes an important trick.

It's necessary for unit-scaled models to use a weighted sum when doing their residual-add, to down-weight the residual/trunk branch. If this is not included, training can fail as too much signal comes from each residual (regular models avoid this as their residual implicitly reduces scale).

However, the naïve implementation of this breaks unit scale. We get around this by delaying the weighting until the end of the residual branch in the backward pass (see paper for more details).

### The full transformer

#### Regular scaling

Note that our implementation contains a little extra boilerplate to make it 🤗-compliant.

In [25]:
class NanoGPTModel(PreTrainedModel):
    config_class = NanoGPTConfig

    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.input_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)
        self.ln = nn.LayerNorm(config.hidden_size)
        self.blocks = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.loss_fn = nn.CrossEntropyLoss()
        self.apply(init_weights(glorot_init))

    def forward(
        self,
        input_ids: Tensor,
        attention_mask: Tensor,
        position_ids: Optional[Tensor] = None,
        labels: Optional[Tensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithCrossAttentions:
        attention_mask = attention_mask[:, None, None, :].to(dtype=self.dtype)
        attention_mask = (1.0 - attention_mask) * -10_000

        hidden_states = self.input_embeddings(input_ids)
        hidden_states = self.dropout(hidden_states)
        for block in self.blocks:
            hidden_states = block(hidden_states, attention_mask=attention_mask)
        hidden_states = self.ln(hidden_states)
        logits = self.lm_head(hidden_states)

        if labels is None:
            return CausalLMOutputWithCrossAttentions(logits=logits)
        else:
            labels = torch.roll(labels, -1, 1)
            labels[:, -1] = -100  # By default, ignore_index of CrossEntropyLoss is -100
            loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
            return CausalLMOutputWithCrossAttentions(loss=loss)

    def get_input_embeddings(self) -> nn.Embedding:
        return self.input_embeddings

    def prepare_inputs_for_generation(
        self, input_ids: Tensor, **kwargs
    ) -> Dict[str, Tensor]:
        attention_mask = kwargs["attention_mask"]
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        return {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
        }

In [26]:
def analyse_full_model(model, config, batch_size=64, seq_len=16):
    stats = instrument(model)
    input_ids = labels = torch.randint(0, config.vocab_size, size=(batch_size, seq_len))
    attention_mask = torch.zeros(batch_size, seq_len)
    y = model(input_ids, attention_mask, labels=labels).loss
    y.backward()
    visualise(stats, subnormal=True)

Here's the scaling for the entire (non-unit-scaled) transformer:

In [27]:
model = NanoGPTModel(config)
analyse_full_model(model, config)

These results clearly demonstrates the inadiquacy of the standard approach to model design when it comes to scale. Grad_x and grad_w values are very far from having unit scale, with the former now relying on subnormal values (dangerous, as these begin to lose precision and disappear altogether at $2^{-24}$).

A key issue introduced by this layer is the cross-entropy loss definition. This typically uses a `1 / batch_size` term to average over labels, which has the effect of dramatically under-scaling gradients in the backward pass.

We can also see here that a single loss scaling factor for all gradients isn't ideal. This would shift all grad_xs and grad_ws to the right by the same amount, which would still leave us far short of our unit-scale ideal.

What we really need is per-op scaling! Again, we'll fix this for the unit-scaled implementation:

#### Unit scaling

In [28]:
class UnitScaledNanoGPTModel(NanoGPTModel):
    config_class = NanoGPTConfig

    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.input_embeddings = UnitScaledEmbedding(config.vocab_size, config.hidden_size)
        self.dropout = UnitScaledDropout(config.dropout)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.blocks = nn.ModuleList([
            UnitScaledTransformerBlock(config)
            for _ in range(config.num_hidden_layers)
        ])
        self.lm_head = UnitScaledLinear(
            config.hidden_size, config.vocab_size, bias=False, scale_for="separate"
        )
        self.loss_fn = UnitScaledCrossEntropyLoss()
        self.apply(init_weights(unit_init))

class UnitScaledEmbedding(nn.Embedding):
    def forward(self, input: Tensor) -> Tensor:
        batch_size = np.prod(input.shape)
        weight = scaled(self.weight, beta=self.num_embeddings / batch_size)
        return F.embedding(input, weight, self.padding_idx, self.max_norm,
                           self.norm_type, self.scale_grad_by_freq, self.sparse)

class UnitScaledCrossEntropyLoss(nn.CrossEntropyLoss):    
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        batch_size, seq_len = input.shape
        input = scaled(input, beta=seq_len / (seq_len - 1) ** 0.5)
        loss = F.cross_entropy(input, target, reduction='sum')
        return scaled(loss, alpha=1 / batch_size)

In [29]:
model = UnitScaledNanoGPTModel(config)
analyse_full_model(model, config)

Success! 🥂 🎊 🍾 We've managed to keep every tensor in the model at unit scale for the forward and backward pass. All that was required was to re-implement standard layers with the right scaling factors.

Of course, this only gives us the right scaling *at initialisation*. Scales inevitably drift throughout training, but we've given ourselves headroom by starting in the ideal place. The results in our paper show that this is sufficient to enable accurate, out-the-box training of many models, up to the size of BERT Large (we haven't tested anything larger—yet!).

## Training

#### Regular scaling

To show that this really does work, let's train both our models.

As in Karpathy's original NanoGPT, we'll train on the small Shakespeare dataset for a few minutes and see if we can get something that captures the rough style. Starting with the regular (non-unit-scaled) model:

In [30]:
# TODO: hide
from optimum.graphcore.generation_utils import IPUGenerationMixin
from optimum.graphcore.modeling_utils import (
    PipelineMixin,
    outline_attribute,
    register,
    tied_weight_model,
)

@tied_weight_model(NanoGPTModel)
@register(NanoGPTModel)
class PipelinedNanoGPTModel(NanoGPTModel, PipelineMixin, IPUGenerationMixin):
    def parallelize(self):
        self._hooks = [outline_attribute(self.ln, "LayerNorm")]

    def deparallelize(self):
        pass

@tied_weight_model(UnitScaledNanoGPTModel)
@register(UnitScaledNanoGPTModel)
class PipelinedUnitScaledNanoGPTModel(
    UnitScaledNanoGPTModel, PipelineMixin, IPUGenerationMixin
):
    def parallelize(self):
        self._hooks = [outline_attribute(self.ln, "UnitScaledLayerNorm")]

    def deparallelize(self):
        pass

In [31]:
import poptorch_experimental_addons as pea

def scaled(X, alpha=1.0, beta=1.0):
  # Forward: Y = X * alpha
  # Backward: grad_X = grad_Y * beta
  return pea.autograd_proxy(X * alpha, X * beta)

In [32]:
from functools import partial

from datasets.load import load_dataset
from optimum.graphcore import (
    IPUConfig,
    IPUTrainer,
    IPUTrainingArguments,
    pipeline,
    pipelines,
)
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

In [33]:
dataset = load_dataset("tiny_shakespeare")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

config = NanoGPTConfig(vocab_size=len(tokenizer), eos_token_id=tokenizer.eos_token_id)
batch_sz = 16
seq_len = 128
ipu_config = IPUConfig(
    gradient_accumulation_steps=5 * 64 // batch_sz,
    layers_per_ipu=[config.num_hidden_layers],
    executable_cache_dir="./exe_cache",
)

def split_and_tokenize(data, seq_len, batch_sz):
    tokens = tokenizer(data["text"])
    seqs = [
        tokens["input_ids"][0][i : i + seq_len]
        for i in range(0, len(tokens["input_ids"][0]), seq_len)
    ]
    seqs = seqs[: int(len(seqs) / batch_sz) * batch_sz]  # make divisible by batch size
    return {"input_ids": seqs}


prep_data = partial(split_and_tokenize, seq_len=seq_len, batch_sz=batch_sz)
tokenized_dataset = dataset.map(
    prep_data, batched=True, remove_columns=dataset["train"].column_names
)

Downloading builder script:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

Downloading and preparing dataset tiny_shakespeare/default to /root/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e...


Downloading data:   0%|          | 0.00/435k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1 [00:00<?, ? examples/s]

Dataset tiny_shakespeare downloaded and prepared to /root/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.59k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/698 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.50k [00:00<?, ?B/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [34]:
train_args = IPUTrainingArguments(
    output_dir="out",
    per_device_train_batch_size=batch_sz,
    per_device_eval_batch_size=batch_sz,
    evaluation_strategy="steps",
    eval_steps=250,
    logging_steps=20,
    max_steps=1000,
    weight_decay=0.1,
    warmup_steps=100,
    lr_scheduler_type="linear",
    learning_rate=1e-3,
    loss_scaling=2**6,
)

model = NanoGPTModel(config)

trainer = IPUTrainer(
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    model=model,
    args=train_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    ipu_config=ipu_config,
)

trainer.train()
trainer.save_model("trained_model/")

max_steps is given, it will override any value given in num_train_epochs
Compiling Model...
Graph compilation: 100%|██████████| 100/100 [01:24<00:00]
Compiled/Loaded model in 89.9241143650006 secs
***** Running training *****
  Num examples = 7840
  Num Epochs = 42
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 320
  Gradient Accumulation steps = 20
  Total optimization steps = 1000


  0%|          | 0/1000 [00:00<?, ?it/s]

{'loss': 4.5586, 'learning_rate': 0.0002, 'epoch': 0.83}
{'loss': 2.804, 'learning_rate': 0.0004, 'epoch': 1.67}
{'loss': 2.4454, 'learning_rate': 0.0006, 'epoch': 2.5}
{'loss': 2.2854, 'learning_rate': 0.0008, 'epoch': 3.33}
{'loss': 2.1626, 'learning_rate': 0.001, 'epoch': 4.17}
{'loss': 2.0473, 'learning_rate': 0.0009777777777777777, 'epoch': 5.0}
{'loss': 1.9735, 'learning_rate': 0.0009555555555555556, 'epoch': 5.83}
{'loss': 1.8992, 'learning_rate': 0.0009333333333333333, 'epoch': 6.67}
{'loss': 1.8237, 'learning_rate': 0.0009111111111111111, 'epoch': 7.5}
{'loss': 1.8095, 'learning_rate': 0.0008888888888888888, 'epoch': 8.33}
{'loss': 1.7489, 'learning_rate': 0.0008666666666666667, 'epoch': 9.17}
{'loss': 1.7393, 'learning_rate': 0.0008444444444444444, 'epoch': 10.0}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation:   3%|▎         | 3/100 [00:00<00:15][A
Graph compilation:   4%|▍         | 4/100 [00:00<00:19][A
Graph compilation:   7%|▋         | 7/100 [00:01<00:23][A
Graph compilation:  17%|█▋        | 17/100 [00:01<00:06][A
Graph compilation:  21%|██        | 21/100 [00:02<00:06][A
Graph compilation:  24%|██▍       | 24/100 [00:02<00:05][A
Graph compilation:  27%|██▋       | 27/100 [00:04<00:17][A
Graph compilation:  31%|███       | 31/100 [00:04<00:11][A
Graph compilation:  33%|███▎      | 33/100 [00:05<00:13][A
Graph compilation:  35%|███▌      | 35/100 [00:07<00:23][A
Graph compilation:  41%|████      | 41/100 [00:08<00:15][A
Graph compilation:  45%|████▌     | 45/100 [00:08<00:10][A
Graph compilation:  49%|████▉     | 49/100 [00:08<00:08][A
Graph compilation:  51%|█████     | 51/100 [00:08<00:07][A
Graph compilation:  53%|█████▎    | 53/100 [00:09<00:08][A
Graph compilation:  55%|████

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 1.7431640625, 'eval_runtime': 0.5885, 'eval_samples_per_second': 734.109, 'eval_steps_per_second': 45.882, 'epoch': 10.42}



Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:01<00:00][A


{'loss': 1.7048, 'learning_rate': 0.0008222222222222222, 'epoch': 10.83}
{'loss': 1.6518, 'learning_rate': 0.0008, 'epoch': 11.67}
{'loss': 1.6482, 'learning_rate': 0.0007777777777777778, 'epoch': 12.5}
{'loss': 1.6165, 'learning_rate': 0.0007555555555555555, 'epoch': 13.33}
{'loss': 1.6036, 'learning_rate': 0.0007333333333333333, 'epoch': 14.17}
{'loss': 1.5772, 'learning_rate': 0.0007111111111111111, 'epoch': 15.0}
{'loss': 1.5662, 'learning_rate': 0.000688888888888889, 'epoch': 15.83}
{'loss': 1.5594, 'learning_rate': 0.0006666666666666666, 'epoch': 16.67}
{'loss': 1.521, 'learning_rate': 0.0006444444444444444, 'epoch': 17.5}
{'loss': 1.5226, 'learning_rate': 0.0006222222222222223, 'epoch': 18.33}
{'loss': 1.5203, 'learning_rate': 0.0006, 'epoch': 19.17}
{'loss': 1.5033, 'learning_rate': 0.0005777777777777778, 'epoch': 20.0}
{'loss': 1.491, 'learning_rate': 0.0005555555555555556, 'epoch': 20.83}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:00<00:00][A
Compiled/Loaded model in 3.15173802299978 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16


  0%|          | 0/27 [00:00<?, ?it/s]

Saving model checkpoint to out/checkpoint-500
Configuration saved in out/checkpoint-500/ipu_config.json


{'eval_loss': 1.5712890625, 'eval_runtime': 0.4452, 'eval_samples_per_second': 970.327, 'eval_steps_per_second': 60.645, 'epoch': 20.83}



Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation:   3%|▎         | 3/100 [00:19<10:23][A
Graph compilation:   4%|▍         | 4/100 [00:23<09:04][A
Graph compilation:   7%|▋         | 7/100 [00:26<04:38][A
Graph compilation:  13%|█▎        | 13/100 [00:26<01:39][A
Graph compilation:  16%|█▌        | 16/100 [00:27<01:11][A
Graph compilation:  21%|██        | 21/100 [00:27<00:42][A
Graph compilation:  23%|██▎       | 23/100 [00:30<00:54][A
Graph compilation:  25%|██▌       | 25/100 [00:43<02:27][A
Graph compilation:  26%|██▌       | 26/100 [00:43<02:08][A
Graph compilation:  27%|██▋       | 27/100 [00:44<01:49][A
Graph compilation:  28%|██▊       | 28/100 [00:44<01:36][A
Graph compilation:  29%|██▉       | 29/100 [00:44<01:17][A
Graph compilation:  30%|███       | 30/100 [00:45<01:12][A
Graph compilation:  31%|███       | 31/100 [00:46<00:58][A
Graph compilation:  32%|███▏      | 32/100 [00:46<00:44][A
Graph compilation:  33%|███▎      | 33/100 [00:

{'loss': 1.4705, 'learning_rate': 0.0005333333333333334, 'epoch': 21.67}
{'loss': 1.4702, 'learning_rate': 0.0005111111111111111, 'epoch': 22.5}
{'loss': 1.4462, 'learning_rate': 0.0004888888888888889, 'epoch': 23.33}
{'loss': 1.4544, 'learning_rate': 0.00046666666666666666, 'epoch': 24.17}
{'loss': 1.4625, 'learning_rate': 0.0004444444444444444, 'epoch': 25.0}
{'loss': 1.4357, 'learning_rate': 0.0004222222222222222, 'epoch': 25.83}
{'loss': 1.4208, 'learning_rate': 0.0004, 'epoch': 26.67}
{'loss': 1.4187, 'learning_rate': 0.00037777777777777777, 'epoch': 27.5}
{'loss': 1.4232, 'learning_rate': 0.00035555555555555557, 'epoch': 28.33}
{'loss': 1.385, 'learning_rate': 0.0003333333333333333, 'epoch': 29.17}
{'loss': 1.4109, 'learning_rate': 0.0003111111111111111, 'epoch': 30.0}
{'loss': 1.4041, 'learning_rate': 0.0002888888888888889, 'epoch': 30.83}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation:   3%|▎         | 3/100 [00:00<00:15][A
Graph compilation:   4%|▍         | 4/100 [00:00<00:19][A
Graph compilation:   7%|▋         | 7/100 [00:01<00:22][A
Graph compilation:  17%|█▋        | 17/100 [00:01<00:06][A
Graph compilation:  21%|██        | 21/100 [00:02<00:05][A
Graph compilation:  24%|██▍       | 24/100 [00:02<00:05][A
Graph compilation:  27%|██▋       | 27/100 [00:04<00:17][A
Graph compilation:  31%|███       | 31/100 [00:04<00:12][A
Graph compilation:  33%|███▎      | 33/100 [00:05<00:13][A
Graph compilation:  35%|███▌      | 35/100 [00:07<00:24][A
Graph compilation:  41%|████      | 41/100 [00:08<00:16][A
Graph compilation:  45%|████▌     | 45/100 [00:08<00:10][A
Graph compilation:  49%|████▉     | 49/100 [00:08<00:08][A
Graph compilation:  51%|█████     | 51/100 [00:09<00:07][A
Graph compilation:  53%|█████▎    | 53/100 [00:09<00:08][A
Graph compilation:  55%|████

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 1.525390625, 'eval_runtime': 0.5841, 'eval_samples_per_second': 739.651, 'eval_steps_per_second': 46.228, 'epoch': 31.25}



Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:01<00:00][A


{'loss': 1.3877, 'learning_rate': 0.0002666666666666667, 'epoch': 31.67}
{'loss': 1.379, 'learning_rate': 0.00024444444444444443, 'epoch': 32.5}
{'loss': 1.3865, 'learning_rate': 0.0002222222222222222, 'epoch': 33.33}
{'loss': 1.3761, 'learning_rate': 0.0002, 'epoch': 34.17}
{'loss': 1.3676, 'learning_rate': 0.00017777777777777779, 'epoch': 35.0}
{'loss': 1.3646, 'learning_rate': 0.00015555555555555556, 'epoch': 35.83}
{'loss': 1.3451, 'learning_rate': 0.00013333333333333334, 'epoch': 36.67}
{'loss': 1.3484, 'learning_rate': 0.0001111111111111111, 'epoch': 37.5}
{'loss': 1.3531, 'learning_rate': 8.888888888888889e-05, 'epoch': 38.33}
{'loss': 1.3556, 'learning_rate': 6.666666666666667e-05, 'epoch': 39.17}
{'loss': 1.3431, 'learning_rate': 4.4444444444444447e-05, 'epoch': 40.0}
{'loss': 1.3423, 'learning_rate': 2.2222222222222223e-05, 'epoch': 40.83}
{'loss': 1.3501, 'learning_rate': 0.0, 'epoch': 41.67}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:00<00:00][A
Compiled/Loaded model in 3.2015101830002095 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16


  0%|          | 0/27 [00:00<?, ?it/s]

Saving model checkpoint to out/checkpoint-1000
Configuration saved in out/checkpoint-1000/ipu_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to trained_model/
Configuration saved in trained_model/ipu_config.json


{'eval_loss': 1.5146484375, 'eval_runtime': 0.44, 'eval_samples_per_second': 981.813, 'eval_steps_per_second': 61.363, 'epoch': 41.67}
{'train_runtime': 296.0934, 'train_samples_per_second': 1080.74, 'train_steps_per_second': 3.377, 'train_loss': 1.653693359375, 'epoch': 41.67}


Note that we use loss scaling here, though this model is sufficiently small that it can get away without loss scaling. This certainly doesn't hold for larger models though.

To really get a sense of how the model's doing, we ought to get it to write some Shakespeare. Here's two attempts at *The Tempest*:

In [35]:
pipelines.check_model_type = lambda self, supported_models: ...

final_model = NanoGPTModel.from_pretrained("trained_model/")

TEST_INPUT = """PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits"""
pipe = pipeline(
    "text-generation",
    ipu_config=ipu_config.to_dict(),  # TODO: feature request -> no to_dict()
    model=final_model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=True,
)

outputs = pipe(TEST_INPUT, num_return_sequences=2, temperature=0.4)
for i, output in enumerate(outputs):
    print(f"\n===== Completion [{i+1}] =====\n")
    print(output["generated_text"])

IPUConfig {
  "auto_loss_scaling": false,
  "device_iterations": 1,
  "embedding_serialization_factor": 1,
  "enable_half_partials": true,
  "executable_cache_dir": "./exe_cache",
  "execute_encoder_on_cpu_for_generation": false,
  "gradient_accumulation_steps": 20,
  "inference_device_iterations": 1,
  "inference_replication_factor": 1,
  "ipus_per_replica": 1,
  "layers_per_ipu": [
    6
  ],
  "matmul_proportion": 0.2,
  "optimizer_state_offchip": true,
  "optimum_version": "1.6.1",
  "output_mode": "final",
  "recompute_checkpoint_every_layer": false,
  "replicated_tensor_sharding": false,
  "replication_factor": 1,
  "seed": null,
  "transformers_version": "4.25.1"
}

Graph compilation: 100%|██████████| 100/100 [00:23<00:00]



===== Completion [1] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits of his life;
And then to you are to him so little so.

LEONTES:
What is the body in my brother, and then?

LEONTES:
I have been like a substance and great them?

POMPEY:
No, my lord, I thank your friends the news?

POMPEY:
The worship is for a more so leave, and the first,
Which should I have sworn with my service
Than the land his son of yourself.

KING RICHARD III:
As I am devil to the ground of my liege,
And when I 

===== Completion [2] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits of your part,
To the people of your state and so much as live,
And she the bloody and strike the other state
To be the friends of the house of York,
And so a king, and like to the bear the fire
And many shall be as the bloody of the state
To be a subjects of a bawd and life
That would be with for the root of the country's face,
And then 

It's not going to win any literary awards, but pretty good for a few minutes of training. Now let's see if our unit-scaled model can do the same...

#### Unit scaling

All we need to do is swap in our new model, remove the loss scaling (of course!), and increase the learning rate (as our weights are now larger due to unit-initialisation).

In [36]:
# TODO: hide?
from types import MethodType

class _IPUConfig(IPUConfig):
    def to_options(self, *args, **kwargs):
        options = super().to_options(*args, **kwargs)
        options._popart.setPatterns(dict(AutogradProxyOpPattern=True))
        return options

    def for_pod_type(self, *args, **kwargs):
        config = super().for_pod_type(*args, **kwargs)
        config.__class__ = _IPUConfig
        return config

ipu_config.__class__ = _IPUConfig

In [37]:
model = UnitScaledNanoGPTModel(config)

train_args.loss_scaling = 1.0
train_args.learning_rate = 2e-2

🤞 The moment of truth ... let's train our unit-scaled model 🤞

In [38]:
unit_scale_trainer = IPUTrainer(
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    model=model,
    args=train_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    ipu_config=ipu_config,
)

unit_scale_trainer.train()
unit_scale_trainer.save_model("trained_model/")

max_steps is given, it will override any value given in num_train_epochs
Compiling Model...
Graph compilation: 100%|██████████| 100/100 [01:55<00:00]
Compiled/Loaded model in 127.81229283099947 secs
***** Running training *****
  Num examples = 7840
  Num Epochs = 42
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 320
  Gradient Accumulation steps = 20
  Total optimization steps = 1000


  0%|          | 0/1000 [00:00<?, ?it/s]

{'loss': 4.7835, 'learning_rate': 0.004, 'epoch': 0.83}
{'loss': 2.8256, 'learning_rate': 0.008, 'epoch': 1.67}
{'loss': 2.4451, 'learning_rate': 0.012, 'epoch': 2.5}
{'loss': 2.2794, 'learning_rate': 0.016, 'epoch': 3.33}
{'loss': 2.15, 'learning_rate': 0.02, 'epoch': 4.17}
{'loss': 2.056, 'learning_rate': 0.019555555555555555, 'epoch': 5.0}
{'loss': 1.9731, 'learning_rate': 0.019111111111111113, 'epoch': 5.83}
{'loss': 1.9058, 'learning_rate': 0.018666666666666668, 'epoch': 6.67}
{'loss': 1.8505, 'learning_rate': 0.018222222222222223, 'epoch': 7.5}
{'loss': 1.7972, 'learning_rate': 0.017777777777777778, 'epoch': 8.33}
{'loss': 1.7667, 'learning_rate': 0.017333333333333336, 'epoch': 9.17}
{'loss': 1.7384, 'learning_rate': 0.01688888888888889, 'epoch': 10.0}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation:   3%|▎         | 3/100 [00:04<02:20][A
Graph compilation:   4%|▍         | 4/100 [00:04<01:45][A
Graph compilation:   7%|▋         | 7/100 [00:05<00:58][A
Graph compilation:  22%|██▏       | 22/100 [00:05<00:09][A
Graph compilation:  28%|██▊       | 28/100 [00:08<00:15][A
Graph compilation:  32%|███▏      | 32/100 [00:09<00:18][A
Graph compilation:  35%|███▌      | 35/100 [00:12<00:23][A
Graph compilation:  37%|███▋      | 37/100 [00:14<00:29][A
Graph compilation:  39%|███▉      | 39/100 [00:14<00:25][A
Graph compilation:  41%|████      | 41/100 [00:15<00:28][A
Graph compilation:  45%|████▌     | 45/100 [00:16<00:17][A
Graph compilation:  49%|████▉     | 49/100 [00:16<00:12][A
Graph compilation:  51%|█████     | 51/100 [00:16<00:10][A
Graph compilation:  53%|█████▎    | 53/100 [00:17<00:11][A
Graph compilation:  55%|█████▌    | 55/100 [00:18<00:13][A
Graph compilation:  57%|████

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 1.7333984375, 'eval_runtime': 0.8964, 'eval_samples_per_second': 481.929, 'eval_steps_per_second': 30.121, 'epoch': 10.42}



Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:02<00:00][A


{'loss': 1.7074, 'learning_rate': 0.016444444444444446, 'epoch': 10.83}
{'loss': 1.6881, 'learning_rate': 0.016, 'epoch': 11.67}
{'loss': 1.6322, 'learning_rate': 0.015555555555555557, 'epoch': 12.5}
{'loss': 1.6246, 'learning_rate': 0.015111111111111112, 'epoch': 13.33}
{'loss': 1.6069, 'learning_rate': 0.014666666666666666, 'epoch': 14.17}
{'loss': 1.5772, 'learning_rate': 0.014222222222222223, 'epoch': 15.0}
{'loss': 1.5472, 'learning_rate': 0.013777777777777778, 'epoch': 15.83}
{'loss': 1.5531, 'learning_rate': 0.013333333333333332, 'epoch': 16.67}
{'loss': 1.5266, 'learning_rate': 0.01288888888888889, 'epoch': 17.5}
{'loss': 1.5199, 'learning_rate': 0.012444444444444445, 'epoch': 18.33}
{'loss': 1.5043, 'learning_rate': 0.012, 'epoch': 19.17}
{'loss': 1.4862, 'learning_rate': 0.011555555555555555, 'epoch': 20.0}
{'loss': 1.4942, 'learning_rate': 0.011111111111111112, 'epoch': 20.83}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:01<00:00][A
Compiled/Loaded model in 10.19603246600036 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16


  0%|          | 0/27 [00:00<?, ?it/s]

Saving model checkpoint to out/checkpoint-500
Configuration saved in out/checkpoint-500/ipu_config.json


{'eval_loss': 1.5400390625, 'eval_runtime': 0.6393, 'eval_samples_per_second': 675.789, 'eval_steps_per_second': 42.237, 'epoch': 20.83}



Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:02<00:00][A


{'loss': 1.4525, 'learning_rate': 0.010666666666666666, 'epoch': 21.67}
{'loss': 1.455, 'learning_rate': 0.010222222222222221, 'epoch': 22.5}
{'loss': 1.4438, 'learning_rate': 0.009777777777777778, 'epoch': 23.33}
{'loss': 1.4358, 'learning_rate': 0.009333333333333334, 'epoch': 24.17}
{'loss': 1.4157, 'learning_rate': 0.008888888888888889, 'epoch': 25.0}
{'loss': 1.4046, 'learning_rate': 0.008444444444444445, 'epoch': 25.83}
{'loss': 1.3781, 'learning_rate': 0.008, 'epoch': 26.67}
{'loss': 1.3857, 'learning_rate': 0.007555555555555556, 'epoch': 27.5}
{'loss': 1.3919, 'learning_rate': 0.0071111111111111115, 'epoch': 28.33}
{'loss': 1.3537, 'learning_rate': 0.006666666666666666, 'epoch': 29.17}
{'loss': 1.3511, 'learning_rate': 0.006222222222222223, 'epoch': 30.0}
{'loss': 1.3359, 'learning_rate': 0.0057777777777777775, 'epoch': 30.83}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:01<00:00][A
Compiled/Loaded model in 10.092855619999682 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 1.4873046875, 'eval_runtime': 0.6397, 'eval_samples_per_second': 675.286, 'eval_steps_per_second': 42.205, 'epoch': 31.25}



Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:02<00:00][A


{'loss': 1.3398, 'learning_rate': 0.005333333333333333, 'epoch': 31.67}
{'loss': 1.3367, 'learning_rate': 0.004888888888888889, 'epoch': 32.5}
{'loss': 1.327, 'learning_rate': 0.0044444444444444444, 'epoch': 33.33}
{'loss': 1.322, 'learning_rate': 0.004, 'epoch': 34.17}
{'loss': 1.2932, 'learning_rate': 0.0035555555555555557, 'epoch': 35.0}
{'loss': 1.2955, 'learning_rate': 0.0031111111111111114, 'epoch': 35.83}
{'loss': 1.2822, 'learning_rate': 0.0026666666666666666, 'epoch': 36.67}
{'loss': 1.2897, 'learning_rate': 0.0022222222222222222, 'epoch': 37.5}
{'loss': 1.2786, 'learning_rate': 0.0017777777777777779, 'epoch': 38.33}
{'loss': 1.2674, 'learning_rate': 0.0013333333333333333, 'epoch': 39.17}
{'loss': 1.2678, 'learning_rate': 0.0008888888888888889, 'epoch': 40.0}
{'loss': 1.2452, 'learning_rate': 0.00044444444444444447, 'epoch': 40.83}
{'loss': 1.2518, 'learning_rate': 0.0, 'epoch': 41.67}


Compiling Model...

Graph compilation:   0%|          | 0/100 [00:00<?][A
Graph compilation: 100%|██████████| 100/100 [00:01<00:00][A
Compiled/Loaded model in 10.114462663000268 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16


  0%|          | 0/27 [00:00<?, ?it/s]

Saving model checkpoint to out/checkpoint-1000
Configuration saved in out/checkpoint-1000/ipu_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to trained_model/
Configuration saved in trained_model/ipu_config.json


{'eval_loss': 1.466796875, 'eval_runtime': 0.6367, 'eval_samples_per_second': 678.511, 'eval_steps_per_second': 42.407, 'epoch': 41.67}
{'train_runtime': 255.7952, 'train_samples_per_second': 1251.001, 'train_steps_per_second': 3.909, 'train_loss': 1.632796875, 'epoch': 41.67}


We get a similar evaluation loss as the regular model—a success! In fact after sweeping the learning rate, unit scaling appears to be slightly better here.

Let's celebrate with some unit-scaled Shakespeare...

In [39]:
pipelines.check_model_type = lambda self, supported_models: ...

final_model = UnitScaledNanoGPTModel.from_pretrained("trained_model/")


TEST_INPUT = """PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits"""
pipe = pipeline(
    "text-generation",
    ipu_config=ipu_config.to_dict(),  # TODO: feature request -> no to_dict()
    model=final_model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=True,
)

outputs = pipe(TEST_INPUT, num_return_sequences=2, temperature=0.4)
for i, output in enumerate(outputs):
    print(f"\n===== Completion [{i+1}] =====\n")
    print(output["generated_text"])

IPUConfig {
  "auto_loss_scaling": false,
  "device_iterations": 1,
  "embedding_serialization_factor": 1,
  "enable_half_partials": true,
  "executable_cache_dir": "./exe_cache",
  "execute_encoder_on_cpu_for_generation": false,
  "gradient_accumulation_steps": 20,
  "inference_device_iterations": 1,
  "inference_replication_factor": 1,
  "ipus_per_replica": 1,
  "layers_per_ipu": [
    6
  ],
  "matmul_proportion": 0.2,
  "optimizer_state_offchip": true,
  "optimum_version": "1.6.1",
  "output_mode": "final",
  "recompute_checkpoint_every_layer": false,
  "replicated_tensor_sharding": false,
  "replication_factor": 1,
  "seed": null,
  "transformers_version": "4.25.1"
}

Graph compilation: 100%|██████████| 100/100 [00:37<00:00]



===== Completion [1] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits of the
provost: the strength of your honours and strength
To meet the world in the violence of your father's country's life
Than the ground of the court-wings of all the death.

KING HENRY VI:
My lord, the last is the father for the court?

KING RICHARD III:
No more than I see thee to be a peace
And more than my brother's love in the gods
And the root of my counsel: therefore I mean not
The sister of the charge of the 

===== Completion [2] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits for the head
To see him that live and dear at his country's head,
The sea of his crown, sir, the last of the ground
Of our highness and the sun of his presence
Of the singer of the steed service of her sense,
That she was not her of your grace was for his son.

KING RICHARD III:
What is the done? what thou art thou dost deserve?

DUKE OF

Some neat verse, from a neatly-scaled model. From which we can only conclude:

*Enter: unit scaling*

*Exeunt: loss scaling, automatic loss scaling, etc.*

FIN

---

We hope that practitioners will consider using unit scaling for future projects, particularly those having difficulties with loss scaling or automatic mixed precision. With FP8 on the horizon, these issues are likely to become more prevalent. We hope unit scaling can help.

If you're interested in using unit scaling yourself, or have questions, please do reach out 🙏☎️ We're keen to hear from anyone that has a problem unit scaling might help solve.

The definitions provided here are the closest we have to an "official" PyTorch implementation, but if there's demand for a library tell us and we'll make one!