Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# Unit Scaling: A How-To Guide

In our paper [Unit Scaling: Out-of-the-Box Low-Precision Training](TODO), we describe a scheme for designing neural networks that have approximate unit variance after every operation in the forward and backward pass.

This can be seen as an alternative to (static) loss scaling, or its automatic variant, as used in Automatic Mixed Precision. Whereas both of those schemes rely on a single, global scaling factor for all the gradients, unit scaling is more fine-grained.

A unit-scaled model adds scaling factors (constant scalar multiplications) to each operation in the computational graph to achieve this unit variance property. The result is a model which naturally produces tensors in the middle of the dynamic range provided by floating-point formats. There's no extra loss-scale hyperparameter—it works out-of-the-box!

## Implementing unit scaling

Here we demonstrate how to go about unit scaling in practice. This involves re-implementing common neural network layers to add variance-preserving scaling factors.

As explained in the paper, we can sometimes justify using different scaling factors in the forward and backward pass. We introduce a special `scaled()` op which allows us to do just that.

Below we implement two models. The first is a simple implementation of a transformer-decoder. The design and hyperparameters are inspired by Andrej Karpathy's [popular NanoGPT implementation](https://github.com/karpathy/nanoGPT) (though some differ). It's also 🤗-compatible!

The second model is the same, but unit-scaled. Let's get stuck in...

## Building a unit-scaled NanoGPT

In [None]:
!pip install git+https://github.com/huggingface/optimum-graphcore.git
!pip install git+https://github.com/graphcore-research/poptorch-experimental-addons
!pip install altair

In [39]:
import math
from typing import Dict, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers.activations import GELUActivation
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel

### The MLP layer

#### Regular scaling

We'll start by setting up a basic config for a reasonably small transformer:

In [40]:
class NanoGPTConfig(PretrainedConfig):
    model_type = "nano-gpt"

    def __init__(
        self,
        hidden_size: int = 384,
        num_hidden_layers: int = 6,
        num_attention_heads: int = 6,
        dropout: float = 0.1,
        vocab_size: int = 384,
        eos_token_id: int = 1,
        **kwargs,
    ) -> None:
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.dropout = dropout
        self.vocab_size = vocab_size
        self.eos_token_id = eos_token_id
        super().__init__(**kwargs)

Along with a standard (pre-norm) MLP module:

