# Self-attention does not need O(n^2) memory

This colab accompanies our [paper](https://arxiv.org/abs/2112.05682) on a memory-efficient implementation of (self-)attention. It contains the standard attention implementation, our memory-efficient attention implementation, and evaluation code to determine and compare their runtime performance.

Please remember to connect to a colab runtime with a TPU. You can check whether the runtime you are using has a TPU, by hovering over the RAM and Disk indicator in the top right.

In case you have questions, please send us an email to mrabe@google.com and cstaats@google.com.


## License

Copyright 2021 Google LLC.

SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

## Imports and utilities

In [None]:
from jax import lax
from typing import Any, Optional
from jax import numpy as jnp
import numpy as np
import time
import jax
import timeit
import functools

import jax.tools.colab_tpu
try:
  jax.tools.colab_tpu.setup_tpu()
except KeyError:
  print('Could not find a TPU; running this colab on CPU or GPU.')

In [None]:
# Utility to create random data

_cur_key = jax.random.PRNGKey(4)

def fresh():
  global _cur_key
  _cur_key, result = jax.random.split(_cur_key)
  return result

num_heads = 1
feature_dims = 64

def fresh_qkv(size, dtype=jnp.bfloat16):
  qkv_shape = (size, num_heads, feature_dims)
  return jax.random.normal(fresh(), qkv_shape, dtype=dtype)

fresh_qkv(2)

DeviceArray([[[-1, 0.201172, -1, 0.984375, -1.40625, 0.34375, 0.0439453,
               0.65625, 0.679688, 0.894531, -1, 0.261719, 0.515625,
               -0.824219, -0.192383, 0.494141, -0.0932617, 1.01562,
               0.162109, 0.0439453, 1.19531, 0.8125, -0.3125, -1.17969,
               0.222656, -1.10156, 1.54688, 0.283203, 0.261719,
               0.515625, 1.54688, -0.597656, -1.65625, -0.251953,
               -0.851562, -1.03125, -1.73438, 0.632812, 1.49219,
               -0.398438, -0.3125, 0.679688, -0.851562, 0.103027,
               1.70312, 0.957031, 0.322266, 0.103027, 0.921875,
               -2.10938, 1.125, -0.0932617, -0.152344, -0.271484,
               0.0439453, 0.241211, -0.192383, 0.0634766, 0.283203,
               1.78906, -0.353516, 0.957031, -0.851562, 0.957031]],

             [[-1.14062, -0.333984, -0.667969, -2.32812, -1.17969,
               0.408203, -1.10156, -0.353516, 0.757812, -0.824219,
               -0.527344, 0.609375, 0.8125, -0.769531, -1

## Standard attention

In [None]:
# Based on flax/linen/attention.py

def standard_attention(query, key, value, dtype=jnp.float32, precision=None):
  depth = query.shape[-1]
  query = query / jnp.sqrt(depth).astype(dtype)
  # attn weight shape is (batch..., num_heads, q_length, kv_length)
  attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key,
                            precision=precision)

  # normalize the attention weights
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
  return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value,
                    precision=precision)

## Memory-efficient Self-Attention

In [None]:
# This cell is self-contained; the following imports suffice to run the
# memory-efficient attention implementation.
import functools, jax, math
from jax import lax
from jax import numpy as jnp


def _query_chunk_attention(query,
                           key,
                           value,
                           key_chunk_size=4096,
                           precision=lax.Precision.HIGHEST,
                           dtype=jnp.float32):
  num_kv, num_heads, k_features = key.shape
  v_features = value.shape[-1]
  key_chunk_size = min(key_chunk_size, num_kv)
  query = query / jnp.sqrt(k_features).astype(dtype)

  @functools.partial(jax.checkpoint, prevent_cse=False)
  def summarize_chunk(query, key, value):
    attn_weights = jnp.einsum(
        'qhd,khd->qhk', query, key, precision=precision).astype(dtype)
    max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
    max_score = jax.lax.stop_gradient(max_score)
    exp_weights = jnp.exp(attn_weights - max_score)
    exp_values = jnp.einsum(
        'vhf,qhv->qhf', value, exp_weights, precision=precision).astype(dtype)
    return (exp_values, exp_weights.sum(axis=-1),
            max_score.reshape((query.shape[0], num_heads)))

  def chunk_scanner(chunk_idx):
    key_chunk = lax.dynamic_slice(
        key, (chunk_idx, 0, 0),
        slice_sizes=(key_chunk_size, num_heads, k_features))
    value_chunk = lax.dynamic_slice(
        value, (chunk_idx, 0, 0),
        slice_sizes=(key_chunk_size, num_heads, v_features))
    return summarize_chunk(query, key_chunk, value_chunk)

  chunk_values, chunk_weights, chunk_max = lax.map(
      chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))

  global_max = jnp.max(chunk_max, axis=0, keepdims=True)
  max_diffs = jnp.exp(chunk_max - global_max)
  chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
  chunk_weights *= max_diffs

  all_values = chunk_values.sum(axis=0)
  all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
  return all_values / all_weights


def mefficient_attention(query,
                         key,
                         value,
                         query_chunk_size=1024,
                         precision=jax.lax.Precision.HIGHEST,
                         dtype=jnp.float32):
  num_q, num_heads, q_features = query.shape

  def chunk_scanner(chunk_idx, _):
    query_chunk = lax.dynamic_slice(
        query, (chunk_idx, 0, 0),
        slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features))
    return (chunk_idx + query_chunk_size,
            _query_chunk_attention(
                query_chunk, key, value, precision=precision, dtype=dtype))

  _, res = lax.scan(
      chunk_scanner,
      init=0,
      xs=None,
      length=math.ceil(num_q / query_chunk_size))
  return res.reshape(num_q, num_heads, value.shape[-1])

## Evaluation

In [None]:
# Compare inference performance

# The evaluation uses mixed-precision by default (bfloat16 for the inputs and
# outputs, and float32 for certain internal representations.)
input_dtype = jnp.bfloat16
dtype = jnp.float32
# Using HIGHEST means that we use full float32 precision. For neural network 
# training we can often use lax.Precision.DEFAULT instead.
precision = lax.Precision.HIGHEST

execute_standard_att = True
execute_memory_efficient_att = True
repeats = 50

for i in range(8, 15, 2):
  q_size = 2**i
  memsize = 2**i
  print('\nAttention size:', q_size, 'x', memsize)
  query, key, value = fresh_qkv(q_size, input_dtype), fresh_qkv(memsize, input_dtype), fresh_qkv(memsize, input_dtype)

  if execute_standard_att:
    _orig_attn = functools.partial(standard_attention,
                      precision=precision,
                      dtype=dtype)
    standard_attn = jax.jit(_orig_attn)

    compilation_start = time.time()
    compilation_res = standard_attn(query, key, value)
    compilation_res.block_until_ready()
    print('Standard compilation time:', time.time() - compilation_start)

    total_time = 0.0
    for _ in range(repeats):
      start = time.time()
      res_std = standard_attn(query, key, value)
      res_std.block_until_ready()
      # print('Time of op:', time.time() - start)
      total_time += (time.time() - start)
    total_time = total_time / repeats
    print('Standard attention took:', total_time)

  if execute_memory_efficient_att:
    _memory_efficient_attn = functools.partial(
        mefficient_attention,
        precision=precision,
        dtype=dtype)
    mefficient_attn = jax.jit(_memory_efficient_attn)
    
    compilation_start = time.time()
    compilation_res = mefficient_attn(query, key, value)
    compilation_res.block_until_ready()
    print('Memory-efficient attention compilation time:', time.time() - compilation_start)

    total_time_mem = 0.0
    for _ in range(repeats):
      start = time.time()
      res = mefficient_attn(query, key, value)
      res.block_until_ready()
      total_time_mem += (time.time() - start)
    total_time_mem = total_time_mem / repeats
    print('Memory-efficient attention took:', total_time_mem)

  if execute_standard_att and execute_memory_efficient_att:
    print('Performance advantage:', (total_time / total_time_mem) - 1.0)
    diff = res - res_std
    print('avg difference', jnp.average(jnp.abs(diff)))
    print('max difference', jnp.max(jnp.abs(diff)))
    np.testing.assert_allclose(res.astype(jnp.float32), res_std.astype(jnp.float32), atol=1e-2)


Attention size: 256 x 256
Standard compilation time: 0.10895419120788574
Standard attention took: 0.0018677806854248047
Memory-efficient attention compilation time: 0.31549072265625
Memory-efficient attention took: 0.001740708351135254
Performance advantage: 0.07300035885200229
avg difference 7.912317e-09
max difference 8.940697e-08

Attention size: 1024 x 1024
Standard compilation time: 0.18660283088684082
Standard attention took: 0.0019140434265136719
Memory-efficient attention compilation time: 0.3578367233276367
Memory-efficient attention took: 0.0018649005889892578
Performance advantage: 0.02635145155434193
avg difference 5.8978773e-09
max difference 8.940697e-08

Attention size: 4096 x 4096
Standard compilation time: 0.5202584266662598
Standard attention took: 0.0029724836349487305
Memory-efficient attention compilation time: 0.4849257469177246
Memory-efficient attention took: 0.0031867408752441406
Performance advantage: -0.06723396996594388
avg difference 5.1832956e-09
max diff

In [None]:
# Compare differentiation performance

repeats = 50

execute_standard_att = True
execute_memory_efficient_att = True

input_dtype = jnp.bfloat16
dtype = jnp.float32
precision = lax.Precision.HIGHEST


def loss_ckpt(query, key, value):
  return jnp.sum(mefficient_attention(query, key, value, precision=precision))

def loss_simp(query, key, value):
  return jnp.sum(standard_attention(query, key, value, precision=precision))

diff_mefficient_attention = jax.jit(jax.grad(loss_ckpt, argnums=[0,1,2]))
diff_attention_simp = jax.jit(jax.grad(loss_simp, argnums=[0,1,2]))

for i in range(8, 15, 1):
  q_size = 2**i
  memsize = 2**i
  print('\nAttention size:', q_size, 'x', memsize)

  query = fresh_qkv(q_size, input_dtype)
  key = fresh_qkv(memsize, input_dtype)
  value = fresh_qkv(memsize, input_dtype)

  if execute_standard_att:
    compilation_start = time.time()
    _comp_res = diff_attention_simp(query, key, value)
    for t in _comp_res:
      t.block_until_ready()
    print('Diff simp compilation time:', time.time() - compilation_start)

    total_time_simp = 0.0
    for _ in range(repeats):
      start = time.time()
      res_std = diff_attention_simp(query, key, value)
      for t in res_std:
        t.block_until_ready()
      total_time_simp += (time.time() - start)
    total_time_simp = total_time_simp / repeats
    print('Standard attention took:', total_time_simp)

  if execute_memory_efficient_att:
    compilation_start = time.time()
    _comp_res = diff_mefficient_attention(query, key, value)
    for t in _comp_res:
      t.block_until_ready()
    print('Diff mem ckpt compilation time:', time.time() - compilation_start)

    total_time_mem = 0.0
    for _ in range(repeats):
      start = time.time()
      res = diff_mefficient_attention(query, key, value)
      for t in res:
        t.block_until_ready()
      total_time_mem += (time.time() - start)
    total_time_mem = total_time_mem / repeats
    print('Memory-efficient attention took:', total_time_mem)

  if execute_standard_att and execute_memory_efficient_att:
    print('Performance advantage:', (total_time_simp / total_time_mem) - 1.0)
    diff = res[0] - res_std[0]
    print('avg difference', jnp.average(jnp.abs(diff)))
    print('max difference', jnp.max(jnp.abs(diff)))
    for tuple_idx in range(3):
      np.testing.assert_allclose(
          res[tuple_idx].astype(jnp.float32),
          res_std[tuple_idx].astype(jnp.float32), atol=1e-2, rtol=1e-2)


Attention size: 256 x 256
Diff simp compilation time: 0.2330009937286377
Standard attention took: 0.0023757553100585936
Diff mem ckpt compilation time: 1.592576503753662
Memory-efficient attention took: 0.0025052547454833983
Performance advantage: -0.051691124688326706
avg difference 3.00352e-08
max difference 0.000488281

Attention size: 512 x 512
Diff simp compilation time: 0.2673332691192627
Standard attention took: 0.003413748741149902
Diff mem ckpt compilation time: 1.7921810150146484
Memory-efficient attention took: 0.002515106201171875
Performance advantage: 0.3572980495055511
avg difference 9.66338e-13
max difference 2.98023e-08

Attention size: 1024 x 1024
Diff simp compilation time: 0.37073850631713867
Standard attention took: 0.002552280426025391
Diff mem ckpt compilation time: 1.660691261291504
Memory-efficient attention took: 0.002996225357055664
Performance advantage: -0.14816807086451256
avg difference 9.38599e-10
max difference 6.10352e-05

Attention size: 2048 x 2048
