In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Check if Flash Attention backend is available (PyTorch 2.0+)
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
    print("✓ F.scaled_dot_product_attention is available")

    # Check available backends
    from torch.nn.attention import SDPBackend
    print(f"\nAvailable SDPA backends:")
    print(f"  FLASH_ATTENTION: {SDPBackend.FLASH_ATTENTION}")
    print(f"  EFFICIENT_ATTENTION: {SDPBackend.EFFICIENT_ATTENTION}")
    print(f"  MATH: {SDPBackend.MATH}")
else:
    print("✗ scaled_dot_product_attention NOT available (need PyTorch 2.0+)")

PyTorch version: 2.9.0+cu126
CUDA available: True
✓ F.scaled_dot_product_attention is available

Available SDPA backends:
  FLASH_ATTENTION: SDPBackend.FLASH_ATTENTION
  EFFICIENT_ATTENTION: SDPBackend.EFFICIENT_ATTENTION
  MATH: SDPBackend.MATH


In [None]:
import timm
from hydra import initialize, compose

print(f"timm version: {timm.__version__}")

with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="model/dino_v2")
    vit_cfg = cfg.model.vit

    # First, let's see what parameters are accepted
    print(f"\nModel architecture: {cfg.model.architecture}")

    # Try creating without attn_impl first
    try:
        model = timm.create_model(
            cfg.model.architecture,
            pretrained=False,
            num_classes=0,
            img_size=vit_cfg.image_size,
            patch_size=vit_cfg.patch_size,
            embed_dim=vit_cfg.embed_dim,
            depth=vit_cfg.depth,
            num_heads=vit_cfg.num_heads,
            mlp_ratio=vit_cfg.mlp_ratio,
            drop_path_rate=vit_cfg.drop_path_rate,
        )
        print("✓ Model created successfully")

        # Check what attention implementation is being used
        if hasattr(model, 'blocks'):
            first_block = model.blocks[0]
            if hasattr(first_block, 'attn'):
                attn = first_block.attn
                print(f"\nAttention module type: {type(attn).__name__}")
                print(f"Attention class: {attn.__class__.__module__}.{attn.__class__.__name__}")

                # Check if it has fused_attn attribute (indicates SDPA support)
                if hasattr(attn, 'fused_attn'):
                    print(f"fused_attn attribute: {attn.fused_attn}")

    except Exception as e:
        print(f"Error: {e}")

timm version: 1.0.22

Model architecture: vit_base_patch8_224
✓ Model created successfully

Attention module type: Attention
Attention class: timm.layers.attention.Attention
fused_attn attribute: True


In [None]:

print("=" * 80)
print("Testing Flash Attention with actual model")
print("=" * 80)
print("\nBoth settings will be applied below (as they are in your training code)")

with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="model/dino_v2")
    vit_cfg = cfg.model.vit

    model = timm.create_model(
        cfg.model.architecture,
        pretrained=False,
        num_classes=0,
        img_size=vit_cfg.image_size,
        patch_size=vit_cfg.patch_size,
        embed_dim=vit_cfg.embed_dim,
        depth=vit_cfg.depth,
        num_heads=vit_cfg.num_heads,
        mlp_ratio=vit_cfg.mlp_ratio,
        drop_path_rate=vit_cfg.drop_path_rate,
    )

    # Force enable fused_attn for all attention blocks
    for block in model.blocks:
        if hasattr(block.attn, 'fused_attn'):
            block.attn.fused_attn = True
            print(f"Enabled fused_attn for block")

    print(f"\nModel created with {len(model.blocks)} blocks")
    print(f"First block attention fused_attn: {model.blocks[0].attn.fused_attn if hasattr(model.blocks[0].attn, 'fused_attn') else 'N/A'}")

    # Test forward pass to ensure it works
    model = model.cuda()
    x = torch.randn(2, 3, vit_cfg.image_size, vit_cfg.image_size).cuda()

    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
            try:
                output = model(x)
                print(f"\n✓ Forward pass successful with Flash Attention context!")
                print(f"Output shape: {output.shape}")
            except Exception as e:
                print(f"\n✗ Forward pass failed: {e}")
                # Try with fallback
                with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True):
                    output = model(x)
                    print(f"✓ Forward pass successful with fallback enabled")
                    print(f"Output shape: {output.shape}")


Testing Flash Attention with actual model
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block
Enabled fused_attn for block

Model created with 12 blocks
First block attention fused_attn: True


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  self.gen = func(*args, **kwds)



✓ Forward pass successful with Flash Attention context!
Output shape: torch.Size([2, 384])


In [11]:
# Benchmark to verify Flash Attention is actually being used
print("=" * 80)
print("Benchmarking: Flash Attention vs Standard Attention")
print("=" * 80)

import time

with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="model/dino_v2")
    vit_cfg = cfg.model.vit

    model = timm.create_model(
        cfg.model.architecture,
        pretrained=False,
        num_classes=0,
        img_size=vit_cfg.image_size,
        patch_size=vit_cfg.patch_size,
        embed_dim=vit_cfg.embed_dim,
        depth=vit_cfg.depth,
        num_heads=vit_cfg.num_heads,
        mlp_ratio=vit_cfg.mlp_ratio,
        drop_path_rate=vit_cfg.drop_path_rate,
    ).cuda()

    # Enable fused attention
    for block in model.blocks:
        if hasattr(block.attn, 'fused_attn'):
            block.attn.fused_attn = True

    model.eval()

    # Create test input
    batch_size = 16
    x = torch.randn(batch_size, 3, vit_cfg.image_size, vit_cfg.image_size).cuda()

    # Warmup
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        for _ in range(5):
            _ = model(x)

    torch.cuda.synchronize()

    # Test 1: With Flash Attention
    print("\n[Test 1] With Flash Attention enforced:")
    times_flash = []
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
            for i in range(20):
                torch.cuda.synchronize()
                start = time.time()
                output = model(x)
                torch.cuda.synchronize()
                times_flash.append(time.time() - start)

    avg_flash = sum(times_flash[5:]) / len(times_flash[5:])  # Skip first 5 for warmup
    print(f"  Average time: {avg_flash*1000:.2f} ms")

    # Test 2: With standard math attention only
    print("\n[Test 2] With standard math attention only:")
    times_math = []
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            for i in range(20):
                torch.cuda.synchronize()
                start = time.time()
                output = model(x)
                torch.cuda.synchronize()
                times_math.append(time.time() - start)

    avg_math = sum(times_math[5:]) / len(times_math[5:])
    print(f"  Average time: {avg_math*1000:.2f} ms")

    # Compare
    print(f"\n{'='*80}")
    if avg_flash < avg_math:
        speedup = avg_math / avg_flash
        print(f"✓ Flash Attention IS WORKING! Speedup: {speedup:.2f}x faster")
        print(f"  Flash Attention: {avg_flash*1000:.2f} ms")
        print(f"  Standard Math:   {avg_math*1000:.2f} ms")
    else:
        print(f"⚠ Flash Attention may not be working properly")
        print(f"  Flash Attention: {avg_flash*1000:.2f} ms")
        print(f"  Standard Math:   {avg_math*1000:.2f} ms")
    print(f"{'='*80}")


Benchmarking: Flash Attention vs Standard Attention


  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):



[Test 1] With Flash Attention enforced:
  Average time: 3.73 ms

[Test 2] With standard math attention only:
  Average time: 5.88 ms

✓ Flash Attention IS WORKING! Speedup: 1.58x faster
  Flash Attention: 3.73 ms
  Standard Math:   5.88 ms