In [41]:
class MLP(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.ln = nn.LayerNorm(config.hidden_size)
        self.linear_1 = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = GELUActivation()
        self.linear_2 = nn.Linear(config.hidden_size * 4, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: Tensor) -> Tensor:
        hidden_states = self.ln(hidden_states)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return self.dropout(hidden_states)

In [42]:
# TODO: hide
def init_weights(init_fn) -> None:
    def inner_fn(module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            fan_out, fan_in = module.weight.shape  # TODO: correct base impl
            module.weight.data.normal_(mean=0.0, std=init_fn(fan_in, fan_out))
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=1)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    return inner_fn

In [77]:
# TODO: hide
import pandas as pd

np.seterr(divide = 'ignore')

def instrument(module):
    stats = {}
    instrument_recursive(module, stats)
    return stats

def instrument_recursive(module, stats, name=''):
    children = list(module.named_children())
    if children:
        for c_name, c in children:
            _name = f'{name}.{c_name}' if name and name != 'blocks' else c_name
            instrument_recursive(c, stats, _name)
    else:
        instrument_terminal(module, stats, name)

def instrument_terminal(module, stats, name):
    module_stats = {}
    def require_input_grads(_module, input):
        for i in input:
            if isinstance(i, Tensor) and i.is_floating_point():
                i.requires_grad_()
    
    module.register_forward_pre_hook(require_input_grads)
    
    if name.split('.')[-1] == 'softmax':
        return
    
    def record_fwd_scale(_module, input, output):
        if isinstance(output, Tensor) and output.is_floating_point():
            module_stats['x'] = np.log2(output.std().item())
    
    module.register_forward_hook(record_fwd_scale)
    
    def record_bwd_scales(_module, grad_input, grad_output):
        grad_input = list(grad_input)
        for g in grad_input:
            if g is not None and isinstance(g, Tensor) \
                and g.is_floating_point() and len(grad_input) == 1:
                module_stats['grad_x'] = np.log2(g.std().item())
        
        for param_name, param in _module.named_parameters():
            if param_name == "weight":
                module_stats['w'] = np.log2(param.std().item())
                if param.grad is not None:
                    module_stats['grad_w'] = np.log2(param.grad.std().item())
    
    module.register_full_backward_hook(record_bwd_scales)
    
    stats[name] = module_stats

In [78]:
# TODO: hide
import altair as alt

def visualise(stats, subnormal=False):
    df = pd.DataFrame(stats)
    df = df.stack().to_frame('scale (log₂)').reset_index(names=['type', 'op'])
    plot(df, subnormal)

def plot(df, subnormal=False):
    is_x_or_grad_x = (df['type'] == 'x') | (df['type'] == 'grad_x')
    op_order = df[df['type'] == 'x']['op'].tolist()
    colors = ['#6C8EBF', '#FF8000', '#5D8944', '#ED3434']
    x_range = np.arange(-18 if subnormal else -14, 18+1 if subnormal else 16+1, 2)
    
    fp16_min = alt.Chart().mark_rule(strokeDash=(4, 4)).encode(x=alt.datum(-14))
    fp16_min_text = alt.Chart().mark_text(dy=-740).encode(text=alt.Text(value='Min FP16 (normal)'), x=alt.datum(-10))
    fp16_max = alt.Chart().mark_rule(strokeDash=(4, 4)).encode(x=alt.datum(16))
    fp16_max_text = alt.Chart().mark_text(dy=-740).encode(text=alt.Text(value='Max FP16'), x=alt.datum(13))
    
    x_chart = alt.Chart(df[is_x_or_grad_x]).mark_line().encode(
        x=alt.X(
            'scale (log₂):Q',
            axis=alt.Axis(orient='top', values=x_range),
            scale=alt.Scale(domain=[x_range[0], x_range[-1]]),
        ),
        y=alt.Y('op:O', title='', sort=op_order),
        color=alt.Color(
            'type',
            legend=alt.Legend(title='', labelFontSize=12, symbolSize=100),
            scale=alt.Scale(range=colors[:2]),
            sort='descending'
        ),
    )
    w_chart = alt.Chart(df[~is_x_or_grad_x]).mark_point(size=100).encode(
        x=alt.X(
            'scale (log₂):Q',
            axis=alt.Axis(orient='top', values=x_range),
            scale=alt.Scale(domain=[x_range[0], x_range[-1]])
        ),
        y=alt.Y('op:O', title='', sort=op_order),
        color=alt.Color(
            'type',
            legend=alt.Legend(title='', labelFontSize=12, symbolSize=100),
            scale=alt.Scale(range=colors[2:]),
            sort='descending'
        ),
        shape=alt.Shape(
            'type',
            scale=alt.Scale(range=['square', 'triangle-down']),
            sort='descending'
        ),
    )
    layers = [x_chart, w_chart]
    if subnormal:
        layers += [fp16_min, fp16_max, fp16_min_text, fp16_max_text] 
    combined_chart = alt.layer(
        *layers
    ).resolve_scale(
        color='independent', shape='independent'
    ).configure_axis(
        labelFontSize=12,
        titleFontSize=16
    ).properties(
        width=500
    )
    display(combined_chart)

Now we can analyse the scale (i.e. standard deviation) of each operation within the MLP. To do this, we provide an `instrument()` operation, which goes through and tracks the scale coming out of each operation in the forward and backward pass.

Our network's weights will also use the standard glorot initialisation.

Let's feed in a unit normal tensor in both directions, and examine the result.

In [79]:
def analyse_mlp(mlp, config, batch_size=64, seq_len=16):
    stats = instrument(mlp)
    x = torch.normal(0.0, 1.0, size=(batch_size, seq_len, config.hidden_size))
    y = mlp(x)
    y.backward(torch.normal(0.0, 1.0, size=y.shape))
    visualise(stats)

In [80]:
def glorot_init(fan_in, fan_out):
    return ((fan_in + fan_out) / 2) ** -0.5

config = NanoGPTConfig()
mlp = MLP(config).apply(init_weights(glorot_init))
analyse_mlp(mlp, config)

**How to interpret this chart:** The x-axis shows the scale in $\log_2$ form. The range here represents the minimum and maximum absolute values that can be represented in the FP16 or FP8 E5 number formats (excluding subnormal values). In other words, if the scale exceeds the x-axis bounds we're in the numerics "danger zone" where training begins to degrade.

The y-axis shows the operations in the model, in the order in which they execute. We show the scale of activations (x), gradients (grad_x, grad_w) and weights (w). Activations can be said to "flow" forward through these layers, and grad_xs flow backwards, so we represent these by solid lines, and weights and grad_ws by symbols.

By the end of both passes the x and grad_x scales have dropped by half. This is due to glorot slightly under-scaling values, and GeLU dropping the scale further.

Worse, the weights are significantly under-scaled and the grad_ws over-scaled by a factor of $2^4$. This is largely because glorot scaling (along with all other weight init schemes) only accounts for the forward and grad_x scales.

#### Unit scaling

We'll now implement the equivalent layer using unit scaling, beginning with a basic linear operation:

In [81]:
# TODO: hide
class ScaledGrad(torch.autograd.Function):
  @staticmethod
  def forward(ctx, X, alpha, beta):
    ctx.save_for_backward(
      torch.tensor(beta, dtype=X.dtype))
    return alpha * X

  @staticmethod
  def backward(ctx, grad_Y):
    beta, = ctx.saved_tensors
    return beta * grad_Y, None, None

def scaled(X, alpha=1.0, beta=1.0):
  # Forward: Y = X * alpha
  # Backward: grad_X = grad_Y * beta
  return ScaledGrad.apply(X, alpha, beta)

def geometric_mean(xs):
    xs = np.array(xs)
    return xs.prod() ** (1 / xs.size)

In [82]:
class UnitScaledLinear(nn.Linear):
    def __init__(self, *args, scale_for="fwd, grad_x", **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.scale_for = scale_for
    
    def get_scales(self, input):
        fwd_scale = self.weight.shape[1] ** -0.5
        grad_x_scale = self.weight.shape[0] ** -0.5
        grad_w_scale = np.prod(input.shape[:-1]) ** -0.5
        if self.scale_for == "fwd":
            grad_x_scale = fwd_scale
        elif self.scale_for == "grad_x":
            fwd_scale = grad_x_scale
        elif self.scale_for == "fwd, grad_x":
            fwd_scale = grad_x_scale = geometric_mean([fwd_scale, grad_x_scale])
        else:
            assert self.scale_for == "separate", f"demo implementation has no {self.scale_for} scaling"
        return fwd_scale, grad_x_scale, grad_w_scale
           
    def forward(self, input):
        fwd_scale, grad_x_scale, grad_w_scale = self.get_scales(input)
        input = scaled(input, beta=grad_x_scale)
        weight = scaled(self.weight, beta=grad_w_scale)
        bias = scaled(self.bias, beta=grad_w_scale) if self.bias is not None else None
        output = F.linear(input, weight, bias)
        return scaled(output, alpha=fwd_scale)

Let's break this down a bit. The `forward()` method is still based on the fundamental `F.linear(input, weight, bias)` operation, but has each of its operands and output scaled.

This is done via the `scaled(X, alpha, beta)` method. This is a special operation in that it has different dynamics in the forward and backward pass, where we have

$$Y = X \cdot \alpha$$

$$\nabla_X = \nabla_Y \cdot \beta$$

So what scaling factors do we choose? This is determined in `get_scales()`. The standard approach is to set the forward and grad_x scales as a compromise between their "ideal" scale-preserving values (via a geometric mean). See our paper for more details on how we arrive at these ideal values.

We also provide versions of the operation which select forward and grad_x scales based on only one of their ideal values. We also have a scheme which has separate forward and grad_x scales, though this is only allowed in special circumstances (again, see the paper).

Now for the GeLU. This is pretty similar to the linear op. However, as activation functions are nonlinear, our usual approach of calculating scaling values analytically (i.e. by working through the maths) isn't always possible.

Fortunately, for these elemenwise ops we can calulate them empirically, like so:

In [83]:
def analyse_elemenwise_fn(fn, num_samples=2**22):
    x = torch.normal(0.0, 1.0, size=(num_samples,)).requires_grad_()
    y = fn(x)
    y.backward(torch.normal(0.0, 1.0, size=(num_samples,)))
    print(f"fwd scale={y.std():.3f}, bwd scale={x.grad.std():.3f}")

print("GeLU:", end=" ")
analyse_elemenwise_fn(GELUActivation())
print("Tanh:", end=" ")
analyse_elemenwise_fn(nn.Tanh())

GeLU: fwd scale=0.588, bwd scale=0.675
Tanh: fwd scale=0.628, bwd scale=0.682


We then put these scaling factors into our unit-scaled op as follows:

In [84]:
class UnitScaledGELU(GELUActivation):
    def forward(self, input):
        fwd_scale = bwd_scale = geometric_mean([0.588, 0.675]) ** -1
        input = scaled(input, beta=bwd_scale)
        output = self.act(input)
        return scaled(output, alpha=fwd_scale)

Simple! Now we just need to unit scale the layernorm and dropout and we're ready.

These take the same kind of approach, so we won't go into too much detail here:

In [85]:
class UnitScaledLayerNorm(nn.LayerNorm):
    def forward(self, input: Tensor) -> Tensor:
        scale = (np.prod(self.normalized_shape) / input.nelement()) ** 0.5
        weight = scaled(self.weight, beta=scale)
        bias = scaled(self.bias, beta=scale)
        return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

class UnitScaledDropout(nn.Dropout):
    def forward(self, input: Tensor) -> Tensor:
        # Dropout is typically implemented with a (1-p) ** -1 scaling
        # However, to preserve variance this ought to be (1-p) ** -0.5
        # We correct for this by multiplying by (1-p) ** 0.5
        scale = (1-self.p) ** 0.5
        input = scaled(input, beta=scale)
        output = F.dropout(input, self.p, self.training, self.inplace)
        return scaled(output, alpha=scale)

We're now ready to unit-scale the full MLP 🥳

All this requires is swapping out the old layers for our new ones:

In [86]:
class UnitScaledMLP(MLP):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.linear_1 = UnitScaledLinear(config.hidden_size, config.hidden_size * 4)
        self.act = UnitScaledGELU()
        self.linear_2 = UnitScaledLinear(config.hidden_size * 4, config.hidden_size)
        self.dropout = UnitScaledDropout(config.dropout)

We also change our initialisation to give our weights unit scale. Let's analyse our new unit-scaled MLP:

In [88]:
def unit_init(*args):
    return 1

mlp = UnitScaledMLP(config).apply(init_weights(unit_init))
analyse_mlp(mlp, config)

This is much better! The final scales are almost exactly 1 in both directions, and weights and grad_ws look much better.

The compromise fwd & grad_x scaling factors in our linear layers do give temporary non-unit scaling, but this is minor and the second linear layer cancels out the first.

## The self-attention layer

#### Regular scaling

We start with a basic implementation of a (self-)attention module.

(Note that we use [ALiBi](https://arxiv.org/abs/2108.12409) biases over positional embeddings here. This is primarily for simplicity, though these have also been [shown to perform](https://arxiv.org/abs/2210.15424) remarkably well!)

In [89]:
# TODO: hide
class Matmul(nn.Module):
    def forward(self, a: Tensor, b: Tensor, scale: float=1.0) -> Tensor:
        return (a @ b) * scale

def split_heads(tensor: Tensor, num_heads: int, head_size: int) -> Tensor:
    batch_size, seq_len, hidden_size = tensor.shape
    tensor = tensor.view(batch_size, seq_len, num_heads, head_size)
    return tensor.permute(0, 2, 1, 3)


def merge_heads(tensor: Tensor, num_heads: int) -> Tensor:
    tensor = tensor.permute(0, 2, 1, 3).contiguous()
    batch_size, seq_len, num_heads, head_size = tensor.shape
    return tensor.view(batch_size, seq_len, num_heads * head_size)


def causal_mask(seq_len: int, num_heads: int) -> Tensor:
    causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.float16))
    causal_mask = causal_mask.view(1, 1, seq_len, seq_len)
    alibi_mask = gen_alibi_mask(causal_mask, num_heads)
    causal_mask = (1.0 - causal_mask) * -10_000
    return alibi_mask + causal_mask


# Based on https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/alibi/__init__.py
def gen_alibi_mask(causal_mask: Tensor, num_heads: int) -> Tensor:
    distances = causal_mask.to(torch.float32).cumsum(dim=-1)
    slopes = gen_slopes(num_heads)
    return distances.to(torch.float16) * slopes.view(1, num_heads, 1, 1)


def gen_slopes(num_heads: int) -> Tensor:
    n = 2 ** math.floor(math.log2(num_heads))
    m_0 = 2.0 ** (-8.0 / n)
    m = torch.pow(m_0, torch.arange(1, 1 + n))
    if n < num_heads:
        m_hat_0 = 2.0 ** (-4.0 / n)
        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
        m = torch.cat([m, m_hat])
    return m

In [90]:
class Attention(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.ln = nn.LayerNorm(config.hidden_size)
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        
        self.w_qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.qk_matmul = Matmul()
        self.softmax = nn.Softmax(-1)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.qkv_matmul = Matmul()
        self.w_o = nn.Linear(config.hidden_size, config.hidden_size)
        self.residual_dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        seq_len = hidden_states.shape[1]
        head_size = self.hidden_size // self.num_heads

        hidden_states = self.ln(hidden_states)
        q_k_v = self.w_qkv(hidden_states)
        q, k, v = q_k_v.split(self.hidden_size, dim=-1)
        q, k, v = (split_heads(t, self.num_heads, head_size) for t in (q, k, v))

        qk = self.qk_matmul(q, k.transpose(-1, -2), scale=1 / head_size ** 0.5).clone()
        qk += causal_mask(seq_len, self.num_heads) + attention_mask
        qk = self.softmax(qk)
        qk = self.attn_dropout(qk)

        qkv = self.qkv_matmul(qk, v)
        qkv = merge_heads(qkv, self.num_heads)
        qkvo = self.w_o(qkv)
        return self.residual_dropout(qkvo).clone()

In [91]:
def analyse_attn(attention, config, batch_size=64, seq_len=16):
    stats = instrument(attention)
    x = torch.normal(0.0, 1.0, size=(batch_size, seq_len, config.hidden_size))
    attention_mask = torch.zeros(batch_size, 1, 1, seq_len)
    y = attention(x, attention_mask)
    y.backward(torch.normal(0.0, 1.0, size=y.shape))
    visualise(stats)

In [92]:
attention = Attention(config).apply(init_weights(glorot_init))
analyse_attn(attention, config)

Again here, the activation scales and grad_x scales fall by half as they go through the layer, and fluctuate in between (Note that we omit the softmax operation here as a) its output isn't normally distributed, and b) we usually do it in higher-precision anyway). We see the same problems as before for weights and grad_ws.

Let's fix this:

#### Unit scaling

In [93]:
class UnitScaledAttention(Attention):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.w_qkv = UnitScaledLinear(
            config.hidden_size, 3 * config.hidden_size, scale_for="fwd"
        )
        self.qk_matmul = UnitScaledMatmul(scale_for="fwd")
        self.softmax = UnitScaledSoftmax(-1)
        self.attn_dropout = UnitScaledDropout(config.dropout)
        self.qkv_matmul = UnitScaledMatmul(scale_for="fwd")
        self.w_o = UnitScaledLinear(config.hidden_size, config.hidden_size)
        self.residual_dropout = UnitScaledDropout(config.dropout)

class UnitScaledMatmul(Matmul):
    def __init__(self, scale_for="fwd, grad_a, grad_b") -> None:
        super().__init__()
        self.scale_for = scale_for
    
    def get_scales(self, a: Tensor, b: Tensor):
        fwd_scale = a.shape[-1] ** -0.5
        grad_a_scale = b.shape[-1] ** -0.5
        grad_b_scale = a.shape[-2] ** -0.5
        if self.scale_for == "fwd":
            grad_a_scale = grad_b_scale = fwd_scale
        elif self.scale_for == "fwd, grad_a, grad_b":
            fwd_scale = grad_a_scale = grad_b_scale = geometric_mean([
                fwd_scale, grad_a_scale, grad_b_scale
            ])
        else:
            assert False, f"demo implementation has no {self.scale_for} scaling"
        return fwd_scale, grad_a_scale, grad_b_scale
    
    def forward(self, a: Tensor, b: Tensor, scale: float=1.0) -> Tensor:
        # ignores provided scale
        fwd_scale, grad_a_scale, grad_b_scale = self.get_scales(a, b)
        a = scaled(a, beta=grad_a_scale)
        b = scaled(b, beta=grad_b_scale)
        output = a @ b
        return scaled(output, alpha=fwd_scale)

class UnitScaledSoftmax(nn.Softmax):
    def forward(self, input: Tensor) -> Tensor:
        fwd_scale = bwd_scale = input.shape[-1] ** 0.5
        input = scaled(input, fwd_scale)
        output = F.softmax(input, self.dim, _stacklevel=5)
        return scaled(output, bwd_scale)

To unit scale the attention layer we again swap out regular layers for unit-scaled ones. This requires altering two more operations: matmul and softmax.

Like the linear layer, we provide scaling the matmul based on varying criteria. We find that scaling each matmul and linear layer here for the forward pass ensures good scaling in both directions. We see this empirically in our analysis:

In [94]:
attention = UnitScaledAttention(config).apply(init_weights(unit_init))
analyse_attn(attention, config, batch_size=64, seq_len=64)

Again, a great improvement for all four types of tensors, with each having approximate unit scale. Note that these scaling rules are robust to changes in hyperparameters too. You can experiment with the initial config to verify this.

We now have our key building-blocks 🧱 Let's put it all together!

### The transformer block

#### Regular scaling

In [60]:
class TransformerBlock(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.attention_layer = Residual(Attention(config))
        self.mlp_layer = Residual(MLP(config))
        
    def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        hidden_states = self.attention_layer(hidden_states, attention_mask)
        return self.mlp_layer(hidden_states)


class Residual(nn.Module):
    def __init__(self, f: nn.Module):
        super().__init__()
        self.f = f
    
    def forward(self, x: Tensor, *args):
        return x + self.f(x, *args)

#### Unit scaling

In [61]:
class UnitScaledTransformerBlock(TransformerBlock):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.attention_layer = UnitScaledResidual(UnitScaledAttention(config))
        self.mlp_layer = UnitScaledResidual(UnitScaledMLP(config))

class UnitScaledResidual(Residual):
    def __init__(self, f: nn.Module, tau: float=0.2):
        super().__init__(f)
        self.tau = tau
    
    def forward(self, x: Tensor, *args):
        y = x * (1 - self.tau) ** 0.5
        z = scaled(x, beta=self.tau ** 0.5)
        z = self.f(z, *args)
        z = scaled(z, alpha=self.tau ** 0.5)
        return y + z

Note that the scaling in our residual layer includes an important trick.

It's necessary for unit-scaled models to use a weighted sum when doing their residual-add, to down-weight the residual/trunk branch. If this is not included, training can fail as too much signal comes from each residual (regular models avoid this as their residual implicitly reduces scale).

However, the naïve implementation of this breaks unit scale. We get around this by delaying the weighting until the end of the residual branch in the backward pass (see paper for more details).

### The full transformer

#### Regular scaling

Note that our implementation contains a little extra boilerplate to make it 🤗-compliant.

In [62]:
class NanoGPTModel(PreTrainedModel):
    config_class = NanoGPTConfig

    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.input_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)
        self.ln = nn.LayerNorm(config.hidden_size)
        self.blocks = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.loss_fn = nn.CrossEntropyLoss()
        self.apply(init_weights(glorot_init))

    def forward(
        self,
        input_ids: Tensor,
        attention_mask: Tensor,
        position_ids: Optional[Tensor] = None,
        labels: Optional[Tensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithCrossAttentions:
        attention_mask = attention_mask[:, None, None, :].to(dtype=self.dtype)
        attention_mask = (1.0 - attention_mask) * -10_000

        hidden_states = self.input_embeddings(input_ids)
        hidden_states = self.dropout(hidden_states)
        for block in self.blocks:
            hidden_states = block(hidden_states, attention_mask=attention_mask)
        hidden_states = self.ln(hidden_states)
        logits = self.lm_head(hidden_states)

        if labels is None:
            return CausalLMOutputWithCrossAttentions(logits=logits)
        else:
            labels = torch.roll(labels, -1, 1)
            labels[:, -1] = -100  # By default, ignore_index of CrossEntropyLoss is -100
            loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
            return CausalLMOutputWithCrossAttentions(loss=loss)

    def get_input_embeddings(self) -> nn.Embedding:
        return self.input_embeddings

    def prepare_inputs_for_generation(
        self, input_ids: Tensor, **kwargs
    ) -> Dict[str, Tensor]:
        attention_mask = kwargs["attention_mask"]
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        return {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
        }

In [63]:
def analyse_full_model(model, config, batch_size=64, seq_len=16):
    stats = instrument(model)
    input_ids = labels = torch.randint(0, config.vocab_size, size=(batch_size, seq_len))
    attention_mask = torch.zeros(batch_size, seq_len)
    y = model(input_ids, attention_mask, labels=labels).loss
    y.backward()
    visualise(stats, subnormal=True)

Here's the scaling for the entire (non-unit-scaled) transformer:

In [64]:
model = NanoGPTModel(config)
analyse_full_model(model, config)

These results clearly demonstrates the inadiquacy of the standard approach to model design when it comes to scale. Grad_x and grad_w values are very far from having unit scale, with the former now relying on subnormal values (dangerous, as these begin to lose precision and disappear altogether at $2^{-24}$).

A key issue introduced by this layer is the cross-entropy loss definition. This typically uses a `1 / batch_size` term to average over labels, which has the effect of dramatically under-scaling gradients in the backward pass.

We can also see here that a single loss scaling factor for all gradients isn't ideal. This would shift all grad_xs and grad_ws to the right by the same amount, which would still leave us far short of our unit-scale ideal.

What we really need is per-op scaling! Again, we'll fix this for the unit-scaled implementation:

#### Unit scaling

In [65]:
class UnitScaledNanoGPTModel(NanoGPTModel):
    config_class = NanoGPTConfig

    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.input_embeddings = UnitScaledEmbedding(config.vocab_size, config.hidden_size)
        self.dropout = UnitScaledDropout(config.dropout)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.blocks = nn.ModuleList([
            UnitScaledTransformerBlock(config)
            for _ in range(config.num_hidden_layers)
        ])
        self.lm_head = UnitScaledLinear(
            config.hidden_size, config.vocab_size, bias=False, scale_for="separate"
        )
        self.loss_fn = UnitScaledCrossEntropyLoss()
        self.apply(init_weights(unit_init))

class UnitScaledEmbedding(nn.Embedding):
    def forward(self, input: Tensor) -> Tensor:
        batch_size = np.prod(input.shape)
        weight = scaled(self.weight, beta=self.num_embeddings / batch_size)
        return F.embedding(input, weight, self.padding_idx, self.max_norm,
                           self.norm_type, self.scale_grad_by_freq, self.sparse)

class UnitScaledCrossEntropyLoss(nn.CrossEntropyLoss):    
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        batch_size, seq_len = input.shape
        input = scaled(input, beta=seq_len / (seq_len - 1) ** 0.5)
        loss = F.cross_entropy(input, target, reduction='sum')
        return scaled(loss, alpha=1 / batch_size)

In [66]:
model = UnitScaledNanoGPTModel(config)
analyse_full_model(model, config)

Success! 🥂 🎊 🍾 We've managed to keep every tensor in the model at unit scale for the forward and backward pass. All that was required was to re-implement standard layers with the right scaling factors.

Of course, this only gives us the right scaling *at initialisation*. Scales inevitably drift throughout training, but we've given ourselves headroom by starting in the ideal place. The results in our paper show that this is sufficient to enable accurate, out-the-box training of many models, up to the size of BERT Large (we haven't tested anything larger—yet!).

## Training

#### Regular scaling

To show that this really does work, let's train both our models.

As in Karpathy's original NanoGPT, we'll train on the small Shakespeare dataset for a few minutes and see if we can get something that captures the rough style. Starting with the regular (non-unit-scaled) model:

In [67]:
# TODO: hide
from optimum.graphcore.generation_utils import IPUGenerationMixin
from optimum.graphcore.modeling_utils import (
    PipelineMixin,
    outline_attribute,
    register,
    tied_weight_model,
)

@tied_weight_model(NanoGPTModel)
@register(NanoGPTModel)
class PipelinedNanoGPTModel(NanoGPTModel, PipelineMixin, IPUGenerationMixin):
    def parallelize(self):
        self._hooks = [outline_attribute(self.ln, "LayerNorm")]

    def deparallelize(self):
        pass

@tied_weight_model(UnitScaledNanoGPTModel)
@register(UnitScaledNanoGPTModel)
class PipelinedUnitScaledNanoGPTModel(
    UnitScaledNanoGPTModel, PipelineMixin, IPUGenerationMixin
):
    def parallelize(self):
        self._hooks = [outline_attribute(self.ln, "UnitScaledLayerNorm")]

    def deparallelize(self):
        pass

In [68]:
import poptorch_experimental_addons as pea

def scaled(X, alpha=1.0, beta=1.0):
  # Forward: Y = X * alpha
  # Backward: grad_X = grad_Y * beta
  return pea.autograd_proxy(X * alpha, X * beta)

In [69]:
from functools import partial

from datasets.load import load_dataset
from optimum.graphcore import (
    IPUConfig,
    IPUTrainer,
    IPUTrainingArguments,
    pipeline,
    pipelines,
)
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

In [70]:
dataset = load_dataset("tiny_shakespeare")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

config = NanoGPTConfig(vocab_size=len(tokenizer), eos_token_id=tokenizer.eos_token_id)
batch_sz = 16
seq_len = 128
ipu_config = IPUConfig(
    gradient_accumulation_steps=5 * 64 // batch_sz,
    layers_per_ipu=[config.num_hidden_layers],
    executable_cache_dir="./exe_cache",
)

def split_and_tokenize(data, seq_len, batch_sz):
    tokens = tokenizer(data["text"])
    seqs = [
        tokens["input_ids"][0][i : i + seq_len]
        for i in range(0, len(tokens["input_ids"][0]), seq_len)
    ]
    seqs = seqs[: int(len(seqs) / batch_sz) * batch_sz]  # make divisible by batch size
    return {"input_ids": seqs}


prep_data = partial(split_and_tokenize, seq_len=seq_len, batch_sz=batch_sz)
tokenized_dataset = dataset.map(
    prep_data, batched=True, remove_columns=dataset["train"].column_names
)

Found cached dataset tiny_shakespeare (/nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 668.41it/s]
Loading cached processed dataset at /nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e/cache-bebf307cc6e838af.arrow
Loading cached processed dataset at /nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e/cache-c2f9b4618a37a1c8.arrow
Loading cached processed dataset at /nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e/cache-6df33a0afb6a9efa.arrow


In [71]:
train_args = IPUTrainingArguments(
    output_dir="out",
    per_device_train_batch_size=batch_sz,
    per_device_eval_batch_size=batch_sz,
    evaluation_strategy="steps",
    eval_steps=250,
    logging_steps=20,
    max_steps=1000,
    weight_decay=0.1,
    warmup_steps=100,
    lr_scheduler_type="linear",
    learning_rate=1e-3,
    loss_scaling=2**6,
)

model = NanoGPTModel(config)

trainer = IPUTrainer(
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    model=model,
    args=train_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    ipu_config=ipu_config,
)

trainer.train()
trainer.save_model("trained_model/")

Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 5.3185042690020055 secs
***** Running training *****
  Num examples = 7840
  Num Epochs = 42
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 320
  Gradient Accumulation steps = 20
  Total optimization steps = 1000
  2%|▏         | 21/1000 [00:02<02:16,  7.18it/s]

{'loss': 4.4369, 'learning_rate': 0.0002, 'epoch': 0.83}


  4%|▍         | 41/1000 [00:05<02:02,  7.82it/s]

{'loss': 2.7566, 'learning_rate': 0.0004, 'epoch': 1.67}


  6%|▌         | 61/1000 [00:08<02:00,  7.82it/s]

{'loss': 2.4258, 'learning_rate': 0.0006, 'epoch': 2.5}


  8%|▊         | 81/1000 [00:11<02:07,  7.23it/s]

{'loss': 2.2702, 'learning_rate': 0.0008, 'epoch': 3.33}


 10%|█         | 101/1000 [00:13<02:14,  6.70it/s]

{'loss': 2.1507, 'learning_rate': 0.001, 'epoch': 4.17}


 12%|█▏        | 121/1000 [00:16<02:02,  7.20it/s]

{'loss': 2.0396, 'learning_rate': 0.0009777777777777777, 'epoch': 5.0}


 14%|█▍        | 141/1000 [00:19<02:01,  7.06it/s]

{'loss': 1.9657, 'learning_rate': 0.0009555555555555556, 'epoch': 5.83}


 16%|█▌        | 161/1000 [00:21<01:56,  7.20it/s]

{'loss': 1.8925, 'learning_rate': 0.0009333333333333333, 'epoch': 6.67}


 18%|█▊        | 181/1000 [00:24<02:00,  6.78it/s]

{'loss': 1.8139, 'learning_rate': 0.0009111111111111111, 'epoch': 7.5}


 20%|██        | 201/1000 [00:27<01:42,  7.82it/s]

{'loss': 1.8059, 'learning_rate': 0.0008888888888888888, 'epoch': 8.33}


 22%|██▏       | 221/1000 [00:30<01:44,  7.48it/s]

{'loss': 1.7422, 'learning_rate': 0.0008666666666666667, 'epoch': 9.17}


 24%|██▍       | 241/1000 [00:32<01:46,  7.15it/s]

{'loss': 1.733, 'learning_rate': 0.0008444444444444444, 'epoch': 10.0}


 25%|██▌       | 250/1000 [00:33<01:38,  7.61it/s]Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:00<00:00]
Compiled/Loaded model in 4.082931288983673 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 25%|██▌       | 250/1000 [00:40<01:38,  7.61it/s]

{'eval_loss': 1.7421875, 'eval_runtime': 0.4696, 'eval_samples_per_second': 919.964, 'eval_steps_per_second': 57.498, 'epoch': 10.42}


Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
 26%|██▌       | 261/1000 [00:47<03:01,  4.06it/s]

{'loss': 1.6966, 'learning_rate': 0.0008222222222222222, 'epoch': 10.83}


 28%|██▊       | 281/1000 [00:50<01:42,  7.05it/s]

{'loss': 1.6454, 'learning_rate': 0.0008, 'epoch': 11.67}


 30%|███       | 301/1000 [00:53<01:31,  7.63it/s]

{'loss': 1.6402, 'learning_rate': 0.0007777777777777778, 'epoch': 12.5}


 32%|███▏      | 321/1000 [00:56<01:42,  6.63it/s]

{'loss': 1.6077, 'learning_rate': 0.0007555555555555555, 'epoch': 13.33}


 34%|███▍      | 341/1000 [00:58<01:27,  7.53it/s]

{'loss': 1.5972, 'learning_rate': 0.0007333333333333333, 'epoch': 14.17}


 36%|███▌      | 361/1000 [01:01<01:25,  7.47it/s]

{'loss': 1.5686, 'learning_rate': 0.0007111111111111111, 'epoch': 15.0}


 38%|███▊      | 381/1000 [01:04<01:28,  7.03it/s]

{'loss': 1.5528, 'learning_rate': 0.000688888888888889, 'epoch': 15.83}


 40%|████      | 401/1000 [01:07<01:19,  7.49it/s]

{'loss': 1.5522, 'learning_rate': 0.0006666666666666666, 'epoch': 16.67}


 42%|████▏     | 421/1000 [01:09<01:19,  7.30it/s]

{'loss': 1.5131, 'learning_rate': 0.0006444444444444444, 'epoch': 17.5}


 44%|████▍     | 441/1000 [01:12<01:19,  7.07it/s]

{'loss': 1.5112, 'learning_rate': 0.0006222222222222223, 'epoch': 18.33}


 46%|████▌     | 461/1000 [01:15<01:10,  7.69it/s]

{'loss': 1.5121, 'learning_rate': 0.0006, 'epoch': 19.17}


 48%|████▊     | 481/1000 [01:17<01:09,  7.49it/s]

{'loss': 1.4964, 'learning_rate': 0.0005777777777777778, 'epoch': 20.0}


 50%|█████     | 500/1000 [01:20<01:04,  7.71it/s]

{'loss': 1.4818, 'learning_rate': 0.0005555555555555556, 'epoch': 20.83}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:00<00:00]
Compiled/Loaded model in 3.906453931995202 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 50%|█████     | 500/1000 [01:29<01:04,  7.71it/s]Saving model checkpoint to out/checkpoint-500
Configuration saved in out/checkpoint-500/ipu_config.json


{'eval_loss': 1.56640625, 'eval_runtime': 0.4722, 'eval_samples_per_second': 914.934, 'eval_steps_per_second': 57.183, 'epoch': 20.83}


Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
 52%|█████▏    | 521/1000 [01:38<01:14,  6.40it/s]

{'loss': 1.4603, 'learning_rate': 0.0005333333333333334, 'epoch': 21.67}


 54%|█████▍    | 541/1000 [01:41<01:08,  6.69it/s]

{'loss': 1.4605, 'learning_rate': 0.0005111111111111111, 'epoch': 22.5}


 56%|█████▌    | 561/1000 [01:44<01:05,  6.72it/s]

{'loss': 1.444, 'learning_rate': 0.0004888888888888889, 'epoch': 23.33}


 58%|█████▊    | 581/1000 [01:46<00:54,  7.68it/s]

{'loss': 1.4429, 'learning_rate': 0.00046666666666666666, 'epoch': 24.17}


 60%|██████    | 601/1000 [01:49<00:59,  6.71it/s]

{'loss': 1.4514, 'learning_rate': 0.0004444444444444444, 'epoch': 25.0}


 62%|██████▏   | 621/1000 [01:52<00:54,  7.01it/s]

{'loss': 1.4254, 'learning_rate': 0.0004222222222222222, 'epoch': 25.83}


 64%|██████▍   | 641/1000 [01:54<00:52,  6.85it/s]

{'loss': 1.4112, 'learning_rate': 0.0004, 'epoch': 26.67}


 66%|██████▌   | 661/1000 [01:57<00:50,  6.75it/s]

{'loss': 1.4068, 'learning_rate': 0.00037777777777777777, 'epoch': 27.5}


 68%|██████▊   | 681/1000 [02:00<00:44,  7.13it/s]

{'loss': 1.4147, 'learning_rate': 0.00035555555555555557, 'epoch': 28.33}


 70%|███████   | 701/1000 [02:03<00:42,  7.07it/s]

{'loss': 1.3746, 'learning_rate': 0.0003333333333333333, 'epoch': 29.17}


 72%|███████▏  | 721/1000 [02:06<00:43,  6.45it/s]

{'loss': 1.3948, 'learning_rate': 0.0003111111111111111, 'epoch': 30.0}


 74%|███████▍  | 741/1000 [02:09<00:36,  7.03it/s]

{'loss': 1.3934, 'learning_rate': 0.0002888888888888889, 'epoch': 30.83}


 75%|███████▌  | 750/1000 [02:10<00:35,  7.02it/s]Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:00<00:00]
Compiled/Loaded model in 3.861648246005643 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 75%|███████▌  | 750/1000 [02:16<00:35,  7.02it/s]

{'eval_loss': 1.5185546875, 'eval_runtime': 0.4588, 'eval_samples_per_second': 941.633, 'eval_steps_per_second': 58.852, 'epoch': 31.25}


Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
 76%|███████▌  | 761/1000 [02:24<01:02,  3.81it/s]

{'loss': 1.3722, 'learning_rate': 0.0002666666666666667, 'epoch': 31.67}


 78%|███████▊  | 781/1000 [02:27<00:31,  7.04it/s]

{'loss': 1.359, 'learning_rate': 0.00024444444444444443, 'epoch': 32.5}


 80%|████████  | 801/1000 [02:29<00:27,  7.12it/s]

{'loss': 1.3687, 'learning_rate': 0.0002222222222222222, 'epoch': 33.33}


 82%|████████▏ | 821/1000 [02:32<00:23,  7.47it/s]

{'loss': 1.3617, 'learning_rate': 0.0002, 'epoch': 34.17}


 84%|████████▍ | 841/1000 [02:35<00:25,  6.13it/s]

{'loss': 1.3544, 'learning_rate': 0.00017777777777777779, 'epoch': 35.0}


 86%|████████▌ | 861/1000 [02:38<00:18,  7.37it/s]

{'loss': 1.3524, 'learning_rate': 0.00015555555555555556, 'epoch': 35.83}


 88%|████████▊ | 881/1000 [02:41<00:16,  7.21it/s]

{'loss': 1.33, 'learning_rate': 0.00013333333333333334, 'epoch': 36.67}


 90%|█████████ | 901/1000 [02:44<00:13,  7.26it/s]

{'loss': 1.3366, 'learning_rate': 0.0001111111111111111, 'epoch': 37.5}


 92%|█████████▏| 921/1000 [02:47<00:11,  7.12it/s]

{'loss': 1.3277, 'learning_rate': 8.888888888888889e-05, 'epoch': 38.33}


 94%|█████████▍| 941/1000 [02:50<00:08,  6.67it/s]

{'loss': 1.342, 'learning_rate': 6.666666666666667e-05, 'epoch': 39.17}


 96%|█████████▌| 961/1000 [02:53<00:06,  6.19it/s]

{'loss': 1.3299, 'learning_rate': 4.4444444444444447e-05, 'epoch': 40.0}


 98%|█████████▊| 981/1000 [02:56<00:02,  7.25it/s]

{'loss': 1.3312, 'learning_rate': 2.2222222222222223e-05, 'epoch': 40.83}


100%|██████████| 1000/1000 [02:58<00:00,  7.21it/s]

{'loss': 1.3361, 'learning_rate': 0.0, 'epoch': 41.67}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:00<00:00]
Compiled/Loaded model in 4.049473897961434 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

