In [29]:
import torch
from einops import einsum


class FlashAttentionPytorch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V, is_causal=False):
        O = torch.zeros(N_q, d, device='cuda')
        L = torch.zeros((N_q, ), device='cuda')
        for i in range(0, N_q, B_q):
            Q_i = Q[i:i+B_q, :]
            O_i = torch.zeros_like(Q_i)
            l_i = torch.zeros(B_q, device='cuda')
            m_i = torch.full((B_q,), float('-inf'), device='cuda')
            for j in range(0, N_k, B_k):
                K_j = K[j:j+B_k, :]
                V_j = V[j:j+B_k, :]
                S_j = einsum(Q_i, K_j, 'B_q d, B_k d -> B_q B_k') / (d**0.5)
                assert S_j.shape == (B_q, B_k)

                m_curr = torch.max(torch.cat([S_j, m_i[:, None]], axis=-1), axis=-1).values
                print(m_curr)
                print(S_j)
                P_i = torch.exp(S_j - m_curr[:, None])
                assert P_i.shape == (B_q, B_k)

                l_i = torch.exp(m_i - m_curr)*l_i + torch.sum(P_i, axis=-1)
                
                _ = torch.diag(torch.exp(m_i - m_curr))
                O_i = einsum(_, O_i, 'B_q B_q, B_q d -> B_q d') + einsum(P_i, V_j, 'B_q B_k, B_k d -> B_q d')
                assert O_i.shape == (B_q, d)

                m_i = m_curr
            O_i = einsum(torch.diag(1 / l_i), O_i, 'B_q B_q, B_q d -> B_q d')
            L_i = m_i + torch.log(l_i)

            O[i:i+B_q] += O_i
            L[i:i+B_q] += L_i
        ctx.save_for_backward(Q, K, V, L, O)
        return O, L
    def backward(ctx):
        raise NotImplementedError
            



In [5]:
import torch
from einops import einsum
class FlashAttentionAutogradFunctionPytorch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V, is_causal=False):
        O = torch.zeros(N_q, d)
        L = torch.zeros(N_q, )
        for i in range(0, N_q, B_q):
            Q_i = Q[i:i+B_q, :]
            O_i = torch.zeros_like(Q_i)
            l_i = torch.zeros(B_q)
            m_i = torch.full((B_q,), float('-inf'))
            for j in range(0, N_k, B_k):
                K_j = K[j:j+B_k, :]
                V_j = V[j:j+B_k, :]
                S_j = torch.rsqrt(torch.tensor(d)) * einsum(Q_i, K_j, 'B_q d, B_k d -> B_q B_k')
                assert S_j.shape == (B_q, B_k)

                m_curr = torch.max(torch.cat([m_i[:, None], S_j], axis=-1), axis=-1).values
                P_i = torch.exp(S_j - m_curr[:, None])
                assert P_i.shape == (B_q, B_k)

                l_i = torch.exp(m_i - m_curr)*l_i + torch.sum(P_i, axis=-1)
                
                _ = torch.diag(torch.exp(m_i - m_curr))
                O_i = einsum(_, O_i, 'B_q B_q, B_q d -> B_q d') + einsum(P_i, V_j, 'B_q B_k, B_k d -> B_q d')
                assert O_i.shape == (B_q, d)

                m_i = m_curr
            O_i = einsum(torch.diag(1 / l_i), O_i, 'B_q B_q, B_q d -> B_q d')
            L_i = m_i + torch.log(l_i)

            O[i:i+B_q] += O_i
            L[i:i+B_q] += L_i
        ctx.save_for_backward(Q, K, V, L, O)
        return O, L
    def backward(ctx):
        raise NotImplementedError

In [18]:
from tests.test_attention import _attention_and_lse, _make_attn_inputs
impl = FlashAttentionAutogradFunctionPytorch.apply


In [None]:

q, k, v, _do = _make_attn_inputs(device)
o = impl(q, k, v, is_causal=False)

# Extract L from the saved tensors
assert o.grad_fn.saved_tensors is not None, "No saved tensors found in the output tensor. Make sure your autograd forward is saving them using ctx.save_for_backward."
maybe_ls = [t for t in o.grad_fn.saved_tensors if t.shape == (q.shape[0], q.shape[1])]

assert len(maybe_ls) == 1, f"Expected one tensor of shape {q.shape[0], q.shape[1]} in saved tensors, but found {len(maybe_ls)}. The tests require you to save exactly one tensor of this shape, corresponding to the log-sum-exp of the attention scores."
l = maybe_ls[0]

o_ref, l_ref = _attention_and_lse(q, k, v, is_causal)

torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(l, l_ref, rtol=1e-2, atol=1e-2)

In [None]:
from torch import Tensor
import timeit
from functools import partial
import torch

def f(x, repeat):
    for _ in range(repeat):
        x = x * 2
    return x

# Compile the function for better performance
f_compiled = torch.compile(f)

