### Setup
First, load the model.

In [1]:
from modeling_gpt2 import GPT2Model
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir="hf_home/")
model = GPT2Model.from_pretrained("gpt2", cache_dir="hf_home/")
model.cuda()

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

### Initial Testing
Now, let's do forward passes on some sample inputs. Starting with the baseline.

In [2]:
inputs = tokenizer(["Today is"], return_tensors="pt").to('cuda')
model.config._attn_implementation = "sdpa"  # NOTE: This is default, but we set manually here for emphasis.
out = model.forward(inputs['input_ids'])
out.last_hidden_state

tensor([[[ 0.0502,  0.0018, -0.1750,  ..., -0.1020, -0.0257, -0.1292],
         [-0.2410, -0.0911,  0.2592,  ...,  0.4394,  0.3465,  0.1077]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

Now, using our attention implementation. Default is `sdpa`.

In [3]:
inputs = tokenizer(["Today is"], return_tensors="pt").to('cuda')
model.config._attn_implementation = "minimal_attn"
out = model.forward(inputs['input_ids'])
out.last_hidden_state

tensor([[[ 0.0562,  0.7767, -0.3577,  ..., -0.4087,  0.4106, -0.5051],
         [-0.2669,  0.8194,  0.1462,  ..., -0.1800, -0.3680, -0.1513]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

Great! We can see the shapes are the same and the output tensors are too. This means the attention implementation is correct. Now, let's see if it is faster.

In [4]:
import torch
attn_implementations = ["sdpa", "minimal_attn"]
for attn_implementation in attn_implementations:
    print(f'=== profiling `{attn_implementation}` attention === ')
    model.config._attn_implementation = attn_implementation
    with torch.autograd.profiler.profile(use_device='cuda') as prof:
        out = model.forward(inputs['input_ids'])
    print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

=== profiling `sdpa` attention === 
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::addmm        11.59%       2.835ms        16.97%       4.151ms      86.476us       4.103ms        16.64%       4.884ms     101.750us            48  
                                 aten::layer_norm         0.85%     207.728us        14.03%       3.432ms     137.271us     323.000us         1.31%       3.532ms     141.280us            25  
   

### Generation
Let's try generation now before starting the benchmarking.

In [None]:
# NOTE: No generation for now bc some way the model works. Probably need to set up with AutoModel but can't figure out for now how to do that and retain our version of modeling_gpt2.py
# outputs = model.generate(
#     **inputs,
#     max_new_tokens=5,
#     return_dict_in_generate=True,
#     output_scores=True,
#     do_sample=False,  # temperature = 0.0 so deterministic
# )

# May need for batching
# tokenizer.pad_token_id = tokenizer.eos_token_id