100%|██████████| 1000/1000 [03:08<00:00,  7.21it/s]Saving model checkpoint to out/checkpoint-1000
Configuration saved in out/checkpoint-1000/ipu_config.json


{'eval_loss': 1.5078125, 'eval_runtime': 0.4775, 'eval_samples_per_second': 904.799, 'eval_steps_per_second': 56.55, 'epoch': 41.67}




Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 1000/1000 [03:08<00:00,  5.31it/s]
Saving model checkpoint to trained_model/
Configuration saved in trained_model/ipu_config.json


{'train_runtime': 188.3129, 'train_samples_per_second': 1699.3, 'train_steps_per_second': 5.31, 'train_loss': 1.6398095703125, 'epoch': 41.67}


Note that we use loss scaling here, though this model is sufficiently small that it can get away without loss scaling. This certainly doesn't hold for larger models though.

To really get a sense of how the model's doing, we ought to get it to write some Shakespeare. Here's two attempts at *The Tempest*:

In [72]:
pipelines.check_model_type = lambda self, supported_models: ...

final_model = NanoGPTModel.from_pretrained("trained_model/")

TEST_INPUT = """PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits"""
pipe = pipeline(
    "text-generation",
    ipu_config=ipu_config.to_dict(),  # TODO: feature request -> no to_dict()
    model=final_model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=True,
)

outputs = pipe(TEST_INPUT, num_return_sequences=2, temperature=0.4)
for i, output in enumerate(outputs):
    print(f"\n===== Completion [{i+1}] =====\n")
    print(output["generated_text"])

IPUConfig {
  "auto_loss_scaling": false,
  "device_iterations": 1,
  "embedding_serialization_factor": 1,
  "enable_half_partials": true,
  "executable_cache_dir": "./exe_cache",
  "execute_encoder_on_cpu_for_generation": false,
  "gradient_accumulation_steps": 20,
  "inference_device_iterations": 1,
  "inference_replication_factor": 1,
  "ipus_per_replica": 1,
  "layers_per_ipu": [
    6
  ],
  "matmul_proportion": 0.6,
  "optimizer_state_offchip": true,
  "optimum_version": "1.6.1",
  "output_mode": "final",
  "recompute_checkpoint_every_layer": false,
  "replicated_tensor_sharding": false,
  "replication_factor": 1,
  "seed": null,
  "transformers_version": "4.25.1"
}

Graph compilation: 100%|██████████| 100/100 [00:00<00:00]



===== Completion [1] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits him
And here is a part to the scorns of him.

BRUTUS:
Well, here's a very foul of the house to the common
To be secret the trumpets of the people's revenge,
That he presently presence.

CORIOLANUS:
The gods of the slanders of the children of the world
The consuls of the cause of the love.

POLIXENES:
The gods and Capulet, good Camillo's such as as the
That resign as her for the country.

BUCKINGHAM:
So shall be the sta

===== Completion [2] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits in the state.

DUKE OF YORK:
I have stay a fair is the strew and parts the fair
And therefore I would not the people.

GLOUCESTER:
The king is the heads of a king, and like thee
Will be proud in the seasons of your father.

KING RICHARD III:
Why, then, then, the heavens it was a state?

LADY ANNE:
Why, what says the hath the prince of th

It's not going to win any literary awards, but pretty good for a few minutes of training. Now let's see if our unit-scaled model can do the same...

#### Unit scaling

All we need to do is swap in our new model, remove the loss scaling (of course!), and increase the learning rate (as our weights are now larger due to unit-initialisation).

In [73]:
# TODO: hide?
from types import MethodType

class _IPUConfig(IPUConfig):
    def to_options(self, *args, **kwargs):
        options = super().to_options(*args, **kwargs)
        options._popart.setPatterns(dict(AutogradProxyOpPattern=True))
        return options

    def for_pod_type(self, *args, **kwargs):
        config = super().for_pod_type(*args, **kwargs)
        config.__class__ = _IPUConfig
        return config

ipu_config.__class__ = _IPUConfig

In [74]:
model = UnitScaledNanoGPTModel(config)

train_args.loss_scaling = 1.0
train_args.learning_rate = 2e-2

🤞 The moment of truth ... let's train our unit-scaled model 🤞

In [75]:
unit_scale_trainer = IPUTrainer(
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    model=model,
    args=train_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    ipu_config=ipu_config,
)

unit_scale_trainer.train()
unit_scale_trainer.save_model("trained_model/")

max_steps is given, it will override any value given in num_train_epochs
Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
Compiled/Loaded model in 12.797130857012235 secs
***** Running training *****
  Num examples = 7840
  Num Epochs = 42
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 320
  Gradient Accumulation steps = 20
  Total optimization steps = 1000
  2%|▏         | 21/1000 [00:03<02:30,  6.50it/s]

{'loss': 4.7835, 'learning_rate': 0.004, 'epoch': 0.83}


  4%|▍         | 41/1000 [00:06<02:23,  6.68it/s]

{'loss': 2.8256, 'learning_rate': 0.008, 'epoch': 1.67}


  6%|▌         | 61/1000 [00:09<02:26,  6.39it/s]

{'loss': 2.4451, 'learning_rate': 0.012, 'epoch': 2.5}


  8%|▊         | 81/1000 [00:12<02:04,  7.41it/s]

{'loss': 2.2794, 'learning_rate': 0.016, 'epoch': 3.33}


 10%|█         | 101/1000 [00:15<02:15,  6.61it/s]

{'loss': 2.15, 'learning_rate': 0.02, 'epoch': 4.17}


 12%|█▏        | 121/1000 [00:18<02:02,  7.19it/s]

{'loss': 2.056, 'learning_rate': 0.019555555555555555, 'epoch': 5.0}


 14%|█▍        | 141/1000 [00:21<02:02,  6.99it/s]

{'loss': 1.9731, 'learning_rate': 0.019111111111111113, 'epoch': 5.83}


 16%|█▌        | 161/1000 [00:24<02:03,  6.77it/s]

{'loss': 1.9058, 'learning_rate': 0.018666666666666668, 'epoch': 6.67}


 18%|█▊        | 181/1000 [00:26<01:55,  7.07it/s]

{'loss': 1.8505, 'learning_rate': 0.018222222222222223, 'epoch': 7.5}


 20%|██        | 201/1000 [00:29<01:54,  6.97it/s]

{'loss': 1.7972, 'learning_rate': 0.017777777777777778, 'epoch': 8.33}


 22%|██▏       | 221/1000 [00:32<01:51,  6.99it/s]

{'loss': 1.7667, 'learning_rate': 0.017333333333333336, 'epoch': 9.17}


 24%|██▍       | 241/1000 [00:35<01:40,  7.57it/s]

{'loss': 1.7384, 'learning_rate': 0.01688888888888889, 'epoch': 10.0}


 25%|██▌       | 250/1000 [00:36<01:46,  7.05it/s]Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 10.73408320802264 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 25%|██▌       | 250/1000 [00:50<01:46,  7.05it/s]

{'eval_loss': 1.7333984375, 'eval_runtime': 0.6275, 'eval_samples_per_second': 688.487, 'eval_steps_per_second': 43.03, 'epoch': 10.42}


Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
 26%|██▌       | 261/1000 [01:04<04:29,  2.74it/s]  

{'loss': 1.7074, 'learning_rate': 0.016444444444444446, 'epoch': 10.83}


 28%|██▊       | 281/1000 [01:07<01:45,  6.84it/s]

{'loss': 1.6881, 'learning_rate': 0.016, 'epoch': 11.67}


 30%|███       | 301/1000 [01:10<01:44,  6.71it/s]

{'loss': 1.6322, 'learning_rate': 0.015555555555555557, 'epoch': 12.5}


 32%|███▏      | 321/1000 [01:13<01:48,  6.28it/s]

{'loss': 1.6246, 'learning_rate': 0.015111111111111112, 'epoch': 13.33}


 34%|███▍      | 341/1000 [01:16<01:31,  7.21it/s]

{'loss': 1.6069, 'learning_rate': 0.014666666666666666, 'epoch': 14.17}


 36%|███▌      | 361/1000 [01:19<01:40,  6.33it/s]

{'loss': 1.5772, 'learning_rate': 0.014222222222222223, 'epoch': 15.0}


 38%|███▊      | 381/1000 [01:22<01:23,  7.46it/s]

{'loss': 1.5472, 'learning_rate': 0.013777777777777778, 'epoch': 15.83}


 40%|████      | 401/1000 [01:25<01:32,  6.51it/s]

{'loss': 1.5531, 'learning_rate': 0.013333333333333332, 'epoch': 16.67}


 42%|████▏     | 421/1000 [01:27<01:27,  6.63it/s]

{'loss': 1.5266, 'learning_rate': 0.01288888888888889, 'epoch': 17.5}


 44%|████▍     | 441/1000 [01:30<01:13,  7.62it/s]

{'loss': 1.5199, 'learning_rate': 0.012444444444444445, 'epoch': 18.33}


 46%|████▌     | 461/1000 [01:33<01:11,  7.56it/s]

{'loss': 1.5043, 'learning_rate': 0.012, 'epoch': 19.17}


 48%|████▊     | 481/1000 [01:36<01:15,  6.88it/s]

{'loss': 1.4862, 'learning_rate': 0.011555555555555555, 'epoch': 20.0}


 50%|█████     | 500/1000 [01:38<01:05,  7.65it/s]

{'loss': 1.4942, 'learning_rate': 0.011111111111111112, 'epoch': 20.83}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 10.706514268007595 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 50%|█████     | 500/1000 [01:52<01:05,  7.65it/s]Saving model checkpoint to out/checkpoint-500
Configuration saved in out/checkpoint-500/ipu_config.json


{'eval_loss': 1.5400390625, 'eval_runtime': 0.6162, 'eval_samples_per_second': 701.023, 'eval_steps_per_second': 43.814, 'epoch': 20.83}


Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
 52%|█████▏    | 521/1000 [02:07<01:17,  6.16it/s]  

{'loss': 1.4525, 'learning_rate': 0.010666666666666666, 'epoch': 21.67}


 54%|█████▍    | 541/1000 [02:10<01:10,  6.47it/s]

{'loss': 1.455, 'learning_rate': 0.010222222222222221, 'epoch': 22.5}


 56%|█████▌    | 561/1000 [02:13<01:07,  6.50it/s]

{'loss': 1.4438, 'learning_rate': 0.009777777777777778, 'epoch': 23.33}


 58%|█████▊    | 581/1000 [02:16<00:59,  7.09it/s]

{'loss': 1.4358, 'learning_rate': 0.009333333333333334, 'epoch': 24.17}


 60%|██████    | 601/1000 [02:19<00:56,  7.05it/s]

{'loss': 1.4157, 'learning_rate': 0.008888888888888889, 'epoch': 25.0}


 62%|██████▏   | 621/1000 [02:22<00:53,  7.12it/s]

