<a href="https://colab.research.google.com/github/felix0901/flat_prototype/blob/master/flat_prototype.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Pip Install

In [9]:
# !pip install --upgrade pip
# !pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

#### Import

In [10]:
from jax import custom_vjp
import jax
import jax.numpy as jnp
import numpy as np
import gc
import time
import pandas as pd

#### These are the Logit and Attend function


In [11]:

@jax.jit
def matmul(A, B):
  return jnp.matmul(A, B)


@jax.jit
def matmul3(A, B, C):
  O = jnp.matmul(A, B)
  O = jax.nn.softmax(O, axis=-1)
  return jnp.matmul(O, C)


def einsum_layer(query, key, output, B=1, H=1, R=1):
  assert len(key.shape) == len(query.shape) == 4
  batch, head, hidden,seq_k  = key.shape
  batch, head, seq_q, hidden = query.shape
  for bt in range(0, batch, B):
    for ht in range(0, head, H):
      for mt in range(0, seq_q, R):
        out = matmul(jnp.array(query[bt:bt+B, ht:ht+H, mt:mt+R, :]), jnp.array(key[bt:bt+B, ht:ht+H, :, :]))
        output[bt:bt+B, ht:ht+H, mt:mt+R, :] = out
  return output


def L_softmax_A(query, key, value, output, B=1, H=1, R=1):
  '''
  Baseline: Un-fused Logit-Softmax-Attend
  
  B: Batch graularity
  H: Head granularity
  R: Row granularity
  '''
  assert len(key.shape) == len(query.shape) == len(query.shape) == 4
  batch, head, hidden, seq_k  = key.shape
  batch, head, seq_q, hidden = query.shape
  batch, head, seq_v, hidden = value.shape
  intermediate = np.zeros((batch, head, seq_q, seq_k))
  assert seq_k == seq_v
  intermediate = einsum_layer(query, key, intermediate, B=B, H=H, R=R)
  intermediate = jax.nn.softmax(jnp.array(intermediate), axis=-1)
  output = einsum_layer(np.array(intermediate), value, output, B=B, H=H, R=R)
  return output


def fused_L_softmax_A(query, key, value, output, B=1, H=1, R=1):
  '''
  FLAT: fused Logit-Softmax-Attend

  B: Batch graularity
  H: Head granularity
  R: Row granularity
  '''
  assert len(key.shape) == len(query.shape) == len(query.shape) == 4
  batch, head, hidden,seq_k  = key.shape
  batch, head, seq_q, hidden = query.shape
  batch, head, seq_v, hidden = value.shape
  assert seq_k == seq_v
  for bt in range(0, batch, B):
    for ht in range(0, head, H):
      for mt in range(0, seq_q, R):
        out = matmul3(jnp.array(query[bt:bt+B, ht:ht+H, mt:mt+R, :]), jnp.array(key[bt:bt+B, ht:ht+H, :, :]), jnp.array(value[bt:bt+B, ht:ht+H, :, :]))
        output[bt:bt+B, ht:ht+H, mt:mt+R, :] = out
  return output



 

#### Sweep Batch size

In [12]:
#Baseline
head = 12
seq_k = seq_v = seq_q = seq =256
hidden = 768
print('=========Baseline============')
for batcht in range(8):
    batch = 2** batcht
    key = np.ones((batch, head, hidden, seq_k))
    query = np.ones((batch, head, seq_q, hidden))
    value = np.ones((batch, head, seq_v, hidden))
    output = np.zeros((batch, head, seq_q, hidden))
    B = batch
    H = head
    R = seq_q
    print(f'Running Model Batch-{batch}, Head-{head}, Seq-{seq_q}, Hidden-{hidden}, with granularity B-{B}, H-{H}, R-{R}')
    L_softmax_A(query, key, value, output, B, H, R)
    timing = %timeit -o L_softmax_A(query, key, value, output, B, H, R)
    del key, query, value, output
    gc.collect()


Running Model Batch-1, Head-12, Seq-256, Hidden-768, with granularity B-1, H-12, R-256
10 loops, best of 5: 30.9 ms per loop
Running Model Batch-2, Head-12, Seq-256, Hidden-768, with granularity B-2, H-12, R-256
10 loops, best of 5: 61.1 ms per loop
Running Model Batch-4, Head-12, Seq-256, Hidden-768, with granularity B-4, H-12, R-256
10 loops, best of 5: 119 ms per loop
Running Model Batch-8, Head-12, Seq-256, Hidden-768, with granularity B-8, H-12, R-256
1 loop, best of 5: 230 ms per loop
Running Model Batch-16, Head-12, Seq-256, Hidden-768, with granularity B-16, H-12, R-256
1 loop, best of 5: 458 ms per loop
Running Model Batch-32, Head-12, Seq-256, Hidden-768, with granularity B-32, H-12, R-256
1 loop, best of 5: 918 ms per loop
Running Model Batch-64, Head-12, Seq-256, Hidden-768, with granularity B-64, H-12, R-256
1 loop, best of 5: 1.8 s per loop
Running Model Batch-128, Head-12, Seq-256, Hidden-768, with granularity B-128, H-12, R-256
1 loop, best of 5: 4.38 s per loop


In [13]:
head = 12
seq_k = seq_v = seq_q = seq =256
hidden = 768

# ===hyperparameter of FLAT===
batch_tile = 64   # 1<=batch_tile<=batch
head_tile = head  # 1<=head_tile<=head
seq_tile = seq  # 1<=seq_tile<=seq
#=============================

print('=========FLAT============')

for batcht in range(8):
    batch = 2** batcht
    key = np.ones((batch, head, hidden, seq_k))
    query = np.ones((batch, head, seq_q, hidden))
    value = np.ones((batch, head, seq_v, hidden))
    output = np.zeros((batch, head, seq_q, hidden))
    B = min(batch_tile, batch)
    H = min(head_tile, head)
    R = min(seq_tile, seq_q)
    print(f'Running Model Batch-{batch}, Head-{head}, Seq-{seq_q}, Hidden-{hidden}, with granularity B-{B}, H-{H}, R-{R}')
    fused_L_softmax_A(query, key, value, output, B, H, R)
    timing = %timeit -o fused_L_softmax_A(query, key, value, output, B, H, R)
    gc.collect()

Running Model Batch-1, Head-12, Seq-256, Hidden-768, with granularity B-1, H-12, R-256
10 loops, best of 5: 22.4 ms per loop
Running Model Batch-2, Head-12, Seq-256, Hidden-768, with granularity B-2, H-12, R-256
10 loops, best of 5: 46.1 ms per loop
Running Model Batch-4, Head-12, Seq-256, Hidden-768, with granularity B-4, H-12, R-256
10 loops, best of 5: 94.3 ms per loop
Running Model Batch-8, Head-12, Seq-256, Hidden-768, with granularity B-8, H-12, R-256
1 loop, best of 5: 188 ms per loop
Running Model Batch-16, Head-12, Seq-256, Hidden-768, with granularity B-16, H-12, R-256
1 loop, best of 5: 373 ms per loop
Running Model Batch-32, Head-12, Seq-256, Hidden-768, with granularity B-32, H-12, R-256
1 loop, best of 5: 742 ms per loop
Running Model Batch-64, Head-12, Seq-256, Hidden-768, with granularity B-64, H-12, R-256
1 loop, best of 5: 1.44 s per loop
Running Model Batch-128, Head-12, Seq-256, Hidden-768, with granularity B-64, H-12, R-256
1 loop, best of 5: 2.71 s per loop


#### Sweep Sequence Length

In [14]:
batch = 1
head = 12
hidden = 768

print('=========Baseline============')

for seqt in range(8):
    seq = 2** seqt
    seq_k = seq_v = seq_q =seq
    key = np.ones((batch, head, hidden, seq_k))
    query = np.ones((batch, head, seq_q, hidden))
    value = np.ones((batch, head, seq_v, hidden))
    output = np.zeros((batch, head, seq_q, hidden))
    B = min(batch_tile, batch)
    H = min(head_tile, head)
    R = min(seq_tile, seq_q)
    print(f'Running Model Batch-{batch}, Head-{head}, Seq-{seq_q}, Hidden-{hidden}, with granularity B-{B}, H-{H}, R-{R}')
    L_softmax_A(query, key, value, output, B, H, R)
    timing = %timeit -o L_softmax_A(query, key, value, output, B, H, R)
    del key, query, value, output
    gc.collect()

Running Model Batch-1, Head-12, Seq-1, Hidden-768, with granularity B-1, H-12, R-1
1000 loops, best of 5: 1.79 ms per loop
Running Model Batch-1, Head-12, Seq-2, Hidden-768, with granularity B-1, H-12, R-2
100 loops, best of 5: 1.86 ms per loop
Running Model Batch-1, Head-12, Seq-4, Hidden-768, with granularity B-1, H-12, R-4
1000 loops, best of 5: 1.99 ms per loop
Running Model Batch-1, Head-12, Seq-8, Hidden-768, with granularity B-1, H-12, R-8
100 loops, best of 5: 2.24 ms per loop
Running Model Batch-1, Head-12, Seq-16, Hidden-768, with granularity B-1, H-12, R-16
100 loops, best of 5: 2.84 ms per loop
Running Model Batch-1, Head-12, Seq-32, Hidden-768, with granularity B-1, H-12, R-32
100 loops, best of 5: 3.87 ms per loop
Running Model Batch-1, Head-12, Seq-64, Hidden-768, with granularity B-1, H-12, R-64
100 loops, best of 5: 6.62 ms per loop
Running Model Batch-1, Head-12, Seq-128, Hidden-768, with granularity B-1, H-12, R-128
100 loops, best of 5: 14.2 ms per loop


In [15]:
batch = 1
head = 12
hidden = 768

# ===hyperparameter of FLAT===
batch_tile = 64   # 1<=batch_tile<=batch
head_tile = head  # 1<=head_tile<=head
seq_tile = float('Inf')  # 1<=seq_tile<=seq
#=============================

print('=========FLAT============')

for seqt in range(8):
    seq = 2** seqt
    seq_k = seq_v = seq_q =seq
    key = np.ones((batch, head, hidden, seq_k))
    query = np.ones((batch, head, seq_q, hidden))
    value = np.ones((batch, head, seq_v, hidden))
    output = np.zeros((batch, head, seq_q, hidden))
    B = min(batch_tile, batch)
    H = min(head_tile, head)
    R = min(seq_tile, seq_q)
    print(f'Running Model Batch-{batch}, Head-{head}, Seq-{seq_q}, Hidden-{hidden}, with granularity B-{B}, H-{H}, R-{R}')
    fused_L_softmax_A(query, key, value, output, B, H, R)
    timing = %timeit -o fused_L_softmax_A(query, key, value, output, B, H, R)
    del key, query, value, output
    gc.collect()

Running Model Batch-1, Head-12, Seq-1, Hidden-768, with granularity B-1, H-12, R-1
1000 loops, best of 5: 779 µs per loop
Running Model Batch-1, Head-12, Seq-2, Hidden-768, with granularity B-1, H-12, R-2
1000 loops, best of 5: 873 µs per loop
Running Model Batch-1, Head-12, Seq-4, Hidden-768, with granularity B-1, H-12, R-4
1000 loops, best of 5: 976 µs per loop
Running Model Batch-1, Head-12, Seq-8, Hidden-768, with granularity B-1, H-12, R-8
1000 loops, best of 5: 1.35 ms per loop
Running Model Batch-1, Head-12, Seq-16, Hidden-768, with granularity B-1, H-12, R-16
1000 loops, best of 5: 1.86 ms per loop
Running Model Batch-1, Head-12, Seq-32, Hidden-768, with granularity B-1, H-12, R-32
100 loops, best of 5: 2.76 ms per loop
Running Model Batch-1, Head-12, Seq-64, Hidden-768, with granularity B-1, H-12, R-64
100 loops, best of 5: 5.01 ms per loop
Running Model Batch-1, Head-12, Seq-128, Hidden-768, with granularity B-1, H-12, R-128
100 loops, best of 5: 10.7 ms per loop
