# Mirco-Benchmarking for Transformers on Apple Silicon

This notebook benchmark the key components of various transformer models (BERT, GPT, T5) using Apple Silicon. Please install Pytorch>=1.13 that supports acceleration on Apple Silicon (M1/M2/M3)

In [1]:
import torch

print('Pytorch version\t:', torch.__version__)
print('Apple Silicon\t:',torch.backends.mps.is_available())

Pytorch version	: 2.2.1
Apple Silicon	: True


We first define a `walltime` method to benchmark Pytorch statements by at least 5 seconds.

In [2]:
import inspect
from collections import defaultdict
import pandas as pd
from torch.utils import benchmark 

pd.options.display.precision = 3

def var_dict(*args):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    return dict([(name, val) for name, val in callers_local_vars if val is arg][0] 
                for arg in args)

def walltime(stmt, arg_dict, duration=5):
    return benchmark.Timer(stmt=stmt, globals=arg_dict).blocked_autorange(
        min_run_time=duration).median

## Matrix Multiplication

Matrix multiplication is the most frequently used operator in transformer models. Its performance is crucial. Let's test the [FLOPS](https://en.wikipedia.org/wiki/FLOPS) we can achieve on square matrices.

In general, the number of FLOP operations by multiplying two matrix $M_{a*c}$ and $N_{c*b}$ is on the order of

$$ a*b*(2*c) = 2abc $$

since each element of the mutliplication matrix $X_{a*c} = M_{a*c} * N_{c*b}$ has $2c-1$ operations:

$$X[i][j] = M[i][0] * N[0][j] + M[i][1] * N[1][j] + \cdots + M[i][c-1] * N[c-1][n] $$


In [6]:
matmul_tflops = defaultdict(lambda: {})
for n in [128, 512, 2048]:
    for dtype in (torch.float32, torch.float16):
        a = torch.randn(n, n, dtype=dtype, device='mps')
        b = torch.randn(n, n, dtype=dtype, device='mps')   
        t = walltime('a @ b', var_dict(a, b))
        matmul_tflops[f'n={n}'][dtype] = 2*n**3 / t / 1e12
        del a, b
        
pd.DataFrame(matmul_tflops)

Unnamed: 0,n=128,n=512,n=2048
torch.float32,0.229,6.185,8.599
torch.float16,0.229,7.153,10.571


One can find the theory TFLOPS of your accelerator from Wikipedia, for example, [Apple Silicon](https://en.wikipedia.org/wiki/Apple_silicon), [Nvidia Tesla](https://en.wikipedia.org/wiki/Ampere_(microarchitecture)) and [RTX 30xx](https://en.wikipedia.org/wiki/GeForce_30_series)

The following is a list several accelerators of interests, with their memory information.

| Model       | Memory (GB) | Memory Bandwidth (GB/sec) | FP32 TFLOPS | FP16 TFLOPS |
| ----------- | ----------- | ------------------------- | ----------- | ----------- |
| Apple M3 Max (30-core GPU)       | Shared        | 300                      | 10.6        | 21.2        |
| Nvidia T4  (AWS G4)      | 16          | 320                       | 8.1        | 65.1         |
| Nvidia A10G  (AWS G5)      | 24          | 600                       | 31.5        | 31.5         |
| RTX 3090 | 24          | 936                      | 35.6          | 35.6         |


## Memory Bandwith

If the best TFLOPS is far away from the theory TFLOPS of the GPU spec sheet, the performance is likely bottlenecked by the memory bandwidth. To illustrate it, let's benchmark a simple elemental-wise multiplication to show both its TFLOPS with memory bandwidth. 

In [9]:
vector = defaultdict(lambda: {})
for n in [1024*64, 1024*256, 1024*512]:
    a = torch.randn(n,device='mps')
    t = walltime('a * 1.2', var_dict(a))
    vector[n]['TFLOPS'] = n / t / 1e12
    vector[n]['GB/s'] = 8 * n / t / 1e9
    
pd.DataFrame(vector)

Unnamed: 0,65536,262144,524288
TFLOPS,0.003,0.013,0.027
GB/s,26.875,105.185,214.918


Even for large vectors, the TFLOPS is far far way from GPU peak performance, while the bandwidth may be quite close to its theoretical number.

The matrix multiplication performance is a main topic in HPC and there are a large number of research papers and various implementation, both open-source or proprietary 


## BERT Layer

The main body of a transformer model is a stacking of transformer blocks. Let's benchmark the performance of a single block. 

In BERT, the single block is called a BERT layer. We can construct one such layer from the [BERT large model](https://huggingface.co/bert-large-uncased). We use 16-bit floating points for better performance. 

In [10]:
from transformers import AutoConfig, BertLayer

config = AutoConfig.from_pretrained("bert-large-uncased")
layer = BertLayer(config).half().to('mps')

We can define a function to benchmark both forward and forward with backward performance using different hyperparameters (sequence lengths and batch sizes) 

In [11]:
def layer_benchmark(layer, hidden_size, seq_lens, batch_sizes, cross_attention=False):
    h = hidden_size
    results = defaultdict(lambda: {})    
    encoder_state = 'encoder_hidden_states=X' if cross_attention else ''
    for s in seq_lens:
        for b in batch_sizes:            
            ffn = 16*b*s*h*h / 1e12  # TFLOP for the Feed-Forward Network
            atten = (4*b*h*s*s + 8*b*s*h*h) / 1e12  # TFLOP for attention            
            forward = ffn + (2 if cross_attention else 1) * atten
            
            X = torch.randn(b, s, h, device='mps').half()
            results[f'batch={b}'][f'fwd seq_len={s}'] = forward / walltime(
                f'layer(X, {encoder_state})', var_dict(layer, X))
            results[f'batch={b}'][f'fwd+bwd seq_len={s}'] = 3 * forward / walltime(
                f'layer(X, {encoder_state})[0].sum().backward()', var_dict(layer, X))            
    return pd.DataFrame(results)

In BERT pre-training, we often train with a sequence of 128 (stage 1) or 512 (stage 2). Let's test its performance. 

In [12]:
layer_benchmark(layer, config.hidden_size, [128, 512], [2, 4, 8, 16, 32, 64, 128])

Unnamed: 0,batch=2,batch=4,batch=8,batch=16,batch=32,batch=64,batch=128
fwd seq_len=128,5.323,6.384,7.259,6.921,7.171,7.13,6.607
fwd+bwd seq_len=128,5.919,6.861,7.535,7.921,8.174,7.917,7.772
fwd seq_len=512,5.632,5.898,5.94,5.928,5.827,5.859,5.692
fwd+bwd seq_len=512,6.789,7.017,7.03,7.101,7.219,7.038,6.972


As expected, a large batch size helps. But the best number is below the matrix multiplication TFLOPS. Why is that?

Let's benchmark the first dense layer in the Feed-Forward Network (FFN) in the layer. 

In [15]:
h, b, s = config.hidden_size, 64, 128
X = torch.randn(b, s, h, device='mps').half()

'Dense layer TFLOPS: %.3f' % (8*b*s*h*h / 1e12 / walltime(    
    'layer.intermediate.dense(X)', var_dict(layer, X)))

'Dense layer TFLOPS: 9.370'

The number is pretty good. Let's next run this dense layer with the GeLU activation.

In [14]:
'Dense+Activation TFLOPS: %.3f' % (8*b*s*h*h / 1e12 / walltime(
    'layer.intermediate(X)', var_dict(layer, X)))

'Dense+Activation TFLOPS: 8.761'

Even the activation function has a ignorable complexity, it brings down the TFLOPS. We pointed out the reason before, the elemental-wise operation of the activation function is bounded by the memory bandwidth.

Now test the whole FFN.

In [16]:
ffn = 16*b*s*h*h / 1e12
'FFN TFLOPS: %.3f'%(ffn / walltime(
    'layer.output(layer.intermediate(X),X)', var_dict(layer, X)))

'FFN TFLOPS: 8.426'

The other part in the BERT layer is the multi-head self-attention.

In [17]:
att = (4*b*h*s*s + 8*b*s*h*h) / 1e12
'Attention TFLOPS: %.3f'%(
    att / walltime('layer.attention(X)', var_dict(layer, X)))

'Attention TFLOPS: 5.887'

Even though the main computation part of the attention block is still matrix multiplication, it has more memory bounded operators compared to FFN. So you see a lower TFLOPS.

The ratio of complexity between attention and FFN depends on the BERT configuration. The overall performance is a weighted sum between the FLOPS of these two components.

## GPT-2 Block

Next let's evaluate `gpt2-medium`, which has a similar architecture has `bert-large`, i.e. 24 layers with a 1024 hidden size. GPT2 is trained with a 1024 sequence length.

In [18]:
from transformers.models.gpt2.modeling_gpt2 import GPT2Block

config = AutoConfig.from_pretrained("gpt2-medium")
layer = GPT2Block(config, layer_idx=0).half().to('mps')
layer_benchmark(layer, config.n_embd, [512, 1024], [2, 4, 8, 16, 32, 64])

Unnamed: 0,batch=2,batch=4,batch=8,batch=16,batch=32,batch=64
fwd seq_len=512,4.945,5.027,5.054,5.099,5.102,5.115
fwd+bwd seq_len=512,5.535,5.699,5.813,5.869,5.872,5.924
fwd seq_len=1024,4.326,4.341,4.397,4.404,4.414,4.15
fwd+bwd seq_len=1024,5.073,5.166,5.227,5.243,5.294,5.06


We can see that, despite GPT-2 and BERT has the same complexity, GPT-2 has slightly worse TFLOPS when using the same batch size and sequence length. 

Also using a larger sequence length 1024 further harms the performance.

## T5 Layer

T5 has both encoder and decoder, let's first benchmark the encoder, whose performance is similar to BERT.

In [19]:
from transformers.models.t5.modeling_t5 import T5Block

config = AutoConfig.from_pretrained("t5-large")
config.use_cache = False
config.is_decoder = False
config.is_encoder_decoder = False

encoder = T5Block(config).half().to('mps')
layer_benchmark(encoder, config.d_model, [512], [2, 4, 8, 16, 32, 64, 128])

Unnamed: 0,batch=2,batch=4,batch=8,batch=16,batch=32,batch=64,batch=128
fwd seq_len=512,4.48,4.754,4.811,4.833,4.806,4.876,4.814
fwd+bwd seq_len=512,5.543,5.698,5.79,5.826,5.8,5.85,5.696


The decoder has an additional cross attention, which increases the time complexity and also hurts TFLOPS.

In [20]:
config.is_decoder = True
decoder = T5Block(config).half().to('mps')
layer_benchmark(decoder, config.d_model, [512], [2, 4, 8, 16, 32, 64, 128], cross_attention=True)

Unnamed: 0,batch=2,batch=4,batch=8,batch=16,batch=32,batch=64,batch=128
fwd seq_len=512,4.093,4.224,4.261,4.292,4.283,4.314,4.307
fwd+bwd seq_len=512,5.085,5.203,5.287,5.321,5.335,5.326,5.312


## Conclusion

To conclude, to achieve the best performance for a Transformer layer, you need to use a fast data type (FP16) and a large batch size. 

For further improvement, we may need to rewrite the code. For example, [fusing](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#fuse-pointwise-operations) multiple kernels into a single one with `@torch.jit.script` decorator.

```python
@torch.jit.script
def fused_gelu(x):
        return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
```