Skip to content

Commit

Permalink
cleanup test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 10, 2024
1 parent 66d2643 commit b539d13
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 92 deletions.
8 changes: 6 additions & 2 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def start(
causal,
striped_ring_attn,
dim,
dim_head,
use_cuda,
compare_regular_attn
):
Expand All @@ -51,7 +52,7 @@ def start(
dim = dim,
causal = causal,
depth = 2,
dim_head = 64,
dim_head = dim_head,
ring_attn = True,
striped_ring_attn = striped_ring_attn,
ring_seq_size = ring_seq_size,
Expand All @@ -63,7 +64,7 @@ def start(
dim = dim,
causal = causal,
depth = 2,
dim_head = 64,
dim_head = dim_head,
ring_attn = False,
ring_seq_size = ring_seq_size,
bucket_size = bucket_size,
Expand Down Expand Up @@ -142,6 +143,7 @@ def start(
@click.option('--num-buckets', default = 2, help = 'number of buckets per machine (each sharded sequence is further windowed for flash attention to achieve even greater context lengths)')
@click.option('--seq-len', default = 31, help = 'sequence length to test')
@click.option('--model-dim', default = 8, help = 'model dimensions for testing')
@click.option('--dim-head', default = 16, help = 'attention head dimension')
@click.option('--compare-regular-attn', is_flag = True, help = 'compare ring to regular attention')
def test(
world_size: int,
Expand All @@ -154,6 +156,7 @@ def test(
num_buckets: int,
seq_len: int,
model_dim: int,
dim_head: int,
compare_regular_attn: bool
):
assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}'
Expand All @@ -170,6 +173,7 @@ def test(
causal,
striped_ring_attn,
model_dim,
dim_head,
use_cuda,
compare_regular_attn
),
Expand Down
12 changes: 8 additions & 4 deletions assert_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def start(
causal,
striped_ring_attn,
dim,
dim_head,
use_cuda,
compare_regular_attn
):
Expand All @@ -49,7 +50,7 @@ def start(
ring_attention = RingAttention(
dim = dim,
causal = causal,
dim_head = 8,
dim_head = dim_head,
ring_attn = True,
striped_ring_attn = striped_ring_attn,
ring_seq_size = ring_seq_size,
Expand All @@ -61,7 +62,7 @@ def start(
flash_attention = RingAttention(
dim = dim,
causal = causal,
dim_head = 8,
dim_head = dim_head,
ring_attn = False,
ring_seq_size = ring_seq_size,
bucket_size = bucket_size,
Expand All @@ -80,8 +81,8 @@ def start(

if use_cuda:
seq = seq.cuda(rank)
flash_attention_net.cuda(rank)
ring_attention_net.cuda(rank)
flash_attention.cuda(rank)
ring_attention.cuda(rank)

# separate inputs for ring vs flash

Expand Down Expand Up @@ -144,6 +145,7 @@ def start(
@click.option('--num-buckets', default = 2, help = 'number of buckets per machine (each sharded sequence is further windowed for flash attention to achieve even greater context lengths)')
@click.option('--seq-len', default = 31, help = 'sequence length to test')
@click.option('--model-dim', default = 8, help = 'model dimensions for testing')
@click.option('--dim-head', default = 16, help = 'model dimensions for testing')
@click.option('--compare-regular-attn', is_flag = True, help = 'compare ring to regular attention')
def test(
world_size: int,
Expand All @@ -156,6 +158,7 @@ def test(
num_buckets: int,
seq_len: int,
model_dim: int,
dim_head: int,
compare_regular_attn: bool
):
assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}'
Expand All @@ -172,6 +175,7 @@ def test(
causal,
striped_ring_attn,
model_dim,
dim_head,
use_cuda,
compare_regular_attn
),
Expand Down
10 changes: 7 additions & 3 deletions assert_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@
@click.command()
@click.option('--causal', is_flag = True)
@click.option('--seq-len', default = 62)
@click.option('--dim-head', default = 16)
@click.option('--heads', default = 2)
@click.option('--bucket_size', default = 4)
@click.option('--flash-cuda-kernel', is_flag = True)
def test(
causal: bool,
seq_len: int,
dim_head: int,
heads: int,
bucket_size: int,
flash_cuda_kernel: bool
):
# base qkv

q = torch.randn(2, seq_len, 2, 16)
k = torch.randn(2, seq_len, 2, 16)
v = torch.randn(2, seq_len, 2, 16)
q = torch.randn(2, seq_len, heads, dim_head)
k = torch.randn(2, seq_len, heads, dim_head)
v = torch.randn(2, seq_len, heads, dim_head)

# flash and regular qkv's

Expand Down
83 changes: 0 additions & 83 deletions ring_attention_pytorch/triton_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,89 +540,6 @@ def _bwd_kernel(
BLOCK_N=BLOCK_N,
)


def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
# shape constraints
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d)
assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)

has_bias = bias is not None
bias_type = "none"
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
if bias.stride(-1) != 1:
bias = bias.contiguous()
if bias.shape[2:] == (1, seqlen_k):
bias_type = "vector"
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = "matrix"
else:
raise RuntimeError(
"Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
o = torch.empty_like(q)

BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
BLOCK = 128
num_warps = 4 if d <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_kernel[grid](
q,
k,
v,
bias,
o,
lse,
tmp,
softmax_scale,
q.stride(0),
q.stride(2),
q.stride(1),
k.stride(0),
k.stride(2),
k.stride(1),
v.stride(0),
v.stride(2),
v.stride(1),
*bias_strides,
o.stride(0),
o.stride(2),
o.stride(1),
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
d,
seqlen_q // 32,
seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type,
causal,
BLOCK_HEADDIM,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return o, lse, softmax_scale # softmax_scale could have been updated


def _flash_attn_backward(
do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, causal_mask_diagonal=False, softmax_scale=None
):
Expand Down

0 comments on commit b539d13

Please sign in to comment.