Skip to content

jackd/simple-fast-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fast Attention: Causal Implementation Experiments

Having looked at google-research's blog post and tensorflow implementation on fast attention (FAVOR+), I was left scratching my head about the causal attention implementation. This repository investigates a simpler version.

TL;DR

  • Causal attention can be concisely expressed mathematically using properties of low rank matrices and hadamard products (see Theory section below).
  • ops/v1.py provide significantly simpler implementations that use neither loops over tensors nor custom gradients.
  • the implementations are much shorter (3 and 2 lines vs. 25 and 22 of the original), making it much easier to reason about
  • jit-compiling these operations is significantly faster than the originals (ops/v0.py)
  • results are the same between the two implementations
  • computation time using google-benchmark is indistinguishable

Theory

The google-ai blog post provides a visualisation of causal attention.

Causal Attention

It's not immediately apparent to me what's going on here, and looking at the code (originally here but with relevant part included here for convenience) doesn't make things any clearer.

My implementation (v1) takes a different approach.

The task we consider is to compute the noncausal numerator $N$, where

$N = \left[(Q K^T) \circ L\right] V$

where $Q$, $K$ and $V$ are the query, key and value matrices used in fast attention, $L$ is a lower triangular matrix with values of $1$ on and below the diagonal and $\circ$ is the Hadamard product (elementwise product). Noting that $Q$ and $K$ are low-rank (that's the whole point of performers/FAVOR), we can use the following handy dandy property of Hadamard products (Property 1):

$\left[A \circ \sum_j \mathbf{u}_j \mathbf{Pv}_j^T\right]x = \sum_j D(\mathbf{u}_j) A D(\mathbf{v}_j) \mathbf{x}$

where $D(\mathbf{z})$ is the diagonal matrix with diagonal values $\mathbf{z}$. This means we can express our fast causal attention output as

$N = \sum_m D(\mathbf{q}_m) L D(\mathbf{k}_m) V$

where $\mathbf{q}_m$ and $\mathbf{k}_m$ are the $m^\text{th}$ columns of Q and K respectively.

Note it is neither efficient nor necessary to compute any of the new matrices above. $D(\mathbf{k}_m) Z$ is just the scaling of rows of $Z$ by $\mathbf{k}_m$, while $L Z$ is the cumulative sum of $Z$ on the leading dimension. This results in a significantly simpler tensorflow implementation without the need to implement custom gradients or use python loops.

The implementation looks slighty different to the maths above because we compute $D(\mathbf{k}_m) V$ simultaneously for all $m$ and then combine scaling and reduction over $m$ simultaneously using tf.linalg.matvec.

def causal_numerator(qs: tf.Tensor, ks: tf.Tensor, vs: tf.Tensor):
    """Computes not-normalized FAVOR causal attention A_{masked}V.

    Args:
      qs: query_prime tensor of the shape [L,B,H,M].
      ks: key_prime tensor of the shape [L,B,H,M].
      vs: value tensor of the shape [L,B,H,D].

    Returns:
      Not-normalized FAVOR causal attention A_{masked}V.
    """
    # rhs = tf.einsum('lbhm,lbhd->lbhdm', ks, vs)
    rhs = tf.expand_dims(ks, axis=-2) * tf.expand_dims(vs, axis=-1)  # [L,B,H,D,M]
    rhs = tf.cumsum(rhs, axis=0)
    # return tf.einsum('lbhm,lbhdm->lbhd', qs, rhs)
    return tf.linalg.matvec(rhs, qs)

That's a 3-line implementation, as opposed to the 25 used in the original.

Denominator

The noncausal denominator function is conceptually the same as the numerator except using the ones vector for $V$. Since the first operation involves scaling $V$, we can skip this entirely and just use the keys ks:

def causal_denominator(qs, ks):
    """Computes FAVOR normalizer in causal attention.

    Args:
      qs: query_prime tensor of the shape [L,B,H,M].
      ks: key_prime tensor of the shape [L,B,H,M].

    Returns:
      FAVOR normalizer in causal attention.
    """
    rhs = tf.cumsum(ks, axis=0)
    return tf.einsum("lbhm,lbhm->lbh", qs, rhs)

That's 2 lines compared to 22 in the original.

Benchmarks

The following benchmarks were run on my fairly old laptop with a 1050-Ti. Note that I trust the timings from google-benchmark a lot more than I do from tfbm / tf.test.Benchmark, but maybe I'm just misinterpretting them.

google-benchmark

gbenchmark.py uses google-benchmark. The output is a lot simpler compared to tf.test.Benchmark which means I trust these timings a lot more.

Take-aways:

  • There isn't much difference between v0 and v1 implementations in terms of computation time
  • v1 implementations warm-up significantly faster
  • jit compilation significantly reduces forward time on cpu, but is negligible on gpu
python gbenchmark.py
--------------------------------------------------------------
Benchmark                    Time             CPU   Iterations
--------------------------------------------------------------
v0_forward-cpu         5403096 ns       364764 ns         1000
v1_forward-cpu         5419832 ns       365650 ns         1000
v0_backward-cpu         268558 ns       238634 ns         2896
v1_backward-cpu         267089 ns       235842 ns         2937
v0_forward-gpu          288531 ns       241580 ns         2874
v1_forward-gpu          285695 ns       238078 ns         2908
v0_backward-gpu         268220 ns       237309 ns         2869
v1_backward-gpu         268324 ns       240429 ns         2751
v0_forward-cpu-jit      299143 ns       271613 ns         2516
v1_forward-cpu-jit      291873 ns       269618 ns         2538
v0_backward-cpu-jit     303150 ns       275359 ns         2483
v1_backward-cpu-jit     303948 ns       276806 ns         2482
v0_forward-gpu-jit      278147 ns       277842 ns         2450
v1_forward-gpu-jit      276128 ns       275956 ns         2523
v0_backward-gpu-jit     256809 ns       256798 ns         2706
v1_backward-gpu-jit     252543 ns       252537 ns         2769

Warmup time for v0_forward-cpu: 6.56445574760437
Warmup time for v1_forward-cpu: 0.1015627384185791
Warmup time for v0_backward-cpu: 22.0670325756073
Warmup time for v1_backward-cpu: 0.08140373229980469
Warmup time for v0_forward-gpu: 6.233572244644165
Warmup time for v1_forward-gpu: 0.028412342071533203
Warmup time for v0_backward-gpu: 22.226712226867676
Warmup time for v1_backward-gpu: 0.051419734954833984
Warmup time for v0_forward-cpu-jit: 6.481787443161011
Warmup time for v1_forward-cpu-jit: 0.05790424346923828
Warmup time for v0_backward-cpu-jit: 24.72081184387207
Warmup time for v1_backward-cpu-jit: 0.09151363372802734
Warmup time for v0_forward-gpu-jit: 8.328083515167236
Warmup time for v1_forward-gpu-jit: 0.08592033386230469
Warmup time for v0_backward-gpu-jit: 24.7033634185791
Warmup time for v1_backward-gpu-jit: 0.12377095222473145

tfbm

benchmark.py requires tfbm, a wrapper around tf.test.Benchmark implementations.

pip install git+https://github.com/jackd/tfbm
Results for cls=Forward
Uniform results:
+--------+---------+-------+
| run_id |     cls | iters |
+--------+---------+-------+
|    NOW | Forward |    10 |
+--------+---------+-------+
Varied results:
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|               test | wall_time (us) | device | max_mem_GPU_0_bfc (Mb) | max_mem_cpu (b) | max_mem_gpu_host_bfc (b) | xla_jit |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_forward_xla_gpu |        218.511 |    gpu |                    --- |         192.000 |                   49.000 |    True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_forward_xla_gpu |        260.711 |    gpu |                    --- |         192.000 |                   49.000 |    True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_forward_xla_cpu |        272.274 |    cpu |                    --- |         192.000 |                   49.000 |    True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_forward_xla_cpu |        284.910 |    cpu |                    --- |         192.000 |                   49.000 |    True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v0_forward_gpu |      20148.039 |    gpu |               2563.000 |         192.000 |                 1088.000 |   False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v1_forward_gpu |      58524.966 |    gpu |                516.000 |         192.000 |                   64.000 |   False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v0_forward_cpu |     100526.690 |    cpu |                    --- |  1776322300.000 |                      --- |   False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v1_forward_cpu |     233323.097 |    cpu |                    --- |   541065220.000 |                      --- |   False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results for cls=Backward
Uniform results:
+--------+----------+-------+
| run_id |      cls | iters |
+--------+----------+-------+
|    NOW | Backward |    10 |
+--------+----------+-------+
Varied results:
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|                test | wall_time (us) | device | max_mem_GPU_0_bfc (Mb) | max_mem_cpu (b) | max_mem_gpu_host_bfc (b) | xla_jit |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_backward_xla_gpu |        225.782 |    gpu |                    --- |         192.000 |                   49.000 |    True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_backward_xla_gpu |        231.147 |    gpu |                    --- |         192.000 |                   49.000 |    True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_backward_xla_cpu |        239.730 |    cpu |                    --- |         192.000 |                   49.000 |    True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v0_backward_gpu |      47311.544 |    gpu |               2641.594 |         192.000 |                 1084.000 |   False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_backward_xla_cpu |      77256.083 |    cpu |                    --- |   288358548.000 |                   49.000 |    True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v1_backward_gpu |     110480.189 |    gpu |                788.031 |         192.000 |                   68.000 |   False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v0_backward_cpu |     260773.897 |    cpu |                    --- |  1936089700.000 |                      --- |   False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
|     v1_backward_cpu |     524332.523 |    cpu |                    --- |  1092648988.000 |                      --- |   False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+

About

Experiments with simplified fast causal attention

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages