In [35]:
import argparse
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

In [31]:
# Create small tensors for testing
BATCH = 1
HEADS = 1
SEQ_LEN = 8
HEAD_DIM = 4

print("Creating random tensors")
K = jax.random.normal(jax.random.key(0), (BATCH, HEADS, SEQ_LEN, HEAD_DIM))
V = jax.random.normal(jax.random.key(1), (BATCH, HEADS, SEQ_LEN, HEAD_DIM))
Q = jax.random.normal(jax.random.key(2), (BATCH, HEADS, SEQ_LEN, HEAD_DIM))
print("Random tensors created")

In [7]:
K[0][0][:]

In [8]:
Q[0][0][:]

In [9]:
V[0][0][:]

https://www.youtube.com/watch?v=IvgV6QcsC64

![./einsum_overview.png](einsum_overview.png)
![./einsum_element_wise_multiplication.png](einsum_element_wise_multiplication.png)
![./einsum_rule_1.png](einsum_rule_1.png)
![./einsum_rule_3-4.png](einsum_rule_3-4.png)
![./einsum_summation_along_axis.png](einsum_summation_along_axis.png)


In [10]:
# Matrix a
# create a 3x3 matrix
a = [[1,2,3],
     [4,5,6],
     [7,8,9]]

# Matrix b
b = [[1,2,3],
     [4,5,6],
     [7,8,9]]

print("Method 1: Using nested for loops")
# Initialize result matrix with zeros
result_loops = [[0 for x in range(len(b[0]))] for y in range(len(a))]

# Iterate through rows of a
for i in range(len(a)):
    # Iterate through columns of b
    for j in range(len(b[0])):
        # Iterate through rows of b
        for k in range(len(b)):
            result_loops[i][j] += a[i][k] * b[k][j]

print("Result using for loops:")
for row in result_loops:
    print(row)

print("\nMethod 2: Using numpy einsum")
import numpy as np
a_np = np.array(a)
b_np = np.array(b)
result_einsum = np.einsum('ik,kj->ij', a_np, b_np)

print("Result using einsum:")
print(result_einsum)


In [11]:
def attention_ourselves(_Q, _K, _V):
    print("Computing attention weights")
    print("Q shape:", _Q.shape)
    print("K shape:", _K.shape)
    print("V shape:", _V.shape)
    print("\n")

    # Dimensions:
    # _Q, _K, _V are (batch, seq_len, heads, head_dim)
    # _weights_unnormalized is (batch, heads, seq_len, seq_len)
    # _weights after softmax is (batch, heads, seq_len, seq_len)
    # output is (batch, seq_len, heads, head_dim)

    # Step 1: Q * K^T to get attention weights
    # Memory: 2 * (2*batch*seq*head_dim) for Q,K
    # Flops: 2*batch*seq*seq*head_dim
    _weights_unnormalized = jax.numpy.einsum("bshd,bthd->bhst", _Q, _K)
    _weights = jax.nn.softmax(_weights_unnormalized)

    # Step 2: weights * V to get final output
    # Memory: 2*batch*heads*seq^2 for loading weights (seq^2 because weights matrix is seq_len x seq_len)
    # Flops: 2*batch*seq*seq*head_dim
    output = jax.numpy.einsum("bhst,bshd->bshd", _weights, _V)

    print("Weights computed successfully")
    print("Output shape:", output.shape)
    print("\nDimension Analysis:")
    print("Input shapes (Q,K,V): batch x seq_len x heads x head_dim")
    print("Weight matrix shape: batch x heads x seq_len x seq_len")
    print("Output shape: batch x seq_len x heads x head_dim")

    # Calculate and print memory bandwidth, flops, and arithmetic intensity
    batch, seq_len, heads, head_dim = _Q.shape

    print("\nDetailed Analysis:")
    print("Step 1: Q * K^T (Attention Weights Calculation)")
    mem_bandwidth_step1 = 2 * (2 * batch * seq_len * heads * head_dim) + (2 * batch * heads * seq_len * seq_len)
    flops_step1 = 2 * batch * heads * seq_len * seq_len * head_dim
    ai_step1 = flops_step1 / mem_bandwidth_step1
    print(f"  Memory Bandwidth: {mem_bandwidth_step1} units")
    print(f"  FLOPS: {flops_step1} operations")
    print(f"  Arithmetic Intensity: {ai_step1:.2f}")

    print("\nStep 2: weights * V (Final Output Calculation)")
    mem_bandwidth_step2 = (2 * batch * heads * seq_len * seq_len) + (2 * batch * seq_len * heads * head_dim)
    flops_step2 = 2 * batch * heads * seq_len * seq_len * head_dim
    ai_step2 = flops_step2 / mem_bandwidth_step2
    print(f"  Memory Bandwidth: {mem_bandwidth_step2} units")
    print(f"  FLOPS: {flops_step2} operations")
    print(f"  Arithmetic Intensity: {ai_step2:.2f}")

    print(f"\nAssuming seq_len ({seq_len}) >> head_dim ({head_dim}):")
    print(f"  Approximate Arithmetic Intensity: ~{head_dim}")

    return output


In [13]:
print("Calling attention function")
result = attention_ourselves(Q, K, V)


# Analysis of Coding

## Inputs are:
* 3x Batch, Sequence, HeadDim (Q,K,V)

## Outputs are:
* Batch, Sequence, HeadDim.

## Intermediate output is:
* Output dimension Batch, Sequence, Sequence = W = softmax(einsum(Q,V))
* Memory bandwidth = 2 * (2* Batch * Sequence * HeadDim) + (2* Batch * Sequence^2)
* Flops are 2* Batch * Sequence * Sequence * HeadDim
* Assuming Seq >> HeadDim, Arithmetic intensity is ~HeadDim.

## Then W*V:
* Output is Batch, Sequence, HeadDim = einsum(W,V)
* Flops are 2* Batch * Sequence * Sequence * HeadDim
* Memory bandwidth again dominated by (2* Batch * Sequence^2) (loading W)
* So Arithmetic Intensity Again ~HeadDim


## Overall bandwidth:
* Inputs are 3*Batch*Sequence*HeadDim
* Outputs are 1*Batch*Sequence*HeadDim
* So 8*Batch*Sequence*HeadDim bytes.

## Overall flops:
* 4*Batch*Sequence*Sequence*HeadDim

## Overall ratio flops/byte:
* Sequence / 2.





# Many Many Solutions Emerged!

* The key problem is that tensors of size Batch, Sequence, Sequence are too big to write to HBM efficiently.
* What we need is a fused kernel that allows us to not write back to HBM
* The simplest fused schedule is actually to notice that with Sequence=2048, we can handle each example independently and the tensors are only as large as 2048*2028*2 = 8.3 MB.
    * We have 160MB of SRAM - no reason we should be writing anything back to HBM.
    * This trick doesn't actually work that well - at Sequence =16384 we'd need 536 MB so we need to use HBM.
* The most famous and widely used is fused schedule is FlashAttention (Tri Dao et al, 2022)
    * This depends on some clever observations about softmax and actually fully breaks the memory dependence on Batch*Sequence^2.
    * Nowadays there are many variants of FlashAttention that all exploit the same observation about softmax.
    * (We can cover FlashAttention in detail at some point if folks want. I'll hold a poll once I cover all the basic topics!)



# Review of Attention Perf

* The key problem is that tensors of size Batch, Sequence, Sequence are too big to write to HBM efficiently.
* What we need is a fused kernel that allows us to not write back to HBM
* The simplest fused schedule is actually to notice that with Sequence=2048, we can handle each example independently and the tensors are only as large as 2048*2028*2 = 8.3 MB.
    * We have 160MB of SRAM - no reason we should be writing anything back to HBM.
    * This trick doesn't actually work that well - at Sequence =16384 we'd need 536 MB so we need to use HBM.
* The most famous and widely used is fused schedule is FlashAttention (Tri Dao et al, 2022)
    * This depends on some clever observations about softmax and actually fully breaks the memory dependence on Batch*Sequence^2.
    * Nowadays there are many variants of FlashAttention that all exploit the same observation about softmax.
    * (We can cover FlashAttention in detail at some point if folks want. I'll hold a poll once I cover all the basic topics!)


# Review of Attention Usefulness

* This version of attention is totally order invariant!
    * To be useful we will need to add positional encodings - so each tensor is representing where it comes from.
        * (Attention will still be order invariant)
        * This is a bizarre but useful trait of attention!
* This version of attention is not causal!
    * Easy to add - zero out unwanted W_unnormalized's
* This version of attention doesn't support "multiprompt packing" - training on multiple sequences in one example.
    * Also easy to add - zero out unwanted W_unnormalized's
* Endless more tricks in Attention! But these (and Flash variants) are the top 3


In [32]:
def attention_ourselves_causal(_Q, _K, _V):
    # Dimensions:
    # _Q, _K, _V are (batch, seq_len, heads, head_dim)
    # _weights_unnormalized is (batch, heads, seq_len, seq_len)
    # _weights after softmax is (batch, heads, seq_len, seq_len)
    # output is (batch, seq_len, heads, head_dim)

    # Step 1: Q * K^T to get attention weights
    _weights_unnormalized = jax.numpy.einsum("bshd,bthd->bhst", _Q, _K)
    _weights_unnormalized_to_zero_out = jax.numpy.triu(jax.numpy.ones((SEQ_LEN, SEQ_LEN), jax.numpy.bfloat16), 1)
    _weights = jax.nn.softmax(_weights_unnormalized - 1e6 * _weights_unnormalized_to_zero_out)

    # Step 2: weights * V to get final output
    output = jax.numpy.einsum("bhst,bshd->bshd", _weights, _V)

    return output


In [36]:
# attention_ourselves_causal = attention_ourselves_causal(Q, K, V)
# attn_value = pl.attention.mha_reference(Q, K, V, segment_ids=None, causal=True)