From 3a7000ee1e111b03551dc4377092134a0203e382 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 10 Oct 2023 13:05:48 +0000 Subject: [PATCH] Fix mypy errors in flash_attention --- .../pallas/ops/tpu/flash_attention.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 6c7f791e88b6..7fbf006626d6 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -671,6 +671,8 @@ def lm_index_map(batch_index, head_index, q_seq_index, _): pl.BlockSpec(ab_index_map, (block_b, 1, block_q, block_k_major)) if ab is not None else None) + q_segment_ids_spec = kv_segment_ids_spec = None + q_segment_ids = kv_segment_ids = None if segment_ids is not None: def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _): @@ -714,9 +716,6 @@ def kv_segment_ids_index_map( 2, ), ) - else: - q_segment_ids_spec = kv_segment_ids_spec = None - q_segment_ids = kv_segment_ids = None in_specs = [ pl.BlockSpec(q_index_map, (block_b, 1, block_q, head_dim)), @@ -988,6 +987,8 @@ def ab_index_map(batch_index, head_index, kv_seq_index, q_seq_index): else None ) + q_segment_ids_spec = kv_segment_ids_spec = None + q_segment_ids = kv_segment_ids = None if segment_ids is not None: def q_segment_ids_index_map( @@ -1033,9 +1034,6 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _): 2, ), ) - else: - q_segment_ids_spec = kv_segment_ids_spec = None - q_segment_ids = kv_segment_ids = None in_specs = [ qo_spec, @@ -1334,6 +1332,8 @@ def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index): else None ) + q_segment_ids_spec = kv_segment_ids_spec = None + q_segment_ids = kv_segment_ids = None if segment_ids is not None: def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _): @@ -1381,9 +1381,6 @@ def kv_segment_ids_index_map( 2, ), ) - else: - q_segment_ids_spec = kv_segment_ids_spec = None - q_segment_ids = kv_segment_ids = None in_specs = [ qo_spec,