Skip to content

Commit

Permalink
Fix mypy errors in flash_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke committed Oct 10, 2023
1 parent acb698e commit 3a7000e
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions jax/experimental/pallas/ops/tpu/flash_attention.py
Expand Up @@ -671,6 +671,8 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
pl.BlockSpec(ab_index_map, (block_b, 1, block_q, block_k_major))
if ab is not None else None)

q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None
if segment_ids is not None:

def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _):
Expand Down Expand Up @@ -714,9 +716,6 @@ def kv_segment_ids_index_map(
2,
),
)
else:
q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None

in_specs = [
pl.BlockSpec(q_index_map, (block_b, 1, block_q, head_dim)),
Expand Down Expand Up @@ -988,6 +987,8 @@ def ab_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
else None
)

q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None
if segment_ids is not None:

def q_segment_ids_index_map(
Expand Down Expand Up @@ -1033,9 +1034,6 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
2,
),
)
else:
q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None

in_specs = [
qo_spec,
Expand Down Expand Up @@ -1334,6 +1332,8 @@ def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
else None
)

q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None
if segment_ids is not None:

def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _):
Expand Down Expand Up @@ -1381,9 +1381,6 @@ def kv_segment_ids_index_map(
2,
),
)
else:
q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None

in_specs = [
qo_spec,
Expand Down

0 comments on commit 3a7000e

Please sign in to comment.