# Transformers Implementation

Here I'm going to implement the transformer model from [Attention is All You Need](https://arxiv.org/abs/1706.03762).

In [None]:
from typing import Union, Tuple, Optional

from math import sqrt

import xaby as xb
import xaby.nn as xn
import xaby.random as xr

from xaby import jnp
import numpy as np
import jax
from jax.ops import index, index_add, index_update

First up we need an embedding function, really common for language models.

In [None]:
padding_val = 0
w = xb.random.normal((16, 20))
jnp.concatenate([jnp.zeros((16, 1), dtype=jnp.float32) + padding_val, w], axis=1)[:, 0]

In [None]:
class embedding(xb.Fn):
    def __init__(self, num_embeddings, embedding_dim, padding_idx: int=0, padding_val: float=0.):
        def embedding(x: xb.ArrayList, params: dict) -> xb.ArrayList:
            indices = x[0]
            if jnp.dtype(indices).type not in {jnp.int8, jnp.int16, jnp.int32, jnp.int64}:
                raise ValueError("Input array must have an integer dtype")
            
            weights = params["weights"]
            pad = jnp.zeros((embedding_dim, 1), dtype=jnp.float32) + padding_val
            padded_weights = jnp.concatenate([pad, weights], axis=1)
            return xb.pack(padded_weights[:, indices].T)
        
        super().__init__(jax.jit(embedding), 1, 1, "embedding")
        
        self.params["weights"] = xr.normal((embedding_dim, num_embeddings))
        
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        
    def __repr__(self):
        return f"embedding({self.num_embeddings}, {self.embedding_dim})"

In [None]:
a = xb.array([2,4,9,1,3,0,0,0])
xb.pack(a) >> embedding(20, 5)

Next up is to implement the scaled dot-product attention. It's defined in the paper as

$$ 
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left( \frac{QK^\top}{\sqrt{d_k}}\right) V
$$

There's also an optional mask before the softmax. Since this doesn't require any parameters, I can implement it completely with (JAX) numpy.

In [None]:
@xb.fn
def attention(queries: jnp.DeviceArray, keys: jnp.DeviceArray, values: jnp.DeviceArray, mask: jnp.DeviceArray):
    """ Computes the masked scaled dot-product attention """
    
    # alignment scores
    scores = jnp.matmul(queries, keys.T) / jnp.sqrt(keys.shape[1])
    
    masked = mask * scores
    scores = jnp.where(masked != 0, x=scores, y=jnp.ones_like(scores)*-jnp.inf)

    # weights
    weights = jax.nn.softmax(scores, axis=1)
    
    # If there are rows of all -inf before the softmax, the output will be rows of all nans. Convert those
    # nans into 0s
  
    return jnp.nan_to_num(jnp.matmul(weights, values))

In [None]:
enc_steps = 10
dec_steps = 9
embedding_dim = 512

q, k = xb.random.normal((dec_steps, embedding_dim)), xb.random.normal((enc_steps, embedding_dim))
v, mask = k, jnp.ones((dec_steps, enc_steps))

%timeit xb.pack(q, k, v, mask) >> attention

In [None]:
def attention_head(embedding_dim: int, model_dim: int):
    """ Expects four arrays, [Queries, Keys, Values, Mask]"""
    inputs = xb.split(xn.linear(embedding_dim, model_dim, bias=False), 
                      xn.linear(embedding_dim, model_dim, bias=False),
                      xn.linear(embedding_dim, model_dim, bias=False),
                      xb.skip)
    
    head = inputs >> attention
    head.name = "attention_head"
    return head

