The goal is to apply the minimal flash attn implementation to LLMs in PyTorch/HuggingFace to benchmark their performance.

Using `demo-flash-attention-minimal.ipynb` to start, but note that doesn't use LLMs at all, it only benchmarks the kernel alone.

We want to benchmark the kernel **within** LLMs and compare with the defaults. Let's start with some setup:

### Setup
Note that we need python3.10 so that pybind works. Currently have python3.12. So, in terminal:
```bash
source /project/engineering/anaconda3/etc/profile.d/conda.sh
conda create -n py310 python=3.10
conda activate py310
conda install jupyter ipykernel
pip3 install torch torchvision torchaudio
pip install transformers
pip install Ninja   # for building cuda
python -m ipykernel install --user --name py310 --display-name "Python 3.10 (py310)"
```

In [3]:
!jupyter kernelspec list

Available kernels:
  myenv      /home/warehouse/cnicholas/.local/share/jupyter/kernels/myenv
  py310      /home/warehouse/cnicholas/.local/share/jupyter/kernels/py310
  python3    /opt/conda/share/jupyter/kernels/python3


Now change the kernel to `Python (myenv)` to work in our fresh environment. Probably need to close then reopen the notebook.

In [1]:
import torch
assert torch.cuda.is_available(), "You must have a GPU to run this notebook."
print("GPU available.")

GPU available.


### Reused Code
Below is copied from the notebook. We'll test and ensure it works. No need to run this

In [3]:
### Default Minimal Kernel Implementation copied from the notebook:
import os
import math

import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load_inline

cuda_src = '''
__global__
void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d,
                    const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale,
                    float* l, float *m, float* O) {
    int tx = threadIdx.x;
    int bx = blockIdx.x; int by = blockIdx.y;  // batch and head index

    // Offset into Q,K,V,O,l,m - different for each batch and head
    int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d);  // gridDim.y = nh
    int lm_offset = (bx * gridDim.y * N) + (by * N);  // offset for l and m

    // Define SRAM for Q,K,V,S
    extern __shared__ float sram[];
    int tile_size = Bc * d;  // size of Qi, Kj, Vj
    float* Qi = sram;
    float* Kj = &sram[tile_size];
    float* Vj = &sram[tile_size * 2];
    float* S = &sram[tile_size * 3];

    for (int j = 0; j < Tc; j++) {

        // Load Kj, Vj to SRAM
        for (int x = 0; x < d; x++) {
            Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
            Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
        }
        __syncthreads();  // such that the inner loop can use the correct Kj, Vj

        for (int i = 0; i < Tr; i++)  {

            // Load Qi to SRAM, l and m to registers
            for (int x = 0; x < d; x++) {
                Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
            }
            float row_m_prev = m[lm_offset + (Br * i) + tx];
            float row_l_prev = l[lm_offset + (Br * i) + tx];

            // S = QK^T, row_m = rowmax(S)
            float row_m = -INFINITY;
            for (int y = 0; y < Bc; y++) {
                float sum = 0;
                for (int x = 0; x < d; x++) {
                    sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
                }
                sum *= softmax_scale;
                S[(Bc * tx) + y] = sum;

                if (sum > row_m)
                    row_m = sum;
            }

            // P = exp(S - row_m), row_l = rowsum(P)
            float row_l = 0;
            for (int y = 0; y < Bc; y++) {
                S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
                row_l += S[(Bc * tx) + y];
            }

            // Compute new m and l
            float row_m_new = max(row_m_prev, row_m);
            float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

            // Write O, l, m to HBM
            for (int x = 0; x < d; x++) {
                float pv = 0;  // Pij * Vj
                for (int y = 0; y < Bc; y++) {
                    pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
                }
                O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
                    * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
                    + (__expf(row_m - row_m_new) * pv));
            }
            m[lm_offset + (Br * i) + tx] = row_m_new;
            l[lm_offset + (Br * i) + tx] = row_l_new;
        }
        __syncthreads();  // otherwise, thread can use the wrong Kj, Vj in inner loop
    }
}

torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
    // TODO: determine Bc, Br dynamically
    const int Bc = 32; const int Br = 32;

    const int B = Q.size(0); const int nh = Q.size(1);
    const int N = Q.size(2); const int d = Q.size(3);

    const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
    const float softmax_scale = 1.0 / sqrt(d);

    // Initialize O, l, m to HBM
    auto O = torch::zeros_like(Q);
    auto l = torch::zeros({B, nh, N});
    auto m = torch::full({B, nh, N}, -INFINITY);
    torch::Device device(torch::kCUDA);
    l = l.to(device); m = m.to(device);

    // Calculate SRAM size needed per block
    const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
    int max_sram_size;
    cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
    printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);

    dim3 grid_dim(B, nh);  // batch_size x num_heads
    dim3 block_dim(Bc);  // Bc threads per block

    forward_kernel<<<grid_dim, block_dim, sram_size>>>(
        Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
        N, d, Tc, Tr, Bc, Br, softmax_scale,
        l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
    );
    return O;
}
'''
cpp_src = 'torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V);'

build_dir = 'cuda'
if not os.path.exists(build_dir):
    os.mkdir(build_dir)

minimal_attn = load_inline(
    name='minimal_attn',
    cpp_sources=cpp_src,
    cuda_sources=cuda_src,
    functions=['forward'],
    with_cuda=True,
    extra_cuda_cflags=['-O2'],
    build_directory=f'./{build_dir}'
)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [5]:
# Notebook code test. Note this does not use an LLM.
batch_size = 32
n_head = 12
seq_len = 64
head_embd = 32

q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()

print('=== profiling manual attention ===')

# Our minimal flash attention needs to be faster than this.
def manual_attn(q, k, v):
    att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
    att = F.softmax(att, dim=-1)
    y = att @ v
    return y

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('=== profiling minimal flash attention === ')

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    minimal_result = minimal_attn.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))

=== profiling manual attention ===


  with torch.autograd.profiler.profile(use_cuda=True) as prof:


-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                     aten::matmul         2.92%       6.357ms        68.67%     149.606ms      74.803ms       6.312ms         2.90%     149.696ms      74.848ms             2  
                                        aten::bmm        26.62%      57.984ms        62.81%     136.839ms      68.420ms     137.021ms        62.87%     137.021ms      68.510ms             2  
                                    ate

  with torch.autograd.profiler.profile(use_cuda=True) as prof:


attn values sanity check: True


### Our Code
Our new code for LLM benchmarking. We will start with GPT-2. Note with newer models attention might get weird so we may want to avoid them for now (e.g., those with RoPE).

We will start by defining a custom operator in PyTorch, [as explained in the docs](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial) and [examples](https://github.com/pytorch/extension-cpp).

First, we'll define a `setup.py` based on [the PyTorch `extension-cpp` examples](https://github.com/pytorch/extension-cpp/blob/master/setup.py). Then, we will define an operator within our `flash.cu` code. We first define a namespace, `minimal_attn`, then implement `mha_forward` as an operator (this is what we wrote in Cuda & c++). Below we will test the operator. Note that the dir structure must be as follows:

```
.
├── minimal_attn/
│   └── csrc/
│       ├── your_cpp_code.cpp
│       └── cuda/
│           └── flash.cu
├── setup.py
```

With this, we can `pip install --no-build-isolation -e .` to get our `minimal_attn` extension.

In [2]:
# !export TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6"
!pip install --no-build-isolation -e .

[0mDefaulting to user installation because normal site-packages is not writeable
Obtaining file:///home/warehouse/cnicholas/cse4059/final_project/flash-attention-minimal/operator
  Preparing metadata (setup.py) ... [?25ldone
Installing collected packages: minimal_attn
  Attempting uninstall: minimal_attn
    Found existing installation: minimal_attn 0.0.1
    Uninstalling minimal_attn-0.0.1:
      Successfully uninstalled minimal_attn-0.0.1
  Running setup.py develop for minimal_attn
    [1;31merror[0m: [1msubprocess-exited-with-error[0m
    
    [31m×[0m [32mpython setup.py develop[0m did not run successfully.
    [31m│[0m exit code: [1;36m1[0m
    [31m╰─>[0m [31m[1311 lines of output][0m
    [31m   [0m running develop
    [31m   [0m !!
    [31m   [0m 
    [31m   [0m         ********************************************************************************
    [31m   [0m         Please avoid running ``setup.py`` and ``easy_install``.
    [31m   [0m         

In [10]:
import minimal_attn

batch_size = 32
n_head = 12
seq_len = 64
head_embd = 32

def sample_inputs(device, *, requires_grad=False):
    def make_kqv(batch_size, n_head, seq_len, head):
      q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
      k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
      v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
      return q, k, v

    return [
        make_kqv(batch_size, n_head, seq_len, head_embd)
    ]

# Our minimal flash attention needs to be faster than this.
def manual_attn(q, k, v):
    att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
    att = F.softmax(att, dim=-1)
    y = att @ v
    return y

device = torch.device('cuda')
samples = sample_inputs(device, requires_grad=True)
samples.extend(sample_inputs(device, requires_grad=False))
for args in samples:
    # Correctness test
    # result = torch.ops.extension_cpp.mymuladd(*args)
    result = torch.ops.minimal_attn.mha_forward(*args)
    expected = manual_attn(*args)
    torch.testing.assert_close(result, expected)

    # Use opcheck to check for incorrect usage of operator registration APIs
    torch.library.opcheck(torch.ops.minimal_attn.mha_forward.default, args)

AttributeError: '_OpNamespace' 'minimal_attn' object has no attribute 'mha_forward'

### HF

In [6]:
# We use our inline loaded version via torch, minimal_attn
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")
model.cuda()

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

### Other Ideas
We can try learning a regression model/small nn to give us the optimal block size under certain conditions. This can be something to show for -- we could just run gpt-2 a bunch of times and get an optimal version.

*Note in the future we can add torch.compile support for the operator.*