### Setup
First, load the model.

In [1]:
from modeling_gpt2 import GPT2Model, GPT2LMHeadModel
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir="hf_home/")
# model = GPT2Model.from_pretrained("gpt2", cache_dir="hf_home/")  # for forward pass
model = GPT2LMHeadModel.from_pretrained("gpt2", cache_dir="hf_home/")  # for generation
model.cuda()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Initial Testing for Correctness
Now, let's do forward passes on some sample inputs. Starting with the baseline.

In [2]:
inputs = tokenizer(["Today is"], return_tensors="pt").to('cuda')
model.transformer.config._attn_implementation = "sdpa"  # NOTE: This is default, but we set manually here for emphasis.
out = model.forward(inputs['input_ids'])
out.logits

tensor([[[ -36.3292,  -36.3402,  -40.4228,  ...,  -46.0234,  -44.5284,
           -37.1276],
         [-122.8355, -122.5403, -127.6362,  ..., -133.4906, -131.9769,
          -125.4615]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>)

Now, using the minimal-flash-attn default attention implementation. Default is `sdpa`.

In [3]:
inputs = tokenizer(["Today is"], return_tensors="pt").to('cuda')
model.transformer.config._attn_implementation = "mha_forward"
out = model.forward(inputs['input_ids'])
out.logits

Before:
523.113984 MB allocated
572.522496 MB reserved
After allocating:
523.115008 MB allocated
572.522496 MB reserved
After freeing:
523.120128 MB allocated
570.425344 MB reserved
Before:
523.288064 MB allocated
572.522496 MB reserved
After allocating:
523.289088 MB allocated
572.522496 MB reserved
After freeing:
523.294208 MB allocated
570.425344 MB reserved
Before:
523.462144 MB allocated
572.522496 MB reserved
After allocating:
523.463168 MB allocated
572.522496 MB reserved
After freeing:
523.468288 MB allocated
570.425344 MB reserved
Before:
523.636224 MB allocated
572.522496 MB reserved
After allocating:
523.637248 MB allocated
572.522496 MB reserved
After freeing:
523.642368 MB allocated
570.425344 MB reserved
Before:
523.810304 MB allocated
572.522496 MB reserved
After allocating:
523.811328 MB allocated
572.522496 MB reserved
After freeing:
523.816448 MB allocated
570.425344 MB reserved
Before:
523.984384 MB allocated
572.522496 MB reserved
After allocating:
523.985408 MB all

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<UnsafeViewBackward0>)

*Note*: These logits are different -- there's likely some problem with the gpu memory. If I run this multiple times, gpu memory goes up without getting cleaned up.

Now using our improved version.

In [19]:
inputs = tokenizer(["Today is"], return_tensors="pt").to('cuda')
model.transformer.config._attn_implementation = "improved_mha_forward"
out = model.forward(inputs['input_ids'])
out.logits

tensor([[[ -81.5403,  -78.8223,  -84.0822,  ...,  -91.6164,  -90.7842,
           -80.2482],
         [ -96.5142,  -93.4523,  -99.9717,  ..., -116.0871, -115.3968,
           -99.3011]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>)

### Initial Testing for Timing
Great! We can see the shapes are the same and the output tensors are too. This means the attention implementation is correct. Now, let's see if it is faster.

In [20]:
import os
os.makedirs("traces", exist_ok=True)

In [23]:
import torch
import time
attn_implementations = ["sdpa", "mha_forward", "improved_mha_forward"]
for attn_implementation in attn_implementations:
    print(f'=== profiling `{attn_implementation}` attention === ')
    model.transformer.config._attn_implementation = attn_implementation
    with torch.autograd.profiler.profile(use_device='cuda') as prof:
        start_time = time.time()
        out = model.forward(inputs['input_ids'])
        end_time = time.time()
    prof.export_chrome_trace(f"traces/{attn_implementation}_trace.json")  # note we can inspect these with chrome://tracing/
    print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
    print(f"Total time taken: {end_time - start_time}\n\n")

=== profiling `sdpa` attention === 
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::addmm        10.25%       2.237ms        14.33%       3.129ms      65.189us       3.527ms        17.82%       4.001ms      83.354us            48  
                                 aten::layer_norm         0.73%     160.395us        11.93%       2.605ms     104.182us     258.000us         1.30%       2.691ms     107.640us            25  
   

Note we see that the default minimal flash attn implementation is faster by ~7ms! And we can see the `aten::scaled_dot_product_attention` runs for the baseline and our `minimal_attn::mha_forward kernel` replaces it, as desired. For now, it does appear to take slightly more CUDA time, which makes sense as we have to use more memory -- the original version probably has more to do on the CPU. In the future we will test how that varies with different input sizes.

| Name                                     | Self CPU %   | Self CPU    | CPU total % | CPU total   | CPU time avg | Self CUDA   | Self CUDA % | CUDA total  | CUDA time avg | # of Calls |
| :--------------------------------------- | :----------- | :---------- | :---------- | :---------- | :----------- | :---------- | :---------- | :---------- | :------------ | :--------- |
| `aten::scaled_dot_product_attention`     | 0.67%        | 213.412us   | 9.23%       | 2.931ms     | 244.211us    | 302.000us   | 1.27%       | 2.978ms     | 248.167us     | 12         |
| `minimal_attn::mha_forward`              | 1.01%        | 162.292us   | 8.00%       | 1.285ms     | 107.096us    | 2.465ms     | 14.54%      | 3.323ms     | 276.917us     | 12         |

In [None]:
# Create a bunch of inputs with random tokens.
# Note we are not doing real words bc we are not dealing with masking right now
# and there should be no need.


### Generation
Let's try generation now before starting the benchmarking.

In [26]:
# Generate text
import torch
import time

def generate_and_decode(model, inputs):
    with torch.no_grad():
        output_ids = model.generate(
            inputs['input_ids'],
            max_new_tokens=100,
            num_return_sequences=1,
            do_sample=False
        )
        print(tokenizer.decode(output_ids[0]))

def run_with_tracing(inputs, attn_implementation):
    print(f'=== profiling `{attn_implementation}` attention === ')
    model.transformer.config._attn_implementation = attn_implementation
    with torch.autograd.profiler.profile(use_device='cuda') as prof:
        start_time = time.time()
        out = generate_and_decode(model, inputs)
        end_time = time.time()
    # prof.export_chrome_trace(f"traces/{attn_implementation}_generation_trace.json")  # note we can inspect these with chrome://tracing/
    print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
    print(f"Total time taken: {end_time - start_time}\n\n")

In [27]:
inputs = tokenizer(["Today is"], return_tensors="pt").to('cuda')
attn_implementations = ["sdpa", "mha_forward", "improved_mha_forward"]
for attn_implementation in attn_implementations:
    run_with_tracing(inputs, attn_implementation)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


=== profiling `sdpa` attention === 
Today is!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::addmm         8.69%     218.122ms        38.64%     969.393ms     201.957us     356.760ms        13.96%        1.052s     219.115us          4800  
                                            aten::empty         4.17%     104.695ms        29.86%     749.054ms      41.756us     820.460ms        32.10%     820.460ms      45.736us         17939  
         

/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [5,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [5,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [5,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [5,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [5,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [5,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelect

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
