## Misc from the Essential Interview

- [ ] Flash Attn https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad
    - Will be good to know the key takeaway from it
        - Mem efficient O(N) in terms of memory efficiency << I think that makes space complexity
        - Vanilla attention is memory bound 
        - There's a memory hierarchy on GPUs
            - SRAM (on GPU)
            - HBM (on GPU)
            - DRAM (on CPU)
        - Via kernel fusion flash attn combines multiple ops from attn and only loads from HBM once
        - Flash attn leverages both
            - Tiling (chunking softmax scores / matrices into blocks)
            - Recomputation in bkwrd pass
- [ ] Minhash - think it relates to jaccard distance 
    - https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.MinHashLSH.html
- [ ] Local vs. global attn
    - How does this work ?

In [8]:
import torch
import torch.nn.functional as F

import math

In [71]:
def flash_attn(Q, K, V, M):
    T, D = Q.shape
    # 4 b/c there's q, k, v, & o
    # d b/c each of the vectors are d-dim
    # block size for K & V
    b_col = math.ceil(M / (4 * D))
    # block size for Q & O 
    b_row = min(b_col, D)
    
    # num_blocks for Q & O & l & m 
    T_row = math.ceil(T / b_row)
    # num_blocks for for K, V 
    T_col = math.ceil(T / b_col)

    # init output on HBM 
    O = torch.zeros_like(Q)
    # holds the cum denom for the softmax
    l = torch.zeros(T)
    # holds the row-wise max scores
    m = torch.full((T,), float("-inf"))

    z = 1 / math.sqrt(D)

    # loop through num key / value blocks
    for j in range(T_col):
        # load K_j, V_j from HBM to SRAM 
        start_col = j*b_col
        end_col = j*b_col + b_col
        K_j = K[start_col:end_col, :]
        V_j = V[start_col:end_col, :]
    
        for i in range(T_row):
            start_row = i*b_row
            end_row = i*b_row + b_row
            # load Q_i, O_i, l_i, m_i from HBM to SRAM
            Q_i = Q[start_row:end_row, :]
            O_i = O[start_row:end_row, :]
            l_i = l[start_row:end_row]
            m_i = m[start_row:end_row]
    
            scores_ij = torch.matmul(Q_i, K_j.T) * z
            # dim 1 b/c we want the max acoss cols 
            m_ij, _ = torch.max(scores_ij, dim=1)
            # subtract the max from the attn scores for numerical stability
            scores_ij = torch.exp(scores_ij - m_ij[:, None])
            # denom of softmax 
            l_ij = torch.sum(scores_ij, dim=1)
            # update the max of the scores 
            m_i_new, _ = torch.max(torch.cat([m_i[:, None], m_ij[:, None]], dim=1), dim=1)
            # 1st term updates max to m_i_new for l_i's, 2nd updates max to m_i_new for l_ij
            l_i_new = torch.exp(m_i - m_i_new)*l_i + torch.exp(m_ij - m_i_new)*l_ij
            l_i_diag = torch.diag(l_i) * torch.exp(m_i - m_i_new)
            # write O_i to HBM
            O_i = torch.diag(l_i_new).inverse() @ (l_i_diag @ O_i + torch.exp(m_ij - m_i_new) * scores_ij @ V_j)
            O[start_row:end_row, :] = O_i
            # write l_i & m_i to HBM
            l[start_row:end_row] = l_i_new
            m[start_row:end_row] = m_i_new
    return O

In [78]:
seq_len = 4
d_model = 10
sram_size = 4 * 10 * 10

# these are on HBM when passed in 
q = torch.rand((seq_len, d_model))
k = torch.rand((seq_len, d_model))
v = torch.rand((seq_len, d_model))


In [79]:
gold_out = F.scaled_dot_product_attention(q[None, None, ...], k[None, None, ...], v[None, None, ...], is_causal=False)

In [80]:
flash_out = flash_attn(q, k, v, sram_size)

i, j: 0 0


In [84]:
torch.allclose(gold_out, flash_out)

True