Skip to content

Conversation

kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Oct 6, 2025

Purpose

Given an attention state of shape (batch_size, num_heads, seq_len, head_dim), the head attention strategy will generate scales of shape (num_heads, 1, 1).

Prerequisites

Changes

  • Add attention head quantization strategy
  • Fix shapes of per-tensor attention flattening
  • Elaborate on attention calibration tests

Testing

  • Added attention head quantization test and validated that generated scales and zero points make sense

@kylesayrs kylesayrs changed the base branch from main to kylesayrs/refactor-initialize-tests October 6, 2025 22:22
@kylesayrs kylesayrs force-pushed the kylesayrs/add-attn-head-strat branch 2 times, most recently from bf00a99 to 2ea692d Compare October 6, 2025 22:28
@kylesayrs kylesayrs force-pushed the kylesayrs/refactor-initialize-tests branch from 97a4d16 to 0fdfbd1 Compare October 6, 2025 22:31
@kylesayrs kylesayrs force-pushed the kylesayrs/add-attn-head-strat branch from 2ea692d to 326f802 Compare October 6, 2025 22:31
@kylesayrs kylesayrs force-pushed the kylesayrs/refactor-initialize-tests branch from 0fdfbd1 to 8973328 Compare October 7, 2025 22:05
@kylesayrs kylesayrs force-pushed the kylesayrs/add-attn-head-strat branch 2 times, most recently from 70da261 to 48875e2 Compare October 7, 2025 22:13
@kylesayrs kylesayrs marked this pull request as ready for review October 7, 2025 22:15
Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are we expecting users to use QuantizationStrategy.ATTN_HEAD in a recipe? If i'm understanding correctly, it would look something like this?

quant_stage:
  quant_modifiers:
    QuantizationModifier:
      config_groups:
        group0:
          targets: ["re:.*self_attn$"]
          weights:
            strategy: attn_head
            ...
        group1:
          targets: ["re:.*(q|k|v)_proj$"]
          weights:
            strategy: group
            ...

@kylesayrs
Copy link
Contributor Author

kylesayrs commented Oct 7, 2025

@brian-dellabetta I’ve decided that giving per-attention strategy its own strategy (rather than reusing group) makes more sense.

quant_stage:
  quant_modifiers:
    QuantizationModifier:
      config_groups:
        group0:
          targets: ["re:.*self_attn$"]
          input_activations:
            strategy: attn_head
            ...

Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall format LGTM, but i'm struggling with understanding how we're arriving at some of these expected_shapes

@kylesayrs kylesayrs force-pushed the kylesayrs/add-attn-head-strat branch from 48875e2 to e1ca4fd Compare October 8, 2025 18:44
Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating!

Base automatically changed from kylesayrs/refactor-initialize-tests to main October 9, 2025 13:20
@kylesayrs kylesayrs dismissed brian-dellabetta’s stale review October 9, 2025 13:20

The base branch was changed.

@kylesayrs kylesayrs changed the base branch from main to transform_arg_support October 9, 2025 13:21
@kylesayrs kylesayrs changed the base branch from transform_arg_support to main October 9, 2025 13:21
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs kylesayrs force-pushed the kylesayrs/add-attn-head-strat branch from d084c5e to e3f24d4 Compare October 9, 2025 14:19
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Copy link
Contributor

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Do we need logic somewhere to reverse flatten_attention_for_quantization? It seems like it would be important to be sure that the unflattening process is implemented in parallel with the flattening function.

@kylesayrs
Copy link
Contributor Author

@fynnsu The inverse function would require extra metadata (for example, unflattening (batch_size * seq_len) requires knowing either batch_size or seq_len).

Calibration only requires the forward function. Implementing the backwards function would allow us to share the util across calibration and quantization forward. This might be nice for standardization and potentially faster runtime, but isn't high priority right now.

@kylesayrs kylesayrs merged commit 3e4f164 into main Oct 9, 2025
2 checks passed
@kylesayrs kylesayrs deleted the kylesayrs/add-attn-head-strat branch October 9, 2025 20:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants