# Unit Scaling: A How-To Guide

In our paper [Unit Scaling: Out-of-the-Box Low-Precision Training](), we describe a scheme for designing neural networks that have approximate unit variance after every operation.

This can be seen as an alternative to (static) loss scaling, or its automatic variant, as used in Automatic Mixed Precision. Whereas both of those schemes rely on a single, global scaling factor for all the gradients, unit scaling is more fine-grained.

A unit-scaled model adds scaling factors (constant scalar multiplications) to each operation in the computational graph to achieve this unit variance property. The result is a model which naturally produces tensors in the middle of the dynamic range provided by floating-point formats. There's no extra loss-scale hyperparameter—it just works out-of-the-box!

## Implementing unit scaling

Here we demonstrate how to go about unit scaling in practice. This involves re-implementing common neural network layers to add variance-preserving scaling factors.

As explained in the paper, we can sometimes justify using different scaling factors in the forward and backward pass. We introduce a special `scaled()` op which allows us to do just that.

Below we implement two models. The first is a simple implementation of a transformer-decoder. The design and hyperparameters are inspired (though some differ) by Andrej Karpathy's [popular NanoGPT implementation](https://github.com/karpathy/nanoGPT). It's also 🤗-compatible!

The second is the same, but unit-scaled. Let's get stuck in...

## Building a unit-scaled NanoGPT

In [1]:
import math
from typing import Dict, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers.activations import GELUActivation
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel

  from .autonotebook import tqdm as notebook_tqdm


### The MLP layer

We'll start by setting up a basic config for a reasonably small transformer:

In [2]:
class NanoGPTConfig(PretrainedConfig):
    model_type = "nano-gpt"

    def __init__(
        self,
        hidden_size: int = 384,
        num_hidden_layers: int = 6,
        num_attention_heads: int = 6,
        dropout: float = 0.1,
        vocab_size: int = 384,
        eos_token_id: int = 1,
        **kwargs,
    ) -> None:
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.dropout = dropout
        self.vocab_size = vocab_size
        self.eos_token_id = eos_token_id
        super().__init__(**kwargs)

Along with a standard (pre-norm) MLP module:

In [3]:
class MLP(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.ln = nn.LayerNorm(config.hidden_size)
        self.linear_1 = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = GELUActivation()
        self.linear_2 = nn.Linear(config.hidden_size * 4, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: Tensor) -> Tensor:
        hidden_states = self.ln(hidden_states)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return self.dropout(hidden_states)

In [4]:
# TODO: hide
def init_weights(init_fn) -> None:
    def inner_fn(module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            fan_out, fan_in = module.weight.shape  # TODO: correct base impl
            module.weight.data.normal_(mean=0.0, std=init_fn(fan_in, fan_out))
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            embed_dim = module.weight.shape[-1]
            module.weight.data.normal_(mean=0.0, std=init_fn(embed_dim, embed_dim))
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    return inner_fn

In [5]:
# TODO: hide
from seaborn import heatmap
import pandas as pd

np.seterr(divide = 'ignore')
hist_bins = np.array([e for e in range(-14, 16+1)])

def instrument(module):
    stats = {}
    instrument_recursive(module, stats)
    return stats

def instrument_recursive(module, stats, name=''):
    children = list(module.named_children())
    if children:
        for c_name, c in children:
            _name = f'{name}.{c_name}' if name and name != 'blocks' else c_name
            instrument_recursive(c, stats, _name)
    else:
        instrument_terminal(module, stats, name)

def instrument_terminal(module, stats, name):
    module_stats = {}
    def require_input_grads(_module, input):
        for i in input:
            if isinstance(i, Tensor) and i.is_floating_point():
                i.requires_grad_()
    
    module.register_forward_pre_hook(require_input_grads)
    
    if name.split('.')[-1] == 'softmax':
        return
    
    def record_fwd_scale(_module, input, output):
        if isinstance(output, Tensor) and output.is_floating_point():
            module_stats['x'] = np.log2(output.std().item())
    
    module.register_forward_hook(record_fwd_scale)
    
    def record_bwd_scales(_module, grad_input, grad_output):
        grad_input = list(grad_input)
        for g in grad_input:
            if g is not None and isinstance(g, Tensor) \
                and g.is_floating_point() and len(grad_input) == 1:
                module_stats['grad_x'] = np.log2(g.std().item())
        
        for param_name, param in _module.named_parameters():
            if param_name == "weight":
                module_stats['w'] = np.log2(param.std().item())
                if param.grad is not None:
                    module_stats['grad_w'] = np.log2(param.grad.std().item())
    
    module.register_full_backward_hook(record_bwd_scales)
    
    stats[name] = module_stats

def create_histogram(t: Tensor) -> np.ndarray:
    return np.histogram(np.log(abs(t.detach())), hist_bins, density=True)[0]

In [6]:
# TODO: hide
import altair as alt

def visualise(stats):
    df = pd.DataFrame(stats)
    df = df.stack().to_frame('scale (log₂)').reset_index(names=['type', 'op'])
    plot(df)

def plot(df):
    is_x_or_grad_x = (df['type'] == 'x') | (df['type'] == 'grad_x')
    op_order = df[df['type'] == 'x']['op'].tolist()
    colors = ['#6C8EBF', '#FF8000', '#5D8944', '#ED3434']
    
    a = alt.Chart(df[is_x_or_grad_x]).mark_line().encode(
        x=alt.X(
            'scale (log₂):Q',
            axis=alt.Axis(orient='top', values=np.arange(-14, 16+1, 2)),
            scale=alt.Scale(domain=[-14, 16]),
        ),
        y=alt.Y(
            'op:O',
            title='',
            sort=op_order,
        ),
        color=alt.Color(
            'type',
            legend=alt.Legend(title='', labelFontSize=12, symbolSize=100),
            scale=alt.Scale(range=colors[:2]),
            sort='descending'
        ),
    )
    b = alt.Chart(df[~is_x_or_grad_x]).mark_point(size=100).encode(
        x=alt.X(
            'scale (log₂):Q',
            axis=alt.Axis(orient='top', values=np.arange(-14, 16+1, 2)),
            scale=alt.Scale(domain=[-14, 16])
        ),
        y=alt.Y(
            'op:O',
            title='',
            sort=op_order,
        ),
        color=alt.Color(
            'type',
            legend=alt.Legend(title='', labelFontSize=12, symbolSize=100),
            scale=alt.Scale(range=colors[2:]),
            sort='descending'
        ),
        shape=alt.Shape(
            'type',
            scale=alt.Scale(range=['square', 'triangle-down']),
            sort='descending'
        ),
    )
    display(alt.layer(a, b).resolve_scale(color='independent', shape='independent').configure_axis(
        labelFontSize=12,
        titleFontSize=16
    ).properties(
        width=500
    ))

Now we can analyse the scale (i.e. standard deviation) of each operation within the MLP. To do this, we provide an `instrument()` operation, which goes through and tracks the scale coming out of each operation in the forward and backward pass.

Our network's weights will also use the standard glorot initialisation.

So let's feed in a unit normal tensor in both directions, and examine the result.

In [7]:
def analyse_mlp(mlp, config, batch_size=64, seq_len=16):
    stats = instrument(mlp)
    x = torch.normal(0.0, 1.0, size=(batch_size, seq_len, config.hidden_size))
    y = mlp(x)
    y.backward(torch.normal(0.0, 1.0, size=y.shape))
    visualise(stats)

In [8]:
def glorot_init(fan_in, fan_out):
    return ((fan_in + fan_out) / 2) ** -0.5

config = NanoGPTConfig()
mlp = MLP(config).apply(init_weights(glorot_init))
analyse_mlp(mlp, config)

By the end of both passes the scale has dropped by 0.5. This is due to glorot under-scaling slightly and the GeLU dropping the scale further.

Even worse, the grad_w scales are around 0.03. This is largely because glorot scaling (along with all other weight init schemes) only accounts for the forward and grad_x scales.

We'll now implement the equivalent layer using unit scaling, beginning with a basic linear operation:

In [9]:
# TODO: hide
class ScaledGrad(torch.autograd.Function):
  @staticmethod
  def forward(ctx, X, alpha, beta):
    ctx.save_for_backward(
      torch.tensor(beta, dtype=X.dtype))
    return alpha * X

  @staticmethod
  def backward(ctx, grad_Y):
    beta, = ctx.saved_tensors
    return beta * grad_Y, None, None

def scaled(X, alpha=1.0, beta=1.0):
  # Forward: Y = X * alpha
  # Backward: grad_X = grad_Y * beta
  return ScaledGrad.apply(X, alpha, beta)

def geometric_mean(xs):
    xs = np.array(xs)
    return xs.prod() ** (1 / xs.size)

In [10]:
class UnitScaledLinear(nn.Linear):
    def __init__(self, *args, scale_for="fwd, grad_x", **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.scale_for = scale_for
    
    def get_scales(self, input):
        fwd_scale = self.weight.shape[1] ** -0.5
        grad_x_scale = self.weight.shape[0] ** -0.5
        grad_w_scale = np.prod(input.shape[:-1]) ** -0.5
        if self.scale_for == "fwd":
            grad_x_scale = fwd_scale
        elif self.scale_for == "grad_x":
            fwd_scale = grad_x_scale
        elif self.scale_for == "fwd, grad_x":
            fwd_scale = grad_x_scale = geometric_mean([fwd_scale, grad_x_scale])
        else:
            assert False, f"demo implementation has no {self.scale_for} scaling"
        return fwd_scale, grad_x_scale, grad_w_scale
           
    def forward(self, input):
        fwd_scale, grad_x_scale, grad_w_scale = self.get_scales(input)
        input = scaled(input, beta=grad_x_scale)
        weight = scaled(self.weight, beta=grad_w_scale)
        bias = scaled(self.bias, beta=grad_w_scale) if self.bias is not None else None
        output = F.linear(input, weight, bias)
        return scaled(output, alpha=fwd_scale)

Let's break this down a bit. The `forward()` method is still based on the fundamental `F.linear(input, weight, bias)` operation, but has each of its operands and its output scaled.

This is done via the `scaled(X, alpha, beta)` method. This is a special operation in that it has different dynamics in the forward and backward pass, where we have

**Forward:** Y = X * alpha

**Backward:** grad_X = grad_Y * beta

So what scaling factors do we choose? This is determined in `get_scales()`. The standard approach is to set the forward and grad_x scales as a compromise between their "ideal" scale-preserving values (via a geometric mean). See our paper for more details on how we arrive at these ideal values.

We also provide version of the operation which select forward and grad_x scales based on only one of their ideal values.

Note that (for the sake of valid gradients) we're obliged to always have `fwd_scale = grad_x_scale`. However, we're allowed to have a separate scale for grad_w (again, see the paper) which gets its own ideal value.

Now for the GeLU:

In [11]:
class UnitScaledGELU(GELUActivation):
    def forward(self, input):
        fwd_scale = bwd_scale = geometric_mean([0.588, 0.675]) ** -1
        input = scaled(input, beta=bwd_scale)
        output = self.act(input)
        return scaled(output, alpha=fwd_scale)

This looks pretty similar to the linear op. As activation functions are nonlinear, our usual approach of calculating scaling values analytically (i.e. by working through the maths) isn't always possible.

Fortunately, for these elemenwise ops we can calulate them empirically, like so:

In [12]:
def analyse_elemenwise_fn(fn, num_samples=2**22):
    x = torch.normal(0.0, 1.0, size=(num_samples,)).requires_grad_()
    y = fn(x)
    y.backward(torch.normal(0.0, 1.0, size=(num_samples,)))
    print(f"fwd scale={y.std():.3f}, bwd scale={x.grad.std():.3f}")

print("GeLU:", end=" ")
analyse_elemenwise_fn(GELUActivation())
print("Tanh:", end=" ")
analyse_elemenwise_fn(torch.nn.Tanh())

GeLU: fwd scale=0.588, bwd scale=0.676
Tanh: fwd scale=0.628, bwd scale=0.681


Simple! Now we just need to unit scale the layernorm and dropout and we're ready.

These take a similar approach, so we won't go into too much detail here:

In [13]:
class UnitScaledLayerNorm(nn.LayerNorm):
    def forward(self, input: Tensor) -> Tensor:
        scale = (np.prod(self.normalized_shape) / input.nelement()) ** 0.5
        weight = scaled(self.weight, beta=scale)
        bias = scaled(self.bias, beta=scale)
        return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

class UnitScaledDropout(nn.Dropout):
    def forward(self, input: Tensor) -> Tensor:
        # Dropout is typically implemented with a (1-p) ** -1 scaling
        # However, to preserve variance this ought to be (1-p) ** -0.5
        # We correct for this by multiplying by (1-p) ** 0.5
        scale = (1-self.p) ** 0.5
        input = scaled(input, beta=scale)
        output = F.dropout(input, self.p, self.training, self.inplace)
        return scaled(output, alpha=scale)

We're now ready to unit-scale the full MLP 🥳

All this requires is swapping out the old layers for our new ones:

In [14]:
class UnitScaledMLP(MLP):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.linear_1 = UnitScaledLinear(config.hidden_size, config.hidden_size * 4)
        self.act = UnitScaledGELU()
        self.linear_2 = UnitScaledLinear(config.hidden_size * 4, config.hidden_size)
        self.dropout = UnitScaledDropout(config.dropout)

We also change our initialisation to give our weights unit scale. Let's analyse our new unit-scaled MLP:

In [15]:
def unit_init(*args):
    return 1

mlp = UnitScaledMLP(config).apply(init_weights(unit_init))
analyse_mlp(mlp, config)

This is much better! The final scales are almost exactly 1 in both directions.

The compromise fwd & grad_x scaling factors in our linear layers do give temporary non-unit scaling, but this is minor and the second linear layer cancels out the first.

## The self-attention layer

We start with a basic implementation of a (self-)attention module.

(Note that we use [ALiBi](https://arxiv.org/abs/2108.12409) biases over positional embeddings here. This is primarily for simplicity, though these have also been [shown to perform](https://arxiv.org/abs/2210.15424) remarkably well!)

In [16]:
# TODO: hide
class Matmul(nn.Module):
    def forward(self, a: Tensor, b: Tensor, scale: float=1.0) -> Tensor:
        return (a @ b) * scale

def split_heads(tensor: Tensor, num_heads: int, head_size: int) -> Tensor:
    batch_size, seq_len, hidden_size = tensor.shape
    tensor = tensor.view(batch_size, seq_len, num_heads, head_size)
    return tensor.permute(0, 2, 1, 3)


def merge_heads(tensor: Tensor, num_heads: int) -> Tensor:
    tensor = tensor.permute(0, 2, 1, 3).contiguous()
    batch_size, seq_len, num_heads, head_size = tensor.shape
    return tensor.view(batch_size, seq_len, num_heads * head_size)


def causal_mask(seq_len: int, num_heads: int) -> Tensor:
    causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.float16))
    causal_mask = causal_mask.view(1, 1, seq_len, seq_len)
    alibi_mask = gen_alibi_mask(causal_mask, num_heads)
    causal_mask = (1.0 - causal_mask) * -10_000
    return alibi_mask + causal_mask


# Based on https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/alibi/__init__.py
def gen_alibi_mask(causal_mask: Tensor, num_heads: int) -> Tensor:
    distances = causal_mask.to(torch.float32).cumsum(dim=-1)
    slopes = gen_slopes(num_heads)
    return distances.to(torch.float16) * slopes.view(1, num_heads, 1, 1)


def gen_slopes(num_heads: int) -> Tensor:
    n = 2 ** math.floor(math.log2(num_heads))
    m_0 = 2.0 ** (-8.0 / n)
    m = torch.pow(m_0, torch.arange(1, 1 + n))
    if n < num_heads:
        m_hat_0 = 2.0 ** (-4.0 / n)
        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
        m = torch.cat([m, m_hat])
    return m

In [17]:
class Attention(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.ln = nn.LayerNorm(config.hidden_size)
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        
        self.w_qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.qk_matmul = Matmul()
        self.softmax = nn.Softmax(-1)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.qkv_matmul = Matmul()
        self.w_o = nn.Linear(config.hidden_size, config.hidden_size)
        self.residual_dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        seq_len = hidden_states.shape[1]
        head_size = self.hidden_size // self.num_heads

        hidden_states = self.ln(hidden_states)
        q_k_v = self.w_qkv(hidden_states)
        q, k, v = q_k_v.split(self.hidden_size, dim=-1)
        q, k, v = (split_heads(t, self.num_heads, head_size) for t in (q, k, v))

        qk = self.qk_matmul(q, k.transpose(-1, -2), scale=1 / head_size ** 0.5).clone()
        qk += causal_mask(seq_len, self.num_heads) + attention_mask
        qk = self.softmax(qk)
        qk = self.attn_dropout(qk)

        qkv = self.qkv_matmul(qk, v)
        qkv = merge_heads(qkv, self.num_heads)
        qkvo = self.w_o(qkv)
        return self.residual_dropout(qkvo).clone()

In [18]:
def analyse_attn(attention, config, batch_size=64, seq_len=16):
    stats = instrument(attention)
    x = torch.normal(0.0, 1.0, size=(batch_size, seq_len, config.hidden_size))
    attention_mask = torch.zeros(batch_size, 1, 1, seq_len)
    y = attention(x, attention_mask)
    y.backward(torch.normal(0.0, 1.0, size=y.shape))
    visualise(stats)

In [19]:
attention = Attention(config).apply(init_weights(glorot_init))
analyse_attn(attention, config)

There's a lot going on here. The key things to note here are again that the output scale falls by just over half in both directions, and the grad_ws are again significantly under-scaled.

Let's fix this:

In [20]:
class UnitScaledAttention(Attention):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.w_qkv = UnitScaledLinear(
            config.hidden_size, 3 * config.hidden_size, scale_for="fwd"
        )
        self.qk_matmul = UnitScaledMatmul(scale_for="fwd")
        self.softmax = UnitScaledSoftmax(-1)
        self.attn_dropout = UnitScaledDropout(config.dropout)
        self.qkv_matmul = UnitScaledMatmul(scale_for="fwd")
        self.w_o = UnitScaledLinear(config.hidden_size, config.hidden_size)
        self.residual_dropout = UnitScaledDropout(config.dropout)

class UnitScaledMatmul(Matmul):
    def __init__(self, scale_for="fwd, grad_a, grad_b") -> None:
        super().__init__()
        self.scale_for = scale_for
    
    def get_scales(self, a: Tensor, b: Tensor):
        fwd_scale = a.shape[-1] ** -0.5
        grad_a_scale = b.shape[-1] ** -0.5
        grad_b_scale = a.shape[-2] ** -0.5
        if self.scale_for == "fwd":
            grad_a_scale = grad_b_scale = fwd_scale
        elif self.scale_for == "fwd, grad_a, grad_b":
            fwd_scale = grad_a_scale = grad_b_scale = geometric_mean([
                fwd_scale, grad_a_scale, grad_b_scale
            ])
        else:
            assert False, f"demo implementation has no {self.scale_for} scaling"
        return fwd_scale, grad_a_scale, grad_b_scale
    
    def forward(self, a: Tensor, b: Tensor, scale: float=1.0) -> Tensor:
        # ignores provided scale
        fwd_scale, grad_a_scale, grad_b_scale = self.get_scales(a, b)
        a = scaled(a, beta=grad_a_scale)
        b = scaled(b, beta=grad_b_scale)
        output = a @ b
        return scaled(output, alpha=fwd_scale)

class UnitScaledSoftmax(nn.Softmax):
    def forward(self, input: Tensor) -> Tensor:
        fwd_scale = bwd_scale = input.shape[-1] ** 0.5
        input = scaled(input, fwd_scale)
        output = F.softmax(input, self.dim, _stacklevel=5)
        return scaled(output, bwd_scale)

To unit scale the attention layer we again swap out regular layers for unit-scaled ones. This requires altering two more operations: matmul and softmax.

Like the linear layer, we provide scaling the matmul based on varying criteria. We find that scaling each matmul and linear layer here for the forward pass ensures good scaling in both directions. We see this empirically in our analysis:

In [21]:
attention = UnitScaledAttention(config).apply(init_weights(unit_init))
analyse_attn(attention, config, batch_size=64, seq_len=64)

Again, a great improvement for forward, grad_x and grad_w tensors, with all having approximate unit scale. Note that these scaling rules are robust to changes in hyperparameters too. You can experiment with the initial config to verify this.

We now have our key building-blocks 🧱 Let's put it all together!

### The transformer block

In [22]:
class TransformerBlock(nn.Module):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__()
        self.attention_layer = Residual(Attention(config))
        self.mlp_layer = Residual(MLP(config))
        
    def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        hidden_states = self.attention_layer(hidden_states, attention_mask)
        return self.mlp_layer(hidden_states)


class Residual(nn.Module):
    def __init__(self, f: nn.Module):
        super().__init__()
        self.f = f
    
    def forward(self, x: Tensor, *args):
        return x + self.f(x, *args)

In [23]:
class UnitScaledTransformerBlock(TransformerBlock):
    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.attention_layer = UnitScaledResidual(UnitScaledAttention(config))
        self.mlp_layer = UnitScaledResidual(UnitScaledMLP(config))

class UnitScaledResidual(Residual):
    def __init__(self, f: nn.Module, tau: float=0.2):
        super().__init__(f)
        self.tau = tau
    
    def forward(self, x: Tensor, *args):
        _x = scaled(x, beta=self.tau ** 0.5)
        y = self.f(_x, *args)
        return x * (1 - self.tau) ** 0.5 + scaled(y, alpha=self.tau ** 0.5)

Note that the scaling in our residual layer includes an important trick.

It's necessary for unit-scaled models to use a weighted sum when doing their residual-add, to down-weight the residual/trunk branch. If this is not included, training can fail as too much signal comes from each residual (regular models avoid this as their residual implicitly reduces scale).

However, the naïve implementation of this breaks unit scale. We get around this by delaying the weighting until the end of the residual branch in the backward pass (see paper for more details).

### The full transformer

Note that our implementation contains a little extra boilerplate to make it 🤗-compliant.

In [24]:
class NanoGPTModel(PreTrainedModel):
    config_class = NanoGPTConfig

    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.input_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)
        self.ln = nn.LayerNorm(config.hidden_size)
        self.blocks = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.loss_fn = nn.CrossEntropyLoss()
        self.apply(init_weights(glorot_init))

    def forward(
        self,
        input_ids: Tensor,
        attention_mask: Tensor,
        position_ids: Optional[Tensor] = None,
        labels: Optional[Tensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithCrossAttentions:
        if position_ids is None:  # True, except at test time
            position_ids = torch.arange(0, input_ids.shape[-1], device=input_ids.device)
            position_ids = position_ids.unsqueeze(0)

        attention_mask = attention_mask[:, None, None, :].to(dtype=self.dtype)
        attention_mask = (1.0 - attention_mask) * -10_000

        hidden_states = self.input_embeddings(input_ids)
        hidden_states = self.dropout(hidden_states)
        for block in self.blocks:
            hidden_states = block(hidden_states, attention_mask=attention_mask)
        hidden_states = self.ln(hidden_states)
        logits = self.lm_head(hidden_states)

        if labels is None:
            return CausalLMOutputWithCrossAttentions(logits=logits)
        else:
            labels = torch.roll(labels, -1, 1)
            labels[:, -1] = -100  # By default, ignore_index of CrossEntropyLoss is -100
            loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
            return CausalLMOutputWithCrossAttentions(loss=loss)

    def get_input_embeddings(self) -> torch.nn.Embedding:
        return self.input_embeddings

    def prepare_inputs_for_generation(
        self, input_ids: Tensor, **kwargs
    ) -> Dict[str, Tensor]:
        attention_mask = kwargs["attention_mask"]
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        return {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
        }

In [25]:
def analyse_full_model(model, config, batch_size=64, seq_len=16):
    stats = instrument(model)
    input_ids = labels = torch.randint(0, config.vocab_size, size=(batch_size, seq_len))
    attention_mask = torch.zeros(batch_size, seq_len)
    y = model(input_ids, attention_mask, labels=labels).loss
    y.backward()
    visualise(stats)

Here's the scaling for the entire (non-unit-scaled) transformer:

In [26]:
model = NanoGPTModel(config)
analyse_full_model(model, config)

These results clearly demonstrates the inadiquacy of the standard approach to model design when it comes to scale. Grad_x and grad_w values are very far from having unit scale, creating the potential for numerics issues.

A key issue introduced by this layer is the cross-entropy loss definition. This typically uses a `1 / batch_size` term to average over labels, which has the effect of dramatically under-scaling gradients in the backward pass.

We can also see here that a single loss scaling factor for all gradients isn't ideal—what we really need is per-op scaling! Again, we'll fix this for the unit-scaled implementation:

In [27]:
class UnitScaledNanoGPTModel(NanoGPTModel):
    config_class = NanoGPTConfig

    def __init__(self, config: NanoGPTConfig) -> None:
        super().__init__(config)
        self.input_embeddings = UnitScaledEmbedding(config.vocab_size, config.hidden_size)
        self.dropout = UnitScaledDropout(config.dropout)
        self.ln = UnitScaledLayerNorm(config.hidden_size)
        self.blocks = nn.ModuleList([
            UnitScaledTransformerBlock(config)
            for _ in range(config.num_hidden_layers)
        ])
        self.lm_head = UnitScaledLinear(
            config.hidden_size, config.vocab_size, bias=False, scale_for="grad_x"
        )
        self.loss_fn = UnitScaledCrossEntropyLoss()
        self.apply(init_weights(unit_init))

class UnitScaledEmbedding(nn.Embedding):
    def forward(self, input: Tensor) -> Tensor:
        batch_size = np.prod(input.shape)
        weight = scaled(self.weight, beta=self.num_embeddings / batch_size)
        return F.embedding(input, weight, self.padding_idx, self.max_norm,
                           self.norm_type, self.scale_grad_by_freq, self.sparse)

class UnitScaledCrossEntropyLoss(nn.CrossEntropyLoss):    
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        batch_size, seq_len = input.shape
        input = scaled(input, beta=seq_len / (seq_len) ** 0.5)
        loss = F.cross_entropy(input, target, reduction='sum')
        return scaled(loss, alpha=1 / batch_size)

In [28]:
model = UnitScaledNanoGPTModel(config)
analyse_full_model(model, config)

Success! 🥂 🎊 🍾 We've managed to keep every tensor in the model at unit scale for the forward and backward pass. All that was required was to re-implement standard layers with the right scaling factors.

Of course, this only gives us the right scaling *at initialisation*. Scales inevitably drift throughout training, but we've given ourselves headroom by starting in the ideal place. The results in our paper show that this is sufficient to enable accurate, out-the-box training of many models, up to the size of BERT Large (we haven't tested anything larger—yet!).

## Training

### Regular model

To show that this really does work, let's train both our models.

As in Karpathy's original NanoGPT, we'll train on the small Shakespeare dataset for a few minutes and see if we can get something that captures the rough style. Starting with the regular (non-unit-scaled) model:

In [29]:
# TODO: hide
from optimum.graphcore.generation_utils import IPUGenerationMixin
from optimum.graphcore.modeling_utils import (
    PipelineMixin,
    outline_attribute,
    register,
    tied_weight_model,
)

@tied_weight_model(NanoGPTModel)
@register(NanoGPTModel)
class PipelinedNanoGPTModel(NanoGPTModel, PipelineMixin, IPUGenerationMixin):
    def parallelize(self):
        self._hooks = [outline_attribute(self.ln, "LayerNorm")]

    def deparallelize(self):
        pass

@tied_weight_model(UnitScaledNanoGPTModel)
@register(UnitScaledNanoGPTModel)
class PipelinedUnitScaledNanoGPTModel(
    UnitScaledNanoGPTModel, PipelineMixin, IPUGenerationMixin
):
    def parallelize(self):
        self._hooks = [outline_attribute(self.ln, "UnitScaledLayerNorm")]

    def deparallelize(self):
        pass

ModuleNotFoundError: No module named 'poptorch'

In [None]:
import poptorch_experimental_addons as pea

def scaled(X, alpha=1.0, beta=1.0):
  # Forward: Y = X * alpha
  # Backward: grad_X = grad_Y * beta
  return pea.autograd_proxy(X * alpha, X * beta)

In [None]:
from functools import partial

from datasets.load import load_dataset
from optimum.graphcore import (
    IPUConfig,
    IPUTrainer,
    IPUTrainingArguments,
    pipeline,
    pipelines,
)
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

In [None]:
dataset = load_dataset("tiny_shakespeare")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

config = NanoGPTConfig(vocab_size=len(tokenizer), eos_token_id=tokenizer.eos_token_id)
batch_sz = 16
seq_len = 128
ipu_config = IPUConfig(
    gradient_accumulation_steps=5 * 64 // batch_sz,
    layers_per_ipu=[config.num_hidden_layers],
    executable_cache_dir="./exe_cache",
)

def split_and_tokenize(data, seq_len, batch_sz):
    tokens = tokenizer(data["text"])
    seqs = [
        tokens["input_ids"][0][i : i + seq_len]
        for i in range(0, len(tokens["input_ids"][0]), seq_len)
    ]
    seqs = seqs[: int(len(seqs) / batch_sz) * batch_sz]  # make divisible by batch size
    return {"input_ids": seqs}


prep_data = partial(split_and_tokenize, seq_len=seq_len, batch_sz=batch_sz)
tokenized_dataset = dataset.map(
    prep_data, batched=True, remove_columns=dataset["train"].column_names
)

Found cached dataset tiny_shakespeare (/nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e)
100%|██████████| 3/3 [00:00<00:00, 1062.39it/s]
Loading cached processed dataset at /nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e/cache-bebf307cc6e838af.arrow
Loading cached processed dataset at /nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e/cache-c2f9b4618a37a1c8.arrow
Loading cached processed dataset at /nethome/charlieb/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e/cache-6df33a0afb6a9efa.arrow


In [None]:
train_args = IPUTrainingArguments(
    output_dir="out",
    per_device_train_batch_size=batch_sz,
    per_device_eval_batch_size=batch_sz,
    evaluation_strategy="steps",
    eval_steps=250,
    logging_steps=10,
    max_steps=1000,
    weight_decay=0.1,
    warmup_steps=100,
    lr_scheduler_type="linear",
    learning_rate=2e-3,
    report_to="wandb",
    loss_scaling=2**6,
)

model = NanoGPTModel(config)

trainer = IPUTrainer(
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    model=model,
    args=train_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    ipu_config=ipu_config,
)
trainer.train()
trainer.save_model("trained_model/")

max_steps is given, it will override any value given in num_train_epochs
Compiling Model...
Graph compilation: 100%|██████████| 100/100 [01:35<00:00]
2023-03-12T23:31:42.132222Z popart:session 1153668.1153668 W: Rng state buffer was not serialized.You did not load poplar Engine.Remember that if you would like to run the model using the model runtime then you have to create your own buffer and callback in your model runtime application for rngStateTensor.
Compiled/Loaded model in 101.21531434357166 secs
***** Running training *****
  Num examples = 7840
  Num Epochs = 42
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 320
  Gradient Accumulation steps = 20
  Total optimization steps = 1000
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcharlieb[0m ([33mresearch[0m). Use [1m`wandb

  1%|          | 11/1000 [00:01<02:33,  6.44it/s]

{'loss': 5.0271, 'learning_rate': 0.0002, 'epoch': 0.42}


  2%|▏         | 21/1000 [00:03<02:23,  6.81it/s]

{'loss': 3.4369, 'learning_rate': 0.0004, 'epoch': 0.83}


  3%|▎         | 31/1000 [00:04<02:10,  7.41it/s]

{'loss': 2.9324, 'learning_rate': 0.0006, 'epoch': 1.25}


  4%|▍         | 41/1000 [00:05<02:09,  7.38it/s]

{'loss': 2.6041, 'learning_rate': 0.0008, 'epoch': 1.67}


  5%|▌         | 51/1000 [00:07<02:23,  6.59it/s]

{'loss': 2.4713, 'learning_rate': 0.001, 'epoch': 2.08}


  6%|▌         | 61/1000 [00:08<02:09,  7.26it/s]

{'loss': 2.3746, 'learning_rate': 0.0012, 'epoch': 2.5}


  7%|▋         | 71/1000 [00:10<02:16,  6.79it/s]

{'loss': 2.2982, 'learning_rate': 0.0014, 'epoch': 2.92}


  8%|▊         | 81/1000 [00:11<02:22,  6.43it/s]

{'loss': 2.2604, 'learning_rate': 0.0016, 'epoch': 3.33}


  9%|▉         | 91/1000 [00:13<02:10,  6.96it/s]

{'loss': 2.1883, 'learning_rate': 0.0018000000000000002, 'epoch': 3.75}


 10%|█         | 101/1000 [00:14<02:09,  6.93it/s]

{'loss': 2.1166, 'learning_rate': 0.002, 'epoch': 4.17}


 11%|█         | 111/1000 [00:15<02:03,  7.19it/s]

{'loss': 2.0984, 'learning_rate': 0.001977777777777778, 'epoch': 4.58}


 12%|█▏        | 121/1000 [00:17<02:04,  7.08it/s]

{'loss': 2.0443, 'learning_rate': 0.0019555555555555554, 'epoch': 5.0}


 13%|█▎        | 131/1000 [00:18<02:05,  6.92it/s]

{'loss': 1.9692, 'learning_rate': 0.0019333333333333333, 'epoch': 5.42}


 14%|█▍        | 141/1000 [00:20<01:56,  7.39it/s]

{'loss': 1.9189, 'learning_rate': 0.0019111111111111113, 'epoch': 5.83}


 15%|█▌        | 151/1000 [00:21<01:58,  7.19it/s]

{'loss': 1.8759, 'learning_rate': 0.001888888888888889, 'epoch': 6.25}


 16%|█▌        | 161/1000 [00:23<02:02,  6.86it/s]

{'loss': 1.8571, 'learning_rate': 0.0018666666666666666, 'epoch': 6.67}


 17%|█▋        | 171/1000 [00:24<02:02,  6.75it/s]

{'loss': 1.7812, 'learning_rate': 0.0018444444444444446, 'epoch': 7.08}


 18%|█▊        | 181/1000 [00:25<01:51,  7.34it/s]

{'loss': 1.7725, 'learning_rate': 0.0018222222222222223, 'epoch': 7.5}


 19%|█▉        | 191/1000 [00:27<01:58,  6.80it/s]

{'loss': 1.7665, 'learning_rate': 0.0018000000000000002, 'epoch': 7.92}


 20%|██        | 201/1000 [00:28<02:02,  6.50it/s]

{'loss': 1.7496, 'learning_rate': 0.0017777777777777776, 'epoch': 8.33}


 21%|██        | 211/1000 [00:30<01:49,  7.21it/s]

{'loss': 1.7112, 'learning_rate': 0.0017555555555555556, 'epoch': 8.75}


 22%|██▏       | 221/1000 [00:31<01:44,  7.48it/s]

{'loss': 1.6651, 'learning_rate': 0.0017333333333333335, 'epoch': 9.17}


 23%|██▎       | 231/1000 [00:33<01:45,  7.27it/s]

{'loss': 1.6963, 'learning_rate': 0.0017111111111111112, 'epoch': 9.58}


 24%|██▍       | 241/1000 [00:34<01:49,  6.93it/s]

{'loss': 1.6472, 'learning_rate': 0.0016888888888888889, 'epoch': 10.0}


 25%|██▌       | 250/1000 [00:35<01:43,  7.25it/s]

{'loss': 1.6521, 'learning_rate': 0.0016666666666666668, 'epoch': 10.42}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:24<00:00]
Compiled/Loaded model in 28.387929940596223 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16
                                                  
 25%|██▌       | 250/1000 [01:07<01:43,  7.25it/s]

{'eval_loss': 1.68359375, 'eval_runtime': 0.6334, 'eval_samples_per_second': 682.07, 'eval_steps_per_second': 42.629, 'epoch': 10.42}


Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
 26%|██▌       | 261/1000 [01:15<05:40,  2.17it/s]  

{'loss': 1.6067, 'learning_rate': 0.0016444444444444445, 'epoch': 10.83}


 27%|██▋       | 271/1000 [01:16<01:44,  6.96it/s]

{'loss': 1.5835, 'learning_rate': 0.0016222222222222222, 'epoch': 11.25}


 28%|██▊       | 281/1000 [01:17<01:45,  6.82it/s]

{'loss': 1.5747, 'learning_rate': 0.0016, 'epoch': 11.67}


 29%|██▉       | 291/1000 [01:19<01:46,  6.65it/s]

{'loss': 1.5582, 'learning_rate': 0.0015777777777777778, 'epoch': 12.08}


 30%|███       | 301/1000 [01:20<01:42,  6.84it/s]

{'loss': 1.5634, 'learning_rate': 0.0015555555555555557, 'epoch': 12.5}


 31%|███       | 311/1000 [01:22<01:29,  7.74it/s]

{'loss': 1.5479, 'learning_rate': 0.0015333333333333334, 'epoch': 12.92}


 32%|███▏      | 321/1000 [01:23<01:34,  7.15it/s]

{'loss': 1.5112, 'learning_rate': 0.001511111111111111, 'epoch': 13.33}


 33%|███▎      | 331/1000 [01:24<01:27,  7.65it/s]

{'loss': 1.5323, 'learning_rate': 0.001488888888888889, 'epoch': 13.75}


 34%|███▍      | 341/1000 [01:26<01:25,  7.67it/s]

{'loss': 1.5062, 'learning_rate': 0.0014666666666666667, 'epoch': 14.17}


 35%|███▌      | 351/1000 [01:27<01:22,  7.88it/s]

{'loss': 1.4773, 'learning_rate': 0.0014444444444444444, 'epoch': 14.58}


 36%|███▌      | 361/1000 [01:28<01:34,  6.78it/s]

{'loss': 1.4915, 'learning_rate': 0.0014222222222222223, 'epoch': 15.0}


 37%|███▋      | 371/1000 [01:30<01:21,  7.67it/s]

{'loss': 1.4746, 'learning_rate': 0.0014, 'epoch': 15.42}


 38%|███▊      | 381/1000 [01:31<01:19,  7.74it/s]

{'loss': 1.4629, 'learning_rate': 0.001377777777777778, 'epoch': 15.83}


 39%|███▉      | 391/1000 [01:32<01:20,  7.60it/s]

{'loss': 1.4767, 'learning_rate': 0.0013555555555555556, 'epoch': 16.25}


 40%|████      | 401/1000 [01:34<01:28,  6.80it/s]

{'loss': 1.4582, 'learning_rate': 0.0013333333333333333, 'epoch': 16.67}


 41%|████      | 411/1000 [01:35<01:18,  7.52it/s]

{'loss': 1.4351, 'learning_rate': 0.0013111111111111112, 'epoch': 17.08}


 42%|████▏     | 421/1000 [01:37<01:14,  7.72it/s]

{'loss': 1.4083, 'learning_rate': 0.001288888888888889, 'epoch': 17.5}


 43%|████▎     | 431/1000 [01:38<01:19,  7.13it/s]

{'loss': 1.4168, 'learning_rate': 0.0012666666666666666, 'epoch': 17.92}


 44%|████▍     | 441/1000 [01:39<01:20,  6.91it/s]

{'loss': 1.4179, 'learning_rate': 0.0012444444444444445, 'epoch': 18.33}


 45%|████▌     | 451/1000 [01:41<01:16,  7.20it/s]

{'loss': 1.4348, 'learning_rate': 0.0012222222222222224, 'epoch': 18.75}


 46%|████▌     | 461/1000 [01:42<01:10,  7.62it/s]

{'loss': 1.4024, 'learning_rate': 0.0012, 'epoch': 19.17}


 47%|████▋     | 471/1000 [01:44<01:15,  7.02it/s]

{'loss': 1.3835, 'learning_rate': 0.0011777777777777778, 'epoch': 19.58}


 48%|████▊     | 481/1000 [01:45<01:13,  7.11it/s]

{'loss': 1.4061, 'learning_rate': 0.0011555555555555555, 'epoch': 20.0}


 49%|████▉     | 491/1000 [01:46<01:07,  7.51it/s]

{'loss': 1.3834, 'learning_rate': 0.0011333333333333334, 'epoch': 20.42}


 50%|█████     | 500/1000 [01:48<01:11,  6.98it/s]

{'loss': 1.3814, 'learning_rate': 0.0011111111111111111, 'epoch': 20.83}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:00<00:00]
Compiled/Loaded model in 3.5375875793397427 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16
                                                  
 50%|█████     | 500/1000 [01:54<01:11,  6.98it/s]Saving model checkpoint to out/checkpoint-500
Configuration saved in out/checkpoint-500/ipu_config.json


{'eval_loss': 1.517578125, 'eval_runtime': 0.4535, 'eval_samples_per_second': 952.581, 'eval_steps_per_second': 59.536, 'epoch': 20.83}


Graph compilation: 100%|██████████| 100/100 [01:27<00:00]
2023-03-12T23:35:14.319987Z popart:session 1153668.1153668 W: Rng state buffer was not serialized.You did not load poplar Engine.Remember that if you would like to run the model using the model runtime then you have to create your own buffer and callback in your model runtime application for rngStateTensor.
 51%|█████     | 511/1000 [03:29<08:05,  1.01it/s]  

{'loss': 1.3745, 'learning_rate': 0.0010888888888888888, 'epoch': 21.25}


 52%|█████▏    | 521/1000 [03:30<01:26,  5.51it/s]

{'loss': 1.3424, 'learning_rate': 0.0010666666666666667, 'epoch': 21.67}


 53%|█████▎    | 531/1000 [03:32<01:06,  7.04it/s]

{'loss': 1.3691, 'learning_rate': 0.0010444444444444446, 'epoch': 22.08}


 54%|█████▍    | 541/1000 [03:33<01:06,  6.85it/s]

{'loss': 1.3479, 'learning_rate': 0.0010222222222222221, 'epoch': 22.5}


 55%|█████▌    | 551/1000 [03:34<01:00,  7.37it/s]

{'loss': 1.3419, 'learning_rate': 0.001, 'epoch': 22.92}


 56%|█████▌    | 561/1000 [03:36<01:04,  6.85it/s]

{'loss': 1.3301, 'learning_rate': 0.0009777777777777777, 'epoch': 23.33}


 57%|█████▋    | 571/1000 [03:37<01:01,  6.98it/s]

{'loss': 1.3346, 'learning_rate': 0.0009555555555555556, 'epoch': 23.75}


 58%|█████▊    | 581/1000 [03:38<00:56,  7.40it/s]

{'loss': 1.3392, 'learning_rate': 0.0009333333333333333, 'epoch': 24.17}


 59%|█████▉    | 591/1000 [03:40<00:56,  7.29it/s]

{'loss': 1.3545, 'learning_rate': 0.0009111111111111111, 'epoch': 24.58}


 60%|██████    | 601/1000 [03:41<00:56,  7.09it/s]

{'loss': 1.3281, 'learning_rate': 0.0008888888888888888, 'epoch': 25.0}


 61%|██████    | 611/1000 [03:43<00:50,  7.72it/s]

{'loss': 1.2939, 'learning_rate': 0.0008666666666666667, 'epoch': 25.42}


 62%|██████▏   | 621/1000 [03:44<00:51,  7.41it/s]

{'loss': 1.3199, 'learning_rate': 0.0008444444444444444, 'epoch': 25.83}


 63%|██████▎   | 631/1000 [03:45<00:48,  7.65it/s]

{'loss': 1.2957, 'learning_rate': 0.0008222222222222222, 'epoch': 26.25}


 64%|██████▍   | 641/1000 [03:47<00:52,  6.89it/s]

{'loss': 1.2911, 'learning_rate': 0.0008, 'epoch': 26.67}


 65%|██████▌   | 651/1000 [03:48<00:55,  6.33it/s]

{'loss': 1.2979, 'learning_rate': 0.0007777777777777778, 'epoch': 27.08}


 66%|██████▌   | 661/1000 [03:50<00:48,  6.95it/s]

{'loss': 1.2705, 'learning_rate': 0.0007555555555555555, 'epoch': 27.5}


 67%|██████▋   | 671/1000 [03:51<00:50,  6.46it/s]

{'loss': 1.299, 'learning_rate': 0.0007333333333333333, 'epoch': 27.92}


 68%|██████▊   | 681/1000 [03:53<00:43,  7.34it/s]

{'loss': 1.2711, 'learning_rate': 0.0007111111111111111, 'epoch': 28.33}


 69%|██████▉   | 691/1000 [03:54<00:43,  7.18it/s]

{'loss': 1.245, 'learning_rate': 0.000688888888888889, 'epoch': 28.75}


 70%|███████   | 701/1000 [03:55<00:43,  6.86it/s]

{'loss': 1.2573, 'learning_rate': 0.0006666666666666666, 'epoch': 29.17}


 71%|███████   | 711/1000 [03:57<00:40,  7.07it/s]

{'loss': 1.2537, 'learning_rate': 0.0006444444444444444, 'epoch': 29.58}


 72%|███████▏  | 721/1000 [03:58<00:41,  6.79it/s]

{'loss': 1.2746, 'learning_rate': 0.0006222222222222223, 'epoch': 30.0}


 73%|███████▎  | 731/1000 [04:00<00:42,  6.33it/s]

{'loss': 1.2401, 'learning_rate': 0.0006, 'epoch': 30.42}


 74%|███████▍  | 741/1000 [04:01<00:40,  6.46it/s]

{'loss': 1.2556, 'learning_rate': 0.0005777777777777778, 'epoch': 30.83}


 75%|███████▌  | 750/1000 [04:02<00:37,  6.69it/s]

{'loss': 1.2387, 'learning_rate': 0.0005555555555555556, 'epoch': 31.25}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:24<00:00]
Compiled/Loaded model in 27.859491711482406 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16
                                                  
 75%|███████▌  | 750/1000 [04:34<00:37,  6.69it/s]

{'eval_loss': 1.4990234375, 'eval_runtime': 0.6159, 'eval_samples_per_second': 701.402, 'eval_steps_per_second': 43.838, 'epoch': 31.25}


Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
 76%|███████▌  | 761/1000 [04:41<01:48,  2.20it/s]

{'loss': 1.2353, 'learning_rate': 0.0005333333333333334, 'epoch': 31.67}


 77%|███████▋  | 771/1000 [04:42<00:35,  6.54it/s]

{'loss': 1.2286, 'learning_rate': 0.0005111111111111111, 'epoch': 32.08}


 78%|███████▊  | 781/1000 [04:43<00:30,  7.28it/s]

{'loss': 1.2085, 'learning_rate': 0.0004888888888888889, 'epoch': 32.5}


 79%|███████▉  | 791/1000 [04:45<00:30,  6.92it/s]

{'loss': 1.2312, 'learning_rate': 0.00046666666666666666, 'epoch': 32.92}


 80%|████████  | 801/1000 [04:46<00:29,  6.75it/s]

{'loss': 1.2042, 'learning_rate': 0.0004444444444444444, 'epoch': 33.33}


 81%|████████  | 811/1000 [04:47<00:25,  7.50it/s]

{'loss': 1.2057, 'learning_rate': 0.0004222222222222222, 'epoch': 33.75}


 82%|████████▏ | 821/1000 [04:49<00:24,  7.21it/s]

{'loss': 1.208, 'learning_rate': 0.0004, 'epoch': 34.17}


 83%|████████▎ | 831/1000 [04:50<00:25,  6.58it/s]

{'loss': 1.1915, 'learning_rate': 0.00037777777777777777, 'epoch': 34.58}


 84%|████████▍ | 841/1000 [04:52<00:23,  6.65it/s]

{'loss': 1.1996, 'learning_rate': 0.00035555555555555557, 'epoch': 35.0}


 85%|████████▌ | 851/1000 [04:53<00:20,  7.35it/s]

{'loss': 1.1902, 'learning_rate': 0.0003333333333333333, 'epoch': 35.42}


 86%|████████▌ | 861/1000 [04:54<00:20,  6.90it/s]

{'loss': 1.1891, 'learning_rate': 0.0003111111111111111, 'epoch': 35.83}


 87%|████████▋ | 871/1000 [04:56<00:16,  7.66it/s]

{'loss': 1.1698, 'learning_rate': 0.0002888888888888889, 'epoch': 36.25}


 88%|████████▊ | 881/1000 [04:57<00:17,  6.71it/s]

{'loss': 1.1554, 'learning_rate': 0.0002666666666666667, 'epoch': 36.67}


 89%|████████▉ | 891/1000 [04:59<00:16,  6.64it/s]

{'loss': 1.185, 'learning_rate': 0.00024444444444444443, 'epoch': 37.08}


 90%|█████████ | 901/1000 [05:00<00:15,  6.42it/s]

{'loss': 1.1438, 'learning_rate': 0.0002222222222222222, 'epoch': 37.5}


 91%|█████████ | 911/1000 [05:02<00:13,  6.79it/s]

{'loss': 1.1703, 'learning_rate': 0.0002, 'epoch': 37.92}


 92%|█████████▏| 921/1000 [05:03<00:11,  6.86it/s]

{'loss': 1.1346, 'learning_rate': 0.00017777777777777779, 'epoch': 38.33}


 93%|█████████▎| 931/1000 [05:04<00:09,  6.99it/s]

{'loss': 1.1595, 'learning_rate': 0.00015555555555555556, 'epoch': 38.75}


 94%|█████████▍| 941/1000 [05:06<00:08,  7.15it/s]

{'loss': 1.1626, 'learning_rate': 0.00013333333333333334, 'epoch': 39.17}


 95%|█████████▌| 951/1000 [05:07<00:06,  7.20it/s]

{'loss': 1.1399, 'learning_rate': 0.0001111111111111111, 'epoch': 39.58}


 96%|█████████▌| 961/1000 [05:09<00:05,  6.75it/s]

{'loss': 1.1473, 'learning_rate': 8.888888888888889e-05, 'epoch': 40.0}


 97%|█████████▋| 971/1000 [05:10<00:04,  7.13it/s]

{'loss': 1.1273, 'learning_rate': 6.666666666666667e-05, 'epoch': 40.42}


 98%|█████████▊| 981/1000 [05:12<00:02,  6.78it/s]

{'loss': 1.154, 'learning_rate': 4.4444444444444447e-05, 'epoch': 40.83}


 99%|█████████▉| 991/1000 [05:13<00:01,  6.39it/s]

{'loss': 1.1563, 'learning_rate': 2.2222222222222223e-05, 'epoch': 41.25}


100%|██████████| 1000/1000 [05:14<00:00,  7.23it/s]

{'loss': 1.1297, 'learning_rate': 0.0, 'epoch': 41.67}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:00<00:00]
Compiled/Loaded model in 3.576560833491385 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16
                                                   
100%|██████████| 1000/1000 [05:21<00:00,  7.23it/s]Saving model checkpoint to out/checkpoint-1000
Configuration saved in out/checkpoint-1000/ipu_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 1000/1000 [05:21<00:00,  3.11it/s]
Saving model checkpoint to trained_model/
Configuration saved in trained_model/ipu_config.json


{'eval_loss': 1.5205078125, 'eval_runtime': 0.4475, 'eval_samples_per_second': 965.284, 'eval_steps_per_second': 60.33, 'epoch': 41.67}
{'train_runtime': 326.0602, 'train_samples_per_second': 981.414, 'train_steps_per_second': 3.067, 'train_loss': 1.5380458984375, 'epoch': 41.67}


The above should give a Weights and Biases link, which you can click to see the training and evaluation loss curves.

Note that we use loss scaling here, though this model is sufficiently small that it can get away without loss scaling. This certainly doesn't hold for larger models though.

To really get a sense of how the model's doing, we ought to get it to write some Shakespeare. Here's two attempts at *The Tempest*:

In [None]:
pipelines.check_model_type = lambda self, supported_models: ...

final_model = NanoGPTModel.from_pretrained("trained_model/")

TEST_INPUT = """PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits"""
pipe = pipeline(
    "text-generation",
    ipu_config=ipu_config.to_dict(),  # TODO: feature request -> no to_dict()
    model=final_model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=True,
)

outputs = pipe(TEST_INPUT, num_return_sequences=2, temperature=0.4)
for i, output in enumerate(outputs):
    print(f"\n===== Completion [{i+1}] =====\n")
    print(output["generated_text"])

IPUConfig {
  "auto_loss_scaling": false,
  "device_iterations": 1,
  "embedding_serialization_factor": 1,
  "enable_half_partials": true,
  "executable_cache_dir": "./exe_cache",
  "execute_encoder_on_cpu_for_generation": false,
  "gradient_accumulation_steps": 20,
  "inference_device_iterations": 1,
  "inference_replication_factor": 1,
  "ipus_per_replica": 1,
  "layers_per_ipu": [
    6
  ],
  "matmul_proportion": 0.6,
  "optimizer_state_offchip": true,
  "optimum_version": "1.6.1",
  "output_mode": "final",
  "recompute_checkpoint_every_layer": false,
  "replicated_tensor_sharding": false,
  "replication_factor": 1,
  "seed": null,
  "transformers_version": "4.25.1"
}

Graph compilation: 100%|██████████| 100/100 [00:25<00:00]



===== Completion [1] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits of the dead;
And then the times of your honour to the sea
That bear the princes of the base that lives
That I may be a little palase you to him
That you have present made the common made them
And friends that you find the common people.

PRINCE:
Then, I shall be so, I do another them and
The sea of his eyes of his charged that makes me.

LADY ANNE:
Why, being it so?

GLOUCESTER:
I cannot be guarded thee to speak.

BUCK

===== Completion [2] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits of the love
At the sun of his son worthy for your love:
I have been yourself and a good lords,
To say the perfect of the court, and the charity
Shall be so and break in the house of York.

QUEEN ELIZABETH:
The king is not a subject of send the world.

KING RICHARD III:
I cannot tell you for the court-house.

QUEEN ELIZABETH:
A goodly fri

It's not going to win any literary awards, but pretty good for a few minutes of training. Now let's see if our unit-scaled model can do the same...

### Unit-scaled model

All we need to do is swap in our new model, remove the loss scaling (of course!), and increase the learning rate (as our weights are now larger due to unit-initialisation).

In [None]:
# TODO: hide?
from types import MethodType

class _IPUConfig(IPUConfig):
    def to_options(self, *args, **kwargs):
        options = super().to_options(*args, **kwargs)
        options._popart.setPatterns(dict(AutogradProxyOpPattern=True))
        return options

    def for_pod_type(self, *args, **kwargs):
        config = super().for_pod_type(*args, **kwargs)
        config.__class__ = _IPUConfig
        return config

ipu_config.__class__ = _IPUConfig

In [None]:
model = UnitScaledNanoGPTModel(config)

train_args.loss_scaling = 1.0
train_args.learning_rate = 2e-2

🤞 The moment of truth ... let's train our unit-scaled model 🤞

In [None]:
unit_scale_trainer = IPUTrainer(
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    model=model,
    args=train_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    ipu_config=ipu_config,
)

unit_scale_trainer.train()
unit_scale_trainer.save_model("trained_model/")

max_steps is given, it will override any value given in num_train_epochs
Compiling Model...
Graph compilation: 100%|██████████| 100/100 [02:02<00:00]
2023-03-12T23:39:55.384706Z popart:session 1153668.1153668 W: Rng state buffer was not serialized.You did not load poplar Engine.Remember that if you would like to run the model using the model runtime then you have to create your own buffer and callback in your model runtime application for rngStateTensor.
Compiled/Loaded model in 135.06661581527442 secs
***** Running training *****
  Num examples = 7840
  Num Epochs = 42
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 320
  Gradient Accumulation steps = 20
  Total optimization steps = 1000
  1%|          | 11/1000 [00:01<02:16,  7.23it/s]

{'loss': 5.8398, 'learning_rate': 0.002, 'epoch': 0.42}


  2%|▏         | 21/1000 [00:02<02:18,  7.05it/s]

{'loss': 3.726, 'learning_rate': 0.004, 'epoch': 0.83}


  3%|▎         | 31/1000 [00:04<02:28,  6.51it/s]

{'loss': 2.9838, 'learning_rate': 0.006, 'epoch': 1.25}


  4%|▍         | 41/1000 [00:05<02:28,  6.46it/s]

{'loss': 2.6725, 'learning_rate': 0.008, 'epoch': 1.67}


  5%|▌         | 51/1000 [00:07<02:19,  6.81it/s]

{'loss': 2.4875, 'learning_rate': 0.01, 'epoch': 2.08}


  6%|▌         | 61/1000 [00:08<02:23,  6.55it/s]

{'loss': 2.4027, 'learning_rate': 0.012, 'epoch': 2.5}


  7%|▋         | 71/1000 [00:10<02:16,  6.82it/s]

{'loss': 2.3176, 'learning_rate': 0.013999999999999999, 'epoch': 2.92}


  8%|▊         | 81/1000 [00:11<02:21,  6.51it/s]

{'loss': 2.2305, 'learning_rate': 0.016, 'epoch': 3.33}


  9%|▉         | 91/1000 [00:13<02:25,  6.23it/s]

{'loss': 2.1574, 'learning_rate': 0.018000000000000002, 'epoch': 3.75}


 10%|█         | 101/1000 [00:14<02:07,  7.03it/s]

{'loss': 2.1291, 'learning_rate': 0.02, 'epoch': 4.17}


 11%|█         | 111/1000 [00:16<02:17,  6.48it/s]

{'loss': 2.0818, 'learning_rate': 0.01977777777777778, 'epoch': 4.58}


 12%|█▏        | 121/1000 [00:17<02:20,  6.27it/s]

{'loss': 2.0438, 'learning_rate': 0.019555555555555555, 'epoch': 5.0}


 13%|█▎        | 131/1000 [00:19<02:10,  6.66it/s]

{'loss': 1.9888, 'learning_rate': 0.019333333333333334, 'epoch': 5.42}


 14%|█▍        | 141/1000 [00:20<02:06,  6.77it/s]

{'loss': 1.9536, 'learning_rate': 0.019111111111111113, 'epoch': 5.83}


 15%|█▌        | 151/1000 [00:21<02:01,  6.97it/s]

{'loss': 1.9127, 'learning_rate': 0.01888888888888889, 'epoch': 6.25}


 16%|█▌        | 161/1000 [00:23<01:54,  7.30it/s]

{'loss': 1.906, 'learning_rate': 0.018666666666666668, 'epoch': 6.67}


 17%|█▋        | 171/1000 [00:24<01:58,  6.99it/s]

{'loss': 1.8675, 'learning_rate': 0.018444444444444447, 'epoch': 7.08}


 18%|█▊        | 181/1000 [00:26<02:10,  6.29it/s]

{'loss': 1.8376, 'learning_rate': 0.018222222222222223, 'epoch': 7.5}


 19%|█▉        | 191/1000 [00:27<02:03,  6.54it/s]

{'loss': 1.8215, 'learning_rate': 0.018000000000000002, 'epoch': 7.92}


 20%|██        | 201/1000 [00:29<02:00,  6.62it/s]

{'loss': 1.7811, 'learning_rate': 0.017777777777777778, 'epoch': 8.33}


 21%|██        | 211/1000 [00:30<02:01,  6.50it/s]

{'loss': 1.7716, 'learning_rate': 0.017555555555555557, 'epoch': 8.75}


 22%|██▏       | 221/1000 [00:32<01:55,  6.75it/s]

{'loss': 1.7611, 'learning_rate': 0.017333333333333336, 'epoch': 9.17}


 23%|██▎       | 231/1000 [00:33<01:52,  6.84it/s]

{'loss': 1.757, 'learning_rate': 0.01711111111111111, 'epoch': 9.58}


 24%|██▍       | 241/1000 [00:35<01:50,  6.89it/s]

{'loss': 1.728, 'learning_rate': 0.01688888888888889, 'epoch': 10.0}


 25%|██▌       | 250/1000 [00:36<01:41,  7.42it/s]

{'loss': 1.6922, 'learning_rate': 0.016666666666666666, 'epoch': 10.42}


Compiling Model...
2023-03-12T23:41:00.767441Z popart:popart 1153668.1153668 W: `autograd_proxy(fwd, proxy)` has not pruned the forward pass of `proxy`, leading to inefficient execution - please use the setting: `PopTorchOptions._popart.setPatterns(dict(AutogradProxyOpPattern=True))`
2023-03-12T23:41:00.767887Z popart:popart 1153668.1153668 W: `autograd_proxy(fwd, proxy)` has not pruned the forward pass of `proxy`, leading to inefficient execution - please use the setting: `PopTorchOptions._popart.setPatterns(dict(AutogradProxyOpPattern=True))`
2023-03-12T23:41:00.768183Z popart:popart 1153668.1153668 W: `autograd_proxy(fwd, proxy)` has not pruned the forward pass of `proxy`, leading to inefficient execution - please use the setting: `PopTorchOptions._popart.setPatterns(dict(AutogradProxyOpPattern=True))`
2023-03-12T23:41:00.768481Z popart:popart 1153668.1153668 W: `autograd_proxy(fwd, proxy)` has not pruned the forward pass of `proxy`, leading to inefficient execution - please use the

{'eval_loss': 1.7333984375, 'eval_runtime': 1.1379, 'eval_samples_per_second': 379.656, 'eval_steps_per_second': 23.729, 'epoch': 10.42}


Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
 26%|██▌       | 261/1000 [01:58<10:04,  1.22it/s]  

{'loss': 1.7271, 'learning_rate': 0.016444444444444446, 'epoch': 10.83}


 27%|██▋       | 271/1000 [01:59<02:07,  5.72it/s]

{'loss': 1.6881, 'learning_rate': 0.01622222222222222, 'epoch': 11.25}


 28%|██▊       | 281/1000 [02:01<01:40,  7.14it/s]

{'loss': 1.679, 'learning_rate': 0.016, 'epoch': 11.67}


 29%|██▉       | 291/1000 [02:02<01:45,  6.72it/s]

{'loss': 1.6436, 'learning_rate': 0.015777777777777776, 'epoch': 12.08}


 30%|███       | 301/1000 [02:03<01:46,  6.56it/s]

{'loss': 1.6207, 'learning_rate': 0.015555555555555557, 'epoch': 12.5}


 31%|███       | 311/1000 [02:05<01:39,  6.89it/s]

{'loss': 1.6296, 'learning_rate': 0.015333333333333334, 'epoch': 12.92}


 32%|███▏      | 321/1000 [02:06<01:35,  7.14it/s]

{'loss': 1.6132, 'learning_rate': 0.015111111111111112, 'epoch': 13.33}


 33%|███▎      | 331/1000 [02:08<01:38,  6.77it/s]

{'loss': 1.6141, 'learning_rate': 0.014888888888888889, 'epoch': 13.75}


 34%|███▍      | 341/1000 [02:09<01:32,  7.16it/s]

{'loss': 1.6121, 'learning_rate': 0.014666666666666666, 'epoch': 14.17}


 35%|███▌      | 351/1000 [02:11<01:33,  6.93it/s]

{'loss': 1.572, 'learning_rate': 0.014444444444444444, 'epoch': 14.58}


 36%|███▌      | 361/1000 [02:12<01:39,  6.45it/s]

{'loss': 1.5707, 'learning_rate': 0.014222222222222223, 'epoch': 15.0}


 37%|███▋      | 371/1000 [02:14<01:31,  6.88it/s]

{'loss': 1.5278, 'learning_rate': 0.013999999999999999, 'epoch': 15.42}


 38%|███▊      | 381/1000 [02:15<01:32,  6.68it/s]

{'loss': 1.5702, 'learning_rate': 0.013777777777777778, 'epoch': 15.83}


 39%|███▉      | 391/1000 [02:17<01:31,  6.66it/s]

{'loss': 1.551, 'learning_rate': 0.013555555555555557, 'epoch': 16.25}


 40%|████      | 401/1000 [02:18<01:23,  7.20it/s]

{'loss': 1.5581, 'learning_rate': 0.013333333333333332, 'epoch': 16.67}


 41%|████      | 411/1000 [02:19<01:27,  6.71it/s]

{'loss': 1.5416, 'learning_rate': 0.013111111111111112, 'epoch': 17.08}


 42%|████▏     | 421/1000 [02:21<01:24,  6.86it/s]

{'loss': 1.5026, 'learning_rate': 0.01288888888888889, 'epoch': 17.5}


 43%|████▎     | 431/1000 [02:22<01:20,  7.08it/s]

{'loss': 1.5347, 'learning_rate': 0.012666666666666666, 'epoch': 17.92}


 44%|████▍     | 441/1000 [02:24<01:21,  6.88it/s]

{'loss': 1.5135, 'learning_rate': 0.012444444444444445, 'epoch': 18.33}


 45%|████▌     | 451/1000 [02:25<01:15,  7.24it/s]

{'loss': 1.4982, 'learning_rate': 0.012222222222222223, 'epoch': 18.75}


 46%|████▌     | 461/1000 [02:27<01:22,  6.51it/s]

{'loss': 1.5139, 'learning_rate': 0.012, 'epoch': 19.17}


 47%|████▋     | 471/1000 [02:28<01:19,  6.66it/s]

{'loss': 1.4805, 'learning_rate': 0.011777777777777778, 'epoch': 19.58}


 48%|████▊     | 481/1000 [02:30<01:17,  6.68it/s]

{'loss': 1.5014, 'learning_rate': 0.011555555555555555, 'epoch': 20.0}


 49%|████▉     | 491/1000 [02:31<01:14,  6.86it/s]

{'loss': 1.4911, 'learning_rate': 0.011333333333333332, 'epoch': 20.42}


 50%|█████     | 500/1000 [02:32<01:07,  7.37it/s]

{'loss': 1.4955, 'learning_rate': 0.011111111111111112, 'epoch': 20.83}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 11.303244029171765 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 50%|█████     | 500/1000 [02:48<01:07,  7.37it/s]Saving model checkpoint to out/checkpoint-500
Configuration saved in out/checkpoint-500/ipu_config.json


{'eval_loss': 1.548828125, 'eval_runtime': 0.7062, 'eval_samples_per_second': 611.732, 'eval_steps_per_second': 38.233, 'epoch': 20.83}


Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
 51%|█████     | 511/1000 [03:02<03:06,  2.63it/s]  

{'loss': 1.4493, 'learning_rate': 0.010888888888888889, 'epoch': 21.25}


 52%|█████▏    | 521/1000 [03:04<01:18,  6.13it/s]

{'loss': 1.4591, 'learning_rate': 0.010666666666666666, 'epoch': 21.67}


 53%|█████▎    | 531/1000 [03:05<01:09,  6.73it/s]

{'loss': 1.4536, 'learning_rate': 0.010444444444444445, 'epoch': 22.08}


 54%|█████▍    | 541/1000 [03:06<01:02,  7.36it/s]

{'loss': 1.4563, 'learning_rate': 0.010222222222222221, 'epoch': 22.5}


 55%|█████▌    | 551/1000 [03:08<01:06,  6.78it/s]

{'loss': 1.4385, 'learning_rate': 0.01, 'epoch': 22.92}


 56%|█████▌    | 561/1000 [03:09<01:04,  6.82it/s]

{'loss': 1.4509, 'learning_rate': 0.009777777777777778, 'epoch': 23.33}


 57%|█████▋    | 571/1000 [03:11<01:00,  7.03it/s]

{'loss': 1.4318, 'learning_rate': 0.009555555555555557, 'epoch': 23.75}


 58%|█████▊    | 581/1000 [03:12<00:57,  7.23it/s]

{'loss': 1.4503, 'learning_rate': 0.009333333333333334, 'epoch': 24.17}


 59%|█████▉    | 591/1000 [03:14<01:00,  6.77it/s]

{'loss': 1.396, 'learning_rate': 0.009111111111111111, 'epoch': 24.58}


 60%|██████    | 601/1000 [03:15<00:57,  6.95it/s]

{'loss': 1.4438, 'learning_rate': 0.008888888888888889, 'epoch': 25.0}


 61%|██████    | 611/1000 [03:17<01:01,  6.35it/s]

{'loss': 1.4076, 'learning_rate': 0.008666666666666668, 'epoch': 25.42}


 62%|██████▏   | 621/1000 [03:18<00:50,  7.53it/s]

{'loss': 1.4074, 'learning_rate': 0.008444444444444445, 'epoch': 25.83}


 63%|██████▎   | 631/1000 [03:19<00:50,  7.31it/s]

{'loss': 1.3782, 'learning_rate': 0.008222222222222223, 'epoch': 26.25}


 64%|██████▍   | 641/1000 [03:21<00:47,  7.53it/s]

{'loss': 1.3779, 'learning_rate': 0.008, 'epoch': 26.67}


 65%|██████▌   | 651/1000 [03:22<00:50,  6.88it/s]

{'loss': 1.3929, 'learning_rate': 0.007777777777777778, 'epoch': 27.08}


 66%|██████▌   | 661/1000 [03:23<00:46,  7.22it/s]

{'loss': 1.377, 'learning_rate': 0.007555555555555556, 'epoch': 27.5}


 67%|██████▋   | 671/1000 [03:25<00:48,  6.72it/s]

{'loss': 1.3975, 'learning_rate': 0.007333333333333333, 'epoch': 27.92}


 68%|██████▊   | 681/1000 [03:26<00:43,  7.35it/s]

{'loss': 1.3966, 'learning_rate': 0.0071111111111111115, 'epoch': 28.33}


 69%|██████▉   | 691/1000 [03:28<00:43,  7.10it/s]

{'loss': 1.3677, 'learning_rate': 0.006888888888888889, 'epoch': 28.75}


 70%|███████   | 701/1000 [03:29<00:42,  7.08it/s]

{'loss': 1.3507, 'learning_rate': 0.006666666666666666, 'epoch': 29.17}


 71%|███████   | 711/1000 [03:30<00:42,  6.84it/s]

{'loss': 1.3469, 'learning_rate': 0.006444444444444445, 'epoch': 29.58}


 72%|███████▏  | 721/1000 [03:32<00:37,  7.47it/s]

{'loss': 1.3599, 'learning_rate': 0.006222222222222223, 'epoch': 30.0}


 73%|███████▎  | 731/1000 [03:33<00:37,  7.08it/s]

{'loss': 1.3308, 'learning_rate': 0.006, 'epoch': 30.42}


 74%|███████▍  | 741/1000 [03:35<00:40,  6.38it/s]

{'loss': 1.3429, 'learning_rate': 0.0057777777777777775, 'epoch': 30.83}


 75%|███████▌  | 750/1000 [03:36<00:37,  6.69it/s]

{'loss': 1.3476, 'learning_rate': 0.005555555555555556, 'epoch': 31.25}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 10.792989015579224 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

 75%|███████▌  | 750/1000 [03:51<00:37,  6.69it/s]

{'eval_loss': 1.4873046875, 'eval_runtime': 0.7039, 'eval_samples_per_second': 613.734, 'eval_steps_per_second': 38.358, 'epoch': 31.25}


Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
 76%|███████▌  | 761/1000 [04:06<01:31,  2.62it/s]

{'loss': 1.3406, 'learning_rate': 0.005333333333333333, 'epoch': 31.67}


 77%|███████▋  | 771/1000 [04:07<00:36,  6.19it/s]

{'loss': 1.3567, 'learning_rate': 0.0051111111111111105, 'epoch': 32.08}


 78%|███████▊  | 781/1000 [04:08<00:32,  6.70it/s]

{'loss': 1.315, 'learning_rate': 0.004888888888888889, 'epoch': 32.5}


 79%|███████▉  | 791/1000 [04:10<00:30,  6.97it/s]

{'loss': 1.3331, 'learning_rate': 0.004666666666666667, 'epoch': 32.92}


 80%|████████  | 801/1000 [04:11<00:26,  7.40it/s]

{'loss': 1.3228, 'learning_rate': 0.0044444444444444444, 'epoch': 33.33}


 81%|████████  | 811/1000 [04:13<00:34,  5.42it/s]

{'loss': 1.3225, 'learning_rate': 0.004222222222222223, 'epoch': 33.75}


 82%|████████▏ | 821/1000 [04:14<00:25,  6.98it/s]

{'loss': 1.3359, 'learning_rate': 0.004, 'epoch': 34.17}


 83%|████████▎ | 831/1000 [04:16<00:23,  7.17it/s]

{'loss': 1.291, 'learning_rate': 0.003777777777777778, 'epoch': 34.58}


 84%|████████▍ | 841/1000 [04:17<00:22,  7.17it/s]

{'loss': 1.3017, 'learning_rate': 0.0035555555555555557, 'epoch': 35.0}


 85%|████████▌ | 851/1000 [04:19<00:19,  7.55it/s]

{'loss': 1.294, 'learning_rate': 0.003333333333333333, 'epoch': 35.42}


 86%|████████▌ | 861/1000 [04:20<00:21,  6.53it/s]

{'loss': 1.304, 'learning_rate': 0.0031111111111111114, 'epoch': 35.83}


 87%|████████▋ | 871/1000 [04:21<00:17,  7.51it/s]

{'loss': 1.3004, 'learning_rate': 0.0028888888888888888, 'epoch': 36.25}


 88%|████████▊ | 881/1000 [04:23<00:17,  6.97it/s]

{'loss': 1.2693, 'learning_rate': 0.0026666666666666666, 'epoch': 36.67}


 89%|████████▉ | 891/1000 [04:24<00:14,  7.37it/s]

{'loss': 1.3039, 'learning_rate': 0.0024444444444444444, 'epoch': 37.08}


 90%|█████████ | 901/1000 [04:26<00:13,  7.15it/s]

{'loss': 1.2831, 'learning_rate': 0.0022222222222222222, 'epoch': 37.5}


 91%|█████████ | 911/1000 [04:27<00:11,  7.54it/s]

{'loss': 1.2877, 'learning_rate': 0.002, 'epoch': 37.92}


 92%|█████████▏| 921/1000 [04:28<00:11,  7.08it/s]

{'loss': 1.2669, 'learning_rate': 0.0017777777777777779, 'epoch': 38.33}


 93%|█████████▎| 931/1000 [04:30<00:10,  6.53it/s]

{'loss': 1.2662, 'learning_rate': 0.0015555555555555557, 'epoch': 38.75}


 94%|█████████▍| 941/1000 [04:31<00:08,  6.99it/s]

{'loss': 1.281, 'learning_rate': 0.0013333333333333333, 'epoch': 39.17}


 95%|█████████▌| 951/1000 [04:33<00:07,  6.76it/s]

{'loss': 1.2652, 'learning_rate': 0.0011111111111111111, 'epoch': 39.58}


 96%|█████████▌| 961/1000 [04:34<00:06,  6.31it/s]

{'loss': 1.2788, 'learning_rate': 0.0008888888888888889, 'epoch': 40.0}


 97%|█████████▋| 971/1000 [04:36<00:04,  6.80it/s]

{'loss': 1.2463, 'learning_rate': 0.0006666666666666666, 'epoch': 40.42}


 98%|█████████▊| 981/1000 [04:37<00:02,  7.34it/s]

{'loss': 1.256, 'learning_rate': 0.00044444444444444447, 'epoch': 40.83}


 99%|█████████▉| 991/1000 [04:39<00:01,  6.65it/s]

{'loss': 1.2739, 'learning_rate': 0.00022222222222222223, 'epoch': 41.25}


100%|██████████| 1000/1000 [04:40<00:00,  7.35it/s]

{'loss': 1.2407, 'learning_rate': 0.0, 'epoch': 41.67}


Compiling Model...
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Compiled/Loaded model in 11.236214805394411 secs
***** Running Evaluation *****
  Num examples = 432
  Batch size = 16

100%|██████████| 1000/1000 [04:54<00:00,  7.35it/s]Saving model checkpoint to out/checkpoint-1000
Configuration saved in out/checkpoint-1000/ipu_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 1000/1000 [04:55<00:00,  3.39it/s]
Saving model checkpoint to trained_model/
Configuration saved in trained_model/ipu_config.json


{'eval_loss': 1.46875, 'eval_runtime': 0.6953, 'eval_samples_per_second': 621.352, 'eval_steps_per_second': 38.834, 'epoch': 41.67}
{'train_runtime': 295.0799, 'train_samples_per_second': 1084.452, 'train_steps_per_second': 3.389, 'train_loss': 1.634490234375, 'epoch': 41.67}


We get a similar evaluation loss as the regular model—a success!

Let's celebrate with some unit-scaled Shakespeare...

In [None]:
pipelines.check_model_type = lambda self, supported_models: ...

final_model = UnitScaledNanoGPTModel.from_pretrained("trained_model/")


TEST_INPUT = """PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits"""
pipe = pipeline(
    "text-generation",
    ipu_config=ipu_config.to_dict(),  # TODO: feature request -> no to_dict()
    model=final_model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=True,
)

outputs = pipe(TEST_INPUT, num_return_sequences=2, temperature=0.4)
for i, output in enumerate(outputs):
    print(f"\n===== Completion [{i+1}] =====\n")
    print(output["generated_text"])

IPUConfig {
  "auto_loss_scaling": false,
  "device_iterations": 1,
  "embedding_serialization_factor": 1,
  "enable_half_partials": true,
  "executable_cache_dir": "./exe_cache",
  "execute_encoder_on_cpu_for_generation": false,
  "gradient_accumulation_steps": 20,
  "inference_device_iterations": 1,
  "inference_replication_factor": 1,
  "ipus_per_replica": 1,
  "layers_per_ipu": [
    6
  ],
  "matmul_proportion": 0.6,
  "optimizer_state_offchip": true,
  "optimum_version": "1.6.1",
  "output_mode": "final",
  "recompute_checkpoint_every_layer": false,
  "replicated_tensor_sharding": false,
  "replication_factor": 1,
  "seed": null,
  "transformers_version": "4.25.1"
}

Graph compilation:  18%|█▊        | 18/100 [00:14<00:31]2023-03-12T23:45:16.933967Z popart:popart 1153668.1153668 W: `autograd_proxy(fwd, proxy)` has not pruned the forward pass of `proxy`, leading to inefficient execution - please use the setting: `PopTorchOptions._popart.setPatterns(dict(AutogradProxyOpPattern=True


===== Completion [1] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits of the people,
Which then I speak with me to the streets,
And then I am so fair soul of a man
That I was slew them such sounded to the body.

DUKE VINCENTIO:
The king is the duke's death.

CAMILLO:
My lord,
Stand so I should be so.

AUTOLYCUS:
I have pass'd him to our son a prince:
The gentleman should be have the duke of his honour,
And presently and he private still stands
Which he hath pride him of your traitors to 

===== Completion [2] =====

PROSPERO:
Our revels now are ended. These our actors,
As I foretold you, were all spirits for this:
The manners of your honour of my country's head,
That I do not so long have the world is off,
And sent it of the house of York and more
Than the earth of his death.

KING RICHARD II:
Why, he was a traitor of all the common soul than to me?

KING RICHARD III:
Ay, thou wilt do thy soul thy death.

Second Keeper:
The county of thi

Some neat verse, from a neatly-scaled model. From which we can only conclude:

*Enter: unit scaling*

*Exeunt: loss scaling, automatic loss scaling*

FIN

---

We hope that practitioners will consider using unit scaling for future projects, particularly those having difficulties with loss scaling or automatic mixed precision. With FP8 on the horizon, these issues are likely to become more prevalent. We hope unit scaling can help.

If you're interested in using unit scaling yourself, or have questions, please do reach out 🙏☎️ We're keen to hear from anyone that has a problem unit scaling might help solve.

The definitions provided here are the closest we have to an "official" PyTorch implementation, but if there's demand for a library tell us and we'll make one!