In [None]:
def benchmark(f_func, repeat, iters, sz=2**24, name="Function"):
    input_tensor = torch.randn(sz, device="cuda")
    
    # Warmup for compiled functions
    for _ in range(3):
        _ = f_func(input_tensor.clone(), repeat)
    torch.cuda.synchronize()
    
    time = timeit.timeit(lambda: f_func(input_tensor.clone(), repeat), number=iters)
    
    # For compiled functions, PyTorch may optimize memory access patterns
    # so we should measure actual memory bandwidth rather than theoretical
    flop = sz * repeat * iters
    
    # Conservative estimate: at minimum we need to read input once and write output once
    # But compiled version might fuse operations and reduce intermediate memory accesses
    memory_conservative = 4 * 2 * sz * iters  # Just input read + final write per benchmark iteration
    
    flops = flop / time
    mem_bd_conservative = memory_conservative / time
    
    print(f"=== {name} ===")
    print(f"Size: {sz:,} elements ({sz*4/1e6:.1f} MB)")
    print(f"Repeat: {repeat}, Iterations: {iters}")
    print(f"Time: {time:.4f} seconds")
    print(f"FLOPS: {flops/1e9:.2f} GFLOPS")
    print(f"Memory BW (conservative): {mem_bd_conservative/1e9:.2f} GB/s")
    print()
    
    return {
        "name": name,
        "time": time, 
        "flops": flops, 
        "memory_bandwidth_conservative": mem_bd_conservative,
    }


In [None]:
# Sweep across different repeat values (powers of 2)
repeat_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
iters = 10
sz = 2**22  # Smaller size for faster sweep

print("Running sweep across repeat values:\n")

# RTX 3060 specs for comparison
RTX_3060_PEAK_FLOPS = 13e12  # 13 TFLOPS
RTX_3060_MEMORY_BW = 360e9   # 360 GB/s

results = []
for repeat in repeat_values:
    print(f"Testing repeat = {repeat}")
    result = benchmark(f_compiled, repeat, iters, sz, f"Repeat-{repeat}")
    
    # Calculate utilization percentages
    flops_util = (result['flops'] / RTX_3060_PEAK_FLOPS) * 100
    mem_util = (result['memory_bandwidth_conservative'] / RTX_3060_MEMORY_BW) * 100
    
    print(f"  FLOPS Utilization: {flops_util:.2f}%")
    print(f"  Memory BW Utilization: {mem_util:.2f}%")
    print("-" * 50)
    
    results.append({
        'repeat': repeat,
        'time': result['time'],
        'flops': result['flops'],
        'memory_bw': result['memory_bandwidth_conservative'],
        'flops_util': flops_util,
        'mem_util': mem_util
    })

print(f"\nSummary of {len(results)} experiments completed!")

In [None]:
# Plot the results
import matplotlib.pyplot as plt
import numpy as np

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

repeats = [r['repeat'] for r in results]
times = [r['time'] for r in results]
flops = [r['flops']/1e9 for r in results]  # Convert to GFLOPS
mem_bw = [r['memory_bw']/1e9 for r in results]  # Convert to GB/s
flops_util = [r['flops_util'] for r in results]

# Runtime vs Repeat (log base 2 x-axis)
ax1.plot(repeats, times, 'b-o', linewidth=2)
ax1.set_xscale('log', base=2)
ax1.set_yscale('log')
ax1.set_xlabel('Number of Repeats')
ax1.set_ylabel('Runtime (seconds)')
ax1.set_title('Runtime vs Repeat Count')
ax1.grid(True, alpha=0.3)
ax1.set_xticks(repeats)
ax1.set_xticklabels(repeats)

# FLOPS vs Repeat (log base 2 x-axis)
ax2.plot(repeats, flops, 'g-o', linewidth=2)
ax2.set_xscale('log', base=2)
ax2.axhline(y=13000, color='r', linestyle='--', label='RTX 3060 Peak (13 TFLOPS)')
ax2.set_xlabel('Number of Repeats')
ax2.set_ylabel('GFLOPS')
ax2.set_title('FLOPS vs Repeat Count')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xticks(repeats)
ax2.set_xticklabels(repeats)

# Memory Bandwidth vs Repeat (log base 2 x-axis)
ax3.plot(repeats, mem_bw, 'purple', marker='o', linewidth=2)
ax3.set_xscale('log', base=2)
ax3.axhline(y=360, color='r', linestyle='--', label='RTX 3060 Peak (360 GB/s)')
ax3.set_xlabel('Number of Repeats')
ax3.set_ylabel('Memory Bandwidth (GB/s)')
ax3.set_title('Memory Bandwidth vs Repeat Count')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_xticks(repeats)
ax3.set_xticklabels(repeats)

# Utilization percentages (log base 2 x-axis)
ax4.plot(repeats, flops_util, 'g-o', label='FLOPS Utilization', linewidth=2)
ax4.plot(repeats, [r['mem_util'] for r in results], 'purple', marker='s', label='Memory BW Utilization', linewidth=2)
ax4.set_xscale('log', base=2)
ax4.set_xlabel('Number of Repeats')
ax4.set_ylabel('Utilization (%)')
ax4.set_title('Hardware Utilization vs Repeat Count')
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_xticks(repeats)
ax4.set_xticklabels(repeats)

plt.tight_layout()
plt.show()

# Print best performers
best_flops = max(results, key=lambda x: x['flops_util'])
best_memory = max(results, key=lambda x: x['mem_util'])

print(f"Best FLOPS utilization: {best_flops['flops_util']:.2f}% at repeat={best_flops['repeat']}")
print(f"Best Memory utilization: {best_memory['mem_util']:.2f}% at repeat={best_memory['repeat']}")