# Flash Attention Hardware Check

Run this notebook on a compute node with GPU to verify flash attention compatibility before training.

**What this checks:**
1. GPU availability and CUDA version
2. Flash Attention 2 package availability
3. Model loading with different attention implementations
4. Memory usage with a small forward pass

In [1]:
import os
import sys
from pathlib import Path

import torch
from dotenv import load_dotenv

load_dotenv()
MODELS_PATH = os.getenv('MODELS_PATH')
print(f"MODELS_PATH: {MODELS_PATH}")

MODELS_PATH: /scratch/jt1955


In [2]:
!which python

/mnt/cup/labs/graziano/jack/newcomb/venv/bin/python


In [3]:
!pip freeze | grep flash

flash-attn==2.8.3


## 1. Check GPU and CUDA

In [4]:
print("=" * 50)
print("GPU & CUDA Information")
print("=" * 50)

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"cuDNN version: {torch.backends.cudnn.version()}")
    print(f"\nGPU count: {torch.cuda.device_count()}")
    
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(f"\nGPU {i}: {props.name}")
        print(f"  Memory: {props.total_memory / 1024**3:.1f} GB")
        print(f"  Compute capability: {props.major}.{props.minor}")
else:
    print("\nERROR: No GPU available! Run this on a compute node.")

GPU & CUDA Information

PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA version: 12.8
cuDNN version: 91002

GPU count: 2

GPU 0: NVIDIA L40S
  Memory: 44.4 GB
  Compute capability: 8.9

GPU 1: NVIDIA L40S
  Memory: 44.4 GB
  Compute capability: 8.9


## 2. Check Flash Attention Package

In [5]:
print("=" * 50)
print("Flash Attention Check")
print("=" * 50)

# Check if flash_attn package is installed
try:
    import flash_attn
    print(f"\nflash_attn version: {flash_attn.__version__}")
    flash_attn_installed = True
except ImportError:
    print("\nflash_attn package NOT installed")
    flash_attn_installed = False

# Check transformers utility
try:
    from transformers.utils import is_flash_attn_2_available
    fa2_available = is_flash_attn_2_available()
    print(f"is_flash_attn_2_available(): {fa2_available}")
except ImportError:
    print("Could not import is_flash_attn_2_available")
    fa2_available = False

if fa2_available:
    print("\n✓ Flash Attention 2 IS available")
    recommended = "flash_attention_2"
else:
    print("\n✗ Flash Attention 2 NOT available")
    print("  Will use SDPA (PyTorch native) instead")
    recommended = "sdpa"

print(f"\nRECOMMENDED: attn_implementation='{recommended}'")

Flash Attention Check

flash_attn version: 2.8.3


  from .autonotebook import tqdm as notebook_tqdm


is_flash_attn_2_available(): True

✓ Flash Attention 2 IS available

RECOMMENDED: attn_implementation='flash_attention_2'


## 3. Test Model Loading

In [6]:
# Configuration
MODEL_NAME = "Llama-3.1-8B"  # Change this to test other models
MODEL_PATH = Path(MODELS_PATH) / MODEL_NAME

print(f"Testing model: {MODEL_PATH}")
print(f"Model exists: {MODEL_PATH.exists()}")

Testing model: /scratch/jt1955/Llama-3.1-8B
Model exists: True


In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer

def test_attention_impl(attn_impl: str) -> dict:
    """Test loading model with specific attention implementation."""
    print(f"\nTesting attn_implementation='{attn_impl}'...")
    
    result = {
        'implementation': attn_impl,
        'load_success': False,
        'forward_success': False,
        'error': None,
        'memory_gb': None,
    }
    
    try:
        # Clear memory
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        # Load model
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            local_files_only=True,
            torch_dtype=torch.bfloat16,
            attn_implementation=attn_impl,
            device_map="auto",
        )
        result['load_success'] = True
        print(f"  ✓ Model loaded successfully")
        
        # Load tokenizer for forward pass test
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_PATH,
            local_files_only=True,
        )
        
        # Test forward pass
        inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs)
        result['forward_success'] = True
        print(f"  ✓ Forward pass successful")
        
        # Record memory
        result['memory_gb'] = torch.cuda.max_memory_allocated() / 1024**3
        print(f"  Memory used: {result['memory_gb']:.2f} GB")
        
        # Cleanup
        del model, tokenizer, inputs, outputs
        torch.cuda.empty_cache()
        
    except Exception as e:
        result['error'] = str(e)
        print(f"  ✗ Error: {e}")
    
    return result

In [8]:
# Test different implementations
print("=" * 50)
print("Testing Attention Implementations")
print("=" * 50)

implementations_to_test = ["sdpa", "eager"]
if fa2_available:
    implementations_to_test.insert(0, "flash_attention_2")

results = []
for impl in implementations_to_test:
    result = test_attention_impl(impl)
    results.append(result)

`torch_dtype` is deprecated! Use `dtype` instead!


Testing Attention Implementations

Testing attn_implementation='flash_attention_2'...


Loading checkpoint shards: 100%|██████████| 4/4 [01:15<00:00, 18.75s/it]


  ✓ Model loaded successfully
  ✓ Forward pass successful
  Memory used: 6.68 GB

Testing attn_implementation='sdpa'...


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.12it/s]


  ✓ Model loaded successfully
  ✓ Forward pass successful
  Memory used: 13.34 GB

Testing attn_implementation='eager'...


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.13it/s]


  ✓ Model loaded successfully
  ✓ Forward pass successful
  Memory used: 20.01 GB


## 4. Summary and Recommendation

In [9]:
print("\n" + "=" * 50)
print("SUMMARY")
print("=" * 50)

print("\n| Implementation | Load | Forward | Memory (GB) |")
print("|----------------|------|---------|-------------|")

best_impl = None
best_memory = float('inf')

for r in results:
    load = "✓" if r['load_success'] else "✗"
    forward = "✓" if r['forward_success'] else "✗"
    mem = f"{r['memory_gb']:.2f}" if r['memory_gb'] else "N/A"
    print(f"| {r['implementation']:14} | {load:4} | {forward:7} | {mem:11} |")
    
    if r['forward_success'] and r['memory_gb'] and r['memory_gb'] < best_memory:
        best_memory = r['memory_gb']
        best_impl = r['implementation']

print("\n" + "=" * 50)
if best_impl:
    print(f"RECOMMENDATION: Use attn_implementation='{best_impl}'")
    print(f"\nTo use in training:")
    print(f"  python cogex/train_cogex.py --attn-impl {best_impl} ...")
else:
    print("WARNING: No attention implementation worked!")
    print("Check GPU availability and model path.")
print("=" * 50)


SUMMARY

| Implementation | Load | Forward | Memory (GB) |
|----------------|------|---------|-------------|
| flash_attention_2 | ✓    | ✓       | 6.68        |
| sdpa           | ✓    | ✓       | 13.34       |
| eager          | ✓    | ✓       | 20.01       |

RECOMMENDATION: Use attn_implementation='flash_attention_2'

To use in training:
  python cogex/train_cogex.py --attn-impl flash_attention_2 ...