In [None]:
def multihead_attention(embedding_dim: int, model_dim: int, n_heads: int, masked=False):
    """ Expects four arrays: queries, keys, values, mask """
    if model_dim % n_heads != 0:
        raise ValueError("Model dimension must be evenly divisible by the number of heads")
    
    heads = [attention_head(embedding_dim, model_dim // n_heads) for _ in range(n_heads)]
    
    attention = xb.parallel(*heads) >> xb.concatenate(axis=1) >> xn.linear(model_dim, model_dim, bias=0)
    attention.name = "multihead_attention"
    
    return attention

In [None]:
embedding_dim = model_dim = 16
n_heads = 4

q, k = xb.random.normal((enc_steps, embedding_dim)), xb.random.normal((enc_steps, embedding_dim))
v, mask = k, jnp.ones((enc_steps, enc_steps))

xb.pack(q, k, v, mask) >> multihead_attention(embedding_dim, model_dim, n_heads)

So far I've only been developing this to work on one training example at a time, but typically you'd want to train on a batch of examples. Luckily, JAX makes it pretty easy to map a function over a batch dimension with `vmap`. Let's see if I can get it to work.

In [None]:
# Testing out using JAX to map the attention model over batches
atten = multihead_attention(embedding_dim, 16, 4)

# To get it to work, in_axes must match the data structures, so use xb.pack to create an ArrayList
# must be the same shape as the input ArrayList
atten.forward = jax.vmap(atten.forward, in_axes=(xb.pack(0,0,0,0), None))

# Add batch dimension, batch size = 10
q, k = xb.random.normal((10, dec_steps, embedding_dim)), xb.random.normal((10, enc_steps, embedding_dim))
v, mask = k, jnp.ones((10, dec_steps, enc_steps))

xb.pack(q, k, v, mask) >> atten >> xb.shapes

It works 😆. You can use `jax.pmap` instead of `jax.vmap` to run the batches in parallel over multiple GPUs. At some point, I'll build this into XABY to make it easy.

Okay, multi-headed attention is done. Each sub-module uses [layer normalization](https://arxiv.org/abs/1607.06450) on the output. I haven't implemented that in XABY yet, so add it here.

In [None]:
class layernorm(xb.Fn):
    def __init__(self, normalized_shape: Union[Tuple[int], int], epsilon=1e-5, elementwise_affine=True):
        
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        
        def layernorm_no_affine(x: xb.ArrayList, params: dict) -> xb.ArrayList:
            inputs, = x
            n_dims = len(inputs.shape)
            norm_axes = tuple(range(n_dims - len(normalized_shape), n_dims))
            
            mean = inputs.mean(axis=norm_axes, keepdims=True)
            var = jnp.mean((inputs - mean)**2, axis=norm_axes, keepdims=True)
            normed = (inputs - mean) / jnp.sqrt(var + epsilon)
            
            return xb.pack(normed)
        
        def layernorm_affine(x: xb.ArrayList, params: dict) -> xb.ArrayList:
            scale, bias = self.params["scale"], self.params["bias"]
            normed, = layernorm_no_affine(x, params)
            return xb.pack(normed * scale + bias)
        
        if elementwise_affine:
            super().__init__(jax.jit(layernorm_affine), 1, 1, "layernorm")
            self.params["scale"] = jnp.ones(normalized_shape)
            self.params["bias"] = jnp.zeros(normalized_shape)
        else:
            super().__init__(jax.jit(layernorm_no_affine), 1, 1, "layernorm")
            
        self.normalized_shape = normalized_shape
        self.elementwise_affine = elementwise_affine
        
    def __repr__(self):
        return f"layernorm({self.normalized_shape}, elementwise_affine={self.elementwise_affine})"

Finally we need to put together a mask for the attention operations. The mask is doing two things. Firstly it needs to tell the attention modules to ignore padding tokens. We can have padding in both the encoder and decoder sequences so 

In [None]:
def append_mask(decoder: bool = False):
    """ Returns a function that appends a mask for attention to an input ArrayList. 
        This function expects either one or two arrays: [queries] or [queries, keys].
        
        The default mask blocks the attention modules from using padding steps. It assumes the embedding
        vector for a padding input is all zeros.
        
        Arguments
        ---------
        decoder: optional, bool, returns a mask that blocks pad embeddings and leftward flowing information
        
    """
    def append_mask(x: xb.ArrayList, p: dict) -> xb.ArrayList:
        """ Creates a mask of all ones with shape [s_q, s_k] where s_q and s_k are the number of sequence
            steps in the queries and keys, respectively, then adds it to the input ArrayList
        """
        if len(x) == 1:
            q, k = x[0], x[0]
        else:
            q, k = x

        # Embeddings for padding tokens should be all zeros, so mask out where rows sum to 0
        q_sums = (q.sum(axis=1) != 0).astype(jnp.int8)
        k_sums = (k.sum(axis=1) != 0).astype(jnp.int8)
        padding_mask = (q_sums * k_sums.reshape(-1, 1))

        if decoder:
            # The decoder needs to mask out illegal connections to prevent leftward flowing information
            mask = jnp.tri(q.shape[0], k.shape[0]) * padding_mask
        else:
            mask = padding_mask
        return xb.pack(*x, mask)
    
    return xb.Fn(jax.jit(append_mask), 3, 4, name="append_mask")

Now to finish up the encoder.

In [None]:
def residual(func):
    res = xb.parallel(func, xb.skip) >> xb.add
    res.name = "residual"
    return res

def feedforward(model_dim, internal_dim):
    linear1 = xn.linear(model_dim, internal_dim, bias=False)
    linear2 = xn.linear(internal_dim, model_dim, bias=False)
    ff = linear1 >> xn.relu >> linear2 
    ff.name = "feedforward"
    return ff

def encoder_layer(embedding_dim: int, model_dim: int, n_heads: int, ff_dim: int, dropout: float=0.1) -> xb.Fn:
    """ Encoder layer for a transformer model. Expects one input array with shape [n_steps, embedding_dim] """
    
    multihead = multihead_attention(embedding_dim, model_dim, n_heads)
    attn = xb.parallel(xb.select(0,0,0,1) >> multihead >> xn.dropout(dropout), xb.select(0)) >> xb.add >> layernorm(model_dim)
    ff_layer = feedforward(model_dim, ff_dim) >> xn.dropout(dropout)
    
    layer = append_mask() >> attn >> residual(ff_layer) >> layernorm(model_dim)
    layer.name = "encoder_layer"
    
    return layer

def encoder(embedding_dim: int, model_dim: int, n_heads: int, ff_dim:int, n_layers: int, dropout: float=0.1) -> xb.Fn:
    """ Encoder for a transformer model. Expects one input array with shape [n_steps, embedding_dim] """
    
    layers = [encoder_layer(embedding_dim, model_dim, n_heads, ff_dim, dropout) for n in range(n_layers)]
    
    encoder = xb.sequential(*layers)
    encoder.name = "encoder"
    
    return encoder

In [None]:
enc_steps = 10
embedding_dim = model_dim = 16
ff_dim = 32
n_heads = 4
n_layers = 2

# Testing with padding tokens (0 by  default)
tokens = xb.array([1,2,3,4,5,0,0,0])
xb.pack(tokens) >> embedding(32, embedding_dim) >> encoder(embedding_dim, model_dim, n_heads, ff_dim, n_layers)

There we go, I have output from the encoder! Time for the decoder. A decoder layer is more complicated because there are two attention sub-layers. And we need different masks for both due to handling padding.

In [None]:
def decoder_layer(embedding_dim: int, model_dim: int, n_heads: int, ff_dim: int, dropout=0.1) -> xb.Fn:
    """ Expects two input arrays: target embeddings and the encoder output. """
    
    masked_attn = multihead_attention(embedding_dim, model_dim, n_heads)
    attn = multihead_attention(embedding_dim, model_dim, n_heads)
    
    # Self-attention sub-layer. Select queries from input arrays [target embeddings, encoder output]
    self_attn = append_mask(True) >> xb.select(0,0,0,1) >> masked_attn >> xn.dropout(dropout)
    self_attn_layer = residual(self_attn) >> layernorm(model_dim)
    # This returns a single array [self-attention]
    
    # Encoder attention sub-layer, this should accept two arrays [self-attention, encoder output]
    enc_attn = append_mask() >> xb.select(0,1,1,2) >> attn >> xn.dropout(dropout)
    enc_attn_layer = xb.parallel(enc_attn, xb.select(0)) >> xb.add >> layernorm(model_dim)
    
    # Feedforward sub-layer, expects one array from the encoder attn sub-layer
    ff_layer = residual(feedforward(model_dim, ff_dim) >> xn.dropout(dropout)) >> layernorm(model_dim)
    
    # Input is two arrays, pass the first one to the self-attention layer, then combine the outputs
    # and pass to the second attention sub-layer
    layer = xb.split(self_attn_layer, xb.skip) >> enc_attn_layer >> ff_layer
    layer.name = "decoder_layer"
    
    return layer

In [None]:
def decoder(embedding_dim: int, model_dim: int, n_heads: int, ff_dim:int, n_layers: int, dropout = 0.1) -> xb.Fn:
    """ Decoder for a transformer model. Expects two input arrays: target embeddings and the encoder output. """
    
    # We need to pass the encoder output to each layer, so use parallel and select to pass it along the layers
    layers = [xb.parallel(decoder_layer(embedding_dim, model_dim, n_heads, ff_dim, dropout), xb.select(1))
              for n in range(n_layers)]
    
    decoder = xb.sequential(*layers) >> xb.select(0)
    decoder.name = "decoder"
    
    return decoder

Okay, decoder is done. Now for the positional encodings...

In [None]:
class positional_encoding(xb.Fn):
    def __init__(self, embedding_dim: int, max_len=5000):
        
        pe = jnp.zeros((max_len, embedding_dim))
        position = jnp.expand_dims(np.arange(0, max_len), 1)
        div_term = jnp.exp(jnp.arange(0, embedding_dim, 2) * (-jnp.log(10000) / embedding_dim))
        pe = index_update(pe, index[:, ::2], jnp.sin(position * div_term))
        pe = index_update(pe, index[:, 1::2], jnp.cos(position * div_term))
        
        @jax.jit
        def pos_encoding(x: xb.ArrayList, params: dict) -> xb.ArrayList:
            embeddings, = x
            n_steps = embeddings.shape[0]
            encoded = embeddings + pe[:n_steps, :]
            return xb.pack(encoded)
        
        super().__init__(pos_encoding, 1, 1, "positional_encoding")

Cool cool, now we have the parts ready to put the whole transformer model together.

In [None]:
def transformer(num_embeddings: int,
                embedding_dim: int,
                n_heads: int,
                ff_dim: int,
                encoder_layers: int,
                decoder_layers: int,
                dropout: float=0.1) -> xb.Fn:
    """ Returns the transformer model which takes two arrays as input: [encoder inputs, decoder inputs]. 
        These inputs should be integer tokens.
    """
    
    # Use one embedding for both inputs
    embed = embedding(num_embeddings, embedding_dim)
    pos_encode = positional_encoding(embedding_dim)
    
    encode = encoder(embedding_dim, embedding_dim, n_heads, ff_dim, encoder_layers, dropout=dropout)
    decode = decoder(embedding_dim, embedding_dim, n_heads, ff_dim, decoder_layers, dropout=dropout)
    
    enc_input = embed >> pos_encode >> xn.dropout(dropout)
    dec_input = embed >> pos_encode >> xn.dropout(dropout)
    
    probabilities = xn.linear(embedding_dim, num_embeddings) >> xn.softmax(axis=1)
    
    # Weight tying
    embed.params["weights"] = probabilities.linear.params["weights"] * sqrt(embedding_dim) 
    
    # Assuming we get two arrays as input: [encoding inputs, decoder inputs]
    model = xb.split(enc_input >> encode, dec_input) >> decode >> probabilities
    model.name = "transformer"
    return model

In [None]:
# Setting this to True will compile all the function control flow (sequential, parallel, etc) 
# in the model. Compilation takes much longer but the model runs 3x faster.
xb.jit_combinators(False)

model = transformer(100, 512, 8, 1024, 6, 6)

Now we can test it out by passing in some test data. The function compiles the first time it's run, it can take a bit. But once that's done it runs fast. However, if the input shapes change, it needs to re-compile. This means in practice we would need to set a static size on the inputs and pad sequences that are shorter than this size. Since we're padding, we need to mask out padded entries.

In [None]:
source = xb.array([1, 3, 3, 7, 5, 7, 0, 0, 0])
targets = xb.array([12, 8, 10, 12, 11, 0, 0, 0, 0])

# Model compiles on the first run, can take a bit
xb.pack(source, targets) >> model

In [None]:
%%timeit
source = xb.random.randint((9,), minval=0, maxval=20)
targets = xb.random.randint((9,), minval=0, maxval=20)

(xb.pack(source, targets) >> model)[0].block_until_ready()

Okay, again, let's see if we can get this operating over a batch of input data...

In [None]:
model.forward = jax.vmap(model.forward, in_axes=(xb.pack(0, 0), None))

In [None]:
source = xb.array([[1, 3, 3, 7, 5, 7, 0, 0, 0],
                   [2, 3, 3, 4, 5, 7, 2, 4, 0],
                   [1, 3, 3, 7, 5, 7, 1, 0, 0],
                   [1, 3, 3, 7, 5, 0, 0, 0, 0]])
target = source * 2

# This takes a bit to compile. Uncomment to run it
# xb.pack(source, target) >> model >> xb.shapes

Alright, next up is putting together the loss function and writing the training loop. The loss is a pretty straightforward cross-entropy loss. But, a few things I need to take care of here. The model needs the target tokens shifted over by one place, with a "start-of-sentence" token at the beginning. The loss needs the original tokens. 

In [None]:
class shift_targets(xb.Fn):
    def __init__(self, start_token=1):
        
        @jax.jit
        def shift(x: xb.ArrayList, params: dict) -> xb.ArrayList:
            source, targets = x
            
            # Shift everything over one spot
            # Typically I would do this with some in-place operations. But you have to use
            # index_update with JAX for in-place things, and for some reason this causes 
            # an error when I vmap the top-level function
            shifted = jnp.concatenate([jnp.array([start_token]), targets[:-1]])
            
            return xb.pack(source, shifted)
        
        super().__init__(shift, 2, 2, "shift_targets")

def model_loss(model, smoothing=0.1):
    """ Returns a function that expects two input arrays: [source tokens, target tokens] """
    loss = xb.parallel(shift_targets() >> model, xb.select(1)) >> xn.cross_entropy_loss(smoothing=smoothing)
    loss = xb.set_meta(loss, name="model loss", n_inputs=2)

    return loss

In [None]:
source = xb.array([1, 3, 3, 7, 5, 7, 0, 0, 0])
targets = xb.array([12, 8, 10, 12, 11, 0, 0, 0, 0])

model = transformer(100, 512, 8, 1024, 6, 6)
loss = model_loss(model)

# This takes a bit to compile. Uncomment to run it
xb.pack(source, targets) >> loss

In [None]:
source = xb.array([[1, 3, 3, 7, 5, 7, 0, 0, 0],
                   [2, 3, 3, 4, 5, 7, 2, 4, 0],
                   [1, 3, 3, 7, 5, 7, 1, 0, 0],
                   [1, 3, 3, 7, 5, 0, 0, 0, 0]])
targets = source * 2

model = transformer(100, 512, 8, 1024, 6, 6)
loss = model_loss(model)
loss = xb.batchify(loss) >> xb.mean(axis=0)

xb.pack(source, targets) >> loss