In [None]:
!nvcc --version

In [None]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install flash_attn --no-build-isolation

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from flash_attn import FlashAttention
from flash_attn.modules.mha import FlashSelfAttention, FlashAttention2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)

In [None]:
# Initialize BERT Tokenizer and sample text data
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text = [
    "FlashAttention is a fast and memory-efficient attention mechanism.",
    "FlashAttention-2 optimizes parallelism and work partitioning."
]
# Tokenize the input text and move tensors to the GPU
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)


In [None]:
# Define a custom BERT model with FlashAttention
class BertWithFlashAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.flash_attention = FlashSelfAttention(causal=False)

    def forward(self, input_ids, attention_mask):
        # Get BERT encoder outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.last_hidden_state

        # Apply FlashAttention on hidden states
        attention_out = self.flash_attention(hidden_states)
        return attention_out

# Define a custom BERT model with FlashAttention-2
class BertWithFlashAttention2(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.flash_attention2 = FlashAttention2(causal=False)

    def forward(self, input_ids, attention_mask):
        # Get BERT encoder outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.last_hidden_state

        # Apply FlashAttention-2 on hidden states
        attention_out = self.flash_attention2(hidden_states)
        return attention_out

# Initialize both models and move them to GPU
model_flash = BertWithFlashAttention().to(device)
model_flash2 = BertWithFlashAttention2().to(device)


In [None]:
import time

def run_inference(model, inputs, num_runs=100):
    # Synchronize GPU to ensure accurate timing
    torch.cuda.synchronize()
    start_time = time.time()

    # Run multiple inference passes without gradients
    with torch.no_grad():
        for _ in range(num_runs):
            model(inputs['input_ids'], inputs['attention_mask'])

    # Synchronize again and calculate elapsed time
    torch.cuda.synchronize()
    end_time = time.time()
    return end_time - start_time

In [None]:
# Measure inference time for FlashAttention
flash_time = run_inference(model_flash, inputs)
print(f"FlashAttention execution time: {flash_time:.4f} seconds")

# Measure inference time for FlashAttention-2
flash2_time = run_inference(model_flash2, inputs)
print(f"FlashAttention-2 execution time: {flash2_time:.4f} seconds")

# Verify the output shapes
with torch.no_grad():
    output_flash = model_flash(inputs['input_ids'], inputs['attention_mask'])
    output_flash2 = model_flash2(inputs['input_ids'], inputs['attention_mask'])

print("Output shape with FlashAttention:", output_flash.shape)
print("Output shape with FlashAttention-2:", output_flash2.shape)


- The execution times indicate that FlashAttention-2 is generally faster than the original FlashAttention due to its optimized parallelism and better work partitioning.
- Both models produce the same output shape, demonstrating that they are functionally equivalent while offering different levels of performance efficiency.