{'loss': 1.4046, 'learning_rate': 0.008444444444444445, 'epoch': 25.83}


 64%|██████▍   | 641/1000 [02:25<00:55,  6.44it/s]

{'loss': 1.3781, 'learning_rate': 0.008, 'epoch': 26.67}


 66%|██████▌   | 661/1000 [02:27<00:47,  7.07it/s]

{'loss': 1.3857, 'learning_rate': 0.007555555555555556, 'epoch': 27.5}


 68%|██████▊   | 681/1000 [02:30<00:45,  7.01it/s]

{'loss': 1.3919, 'learning_rate': 0.0071111111111111115, 'epoch': 28.33}


 70%|███████   | 701/1000 [02:33<00:42,  7.03it/s]

{'loss': 1.3537, 'learning_rate': 0.006666666666666666, 'epoch': 29.17}


 72%|███████▏  | 721/1000 [02:36<00:37,  7.47it/s]

{'loss': 1.3511, 'learning_rate': 0.006222222222222223, 'epoch': 30.0}


 74%|███████▍  | 741/1000 [02:39<00:37,  6.90it/s]

{'loss': 1.3359, 'learning_rate': 0.0057777777777777775, 'epoch': 30.83}


 75%|███████▌  | 750/1000 [02:40<00:33,  7.53it/s]Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 11.012903372000437 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 75%|███████▌  | 750/1000 [02:57<00:33,  7.53it/s]

{'eval_loss': 1.4873046875, 'eval_runtime': 0.612, 'eval_samples_per_second': 705.875, 'eval_steps_per_second': 44.117, 'epoch': 31.25}


Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
 76%|███████▌  | 761/1000 [03:13<01:39,  2.41it/s]

{'loss': 1.3398, 'learning_rate': 0.005333333333333333, 'epoch': 31.67}


 78%|███████▊  | 781/1000 [03:16<00:33,  6.55it/s]

{'loss': 1.3367, 'learning_rate': 0.004888888888888889, 'epoch': 32.5}


 80%|████████  | 801/1000 [03:19<00:27,  7.12it/s]

{'loss': 1.327, 'learning_rate': 0.0044444444444444444, 'epoch': 33.33}


 82%|████████▏ | 821/1000 [03:22<00:28,  6.38it/s]

{'loss': 1.322, 'learning_rate': 0.004, 'epoch': 34.17}


 84%|████████▍ | 840/1000 [03:25<00:24,  6.48it/s]

{'loss': 1.2932, 'learning_rate': 0.0035555555555555557, 'epoch': 35.0}


 86%|████████▌ | 861/1000 [03:28<00:19,  6.95it/s]

{'loss': 1.2955, 'learning_rate': 0.0031111111111111114, 'epoch': 35.83}


 88%|████████▊ | 881/1000 [03:31<00:17,  6.64it/s]

{'loss': 1.2822, 'learning_rate': 0.0026666666666666666, 'epoch': 36.67}


 90%|█████████ | 901/1000 [03:34<00:14,  6.96it/s]

{'loss': 1.2897, 'learning_rate': 0.0022222222222222222, 'epoch': 37.5}


 92%|█████████▏| 921/1000 [03:37<00:11,  6.70it/s]

{'loss': 1.2786, 'learning_rate': 0.0017777777777777779, 'epoch': 38.33}


 94%|█████████▍| 941/1000 [03:40<00:07,  7.38it/s]

{'loss': 1.2674, 'learning_rate': 0.0013333333333333333, 'epoch': 39.17}


 96%|█████████▌| 961/1000 [03:42<00:05,  6.88it/s]

{'loss': 1.2678, 'learning_rate': 0.0008888888888888889, 'epoch': 40.0}


 98%|█████████▊| 981/1000 [03:45<00:02,  7.32it/s]

{'loss': 1.2452, 'learning_rate': 0.00044444444444444447, 'epoch': 40.83}


100%|██████████| 1000/1000 [03:48<00:00,  6.79it/s]

{'loss': 1.2518, 'learning_rate': 0.0, 'epoch': 41.67}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 11.062077922048047 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

100%|██████████| 1000/1000 [04:03<00:00,  6.79it/s]Saving model checkpoint to out/checkpoint-1000
Configuration saved in out/checkpoint-1000/ipu_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 1000/1000 [04:03<00:00,  4.10it/s]
Saving model checkpoint to trained_model/
Configuration saved in trained_model/ipu_config.json


{'eval_loss': 1.466796875, 'eval_runtime': 0.7101, 'eval_samples_per_second': 608.345, 'eval_steps_per_second': 38.022, 'epoch': 41.67}
{'train_runtime': 243.6652, 'train_samples_per_second': 1313.277, 'train_steps_per_second': 4.104, 'train_loss': 1.632796875, 'epoch': 41.67}


We get a similar evaluation loss as the regular model—a success! In fact after sweeping the learning rate, unit scaling appears to be slightly better here.

Let's celebrate with some unit-scaled Shakespeare...

In [76]:
pipelines.check_model_type = lambda self, supported_models: ...

final_model = UnitScaledNanoGPTModel.from_pretrained("trained_model/")


TEST_INPUT = """PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits"""
pipe = pipeline(
    "text-generation",
    ipu_config=ipu_config.to_dict(),  # TODO: feature request -> no to_dict()
    model=final_model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=True,
)

outputs = pipe(TEST_INPUT, num_return_sequences=2, temperature=0.4)
for i, output in enumerate(outputs):
    print(f"\n===== Completion [{i+1}] =====\n")
    print(output["generated_text"])

IPUConfig {
  "auto_loss_scaling": false,
  "device_iterations": 1,
  "embedding_serialization_factor": 1,
  "enable_half_partials": true,
  "executable_cache_dir": "./exe_cache",
  "execute_encoder_on_cpu_for_generation": false,
  "gradient_accumulation_steps": 20,
  "inference_device_iterations": 1,
  "inference_replication_factor": 1,
  "ipus_per_replica": 1,
  "layers_per_ipu": [
    6
  ],
  "matmul_proportion": 0.6,
  "optimizer_state_offchip": true,
  "optimum_version": "1.6.1",
  "output_mode": "final",
  "recompute_checkpoint_every_layer": false,
  "replicated_tensor_sharding": false,
  "replication_factor": 1,
  "seed": null,
  "transformers_version": "4.25.1"
}

Graph compilation: 100%|██████████| 100/100 [00:01<00:00]



===== Completion [1] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits of the
provost: the strength of your honours and strength
To meet the world in the violence of your father's country's life
Than the ground of the court-wings of all the death.

KING HENRY VI:
My lord, the last is the father for the court?

KING RICHARD III:
No more than I see thee to be a peace
And more than my brother's love in the gods
And the root of my counsel: therefore I mean not
The sister of the charge of the 

===== Completion [2] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits for the head
To see him that live and dear at his country's head,
The sea of his crown, sir, the last of the ground
Of our highness and the sun of his presence
Of the singer of the steed service of her sense,
That she was not her of your grace was for his son.

KING RICHARD III:
What is the done? what thou art thou dost deserve?

DUKE OF

Some neat verse, from a neatly-scaled model. From which we can only conclude:

*Enter: unit scaling*

*Exeunt: loss scaling, automatic loss scaling, etc.*

FIN

---

We hope that practitioners will consider using unit scaling for future projects, particularly those having difficulties with loss scaling or automatic mixed precision. With FP8 on the horizon, these issues are likely to become more prevalent. We hope unit scaling can help.

If you're interested in using unit scaling yourself, or have questions, please do reach out 🙏☎️ We're keen to hear from anyone that has a problem unit scaling might help solve.

The definitions provided here are the closest we have to an "official" PyTorch implementation, but if there's demand for a library tell us and we'll make one!