Skip to content

Fix StaticLayer.get_seq_length return type annotation (#45987)#46173

Open
Sanjays2402 wants to merge 1 commit into
huggingface:mainfrom
Sanjays2402:fix/get-seq-length-return-type-45987
Open

Fix StaticLayer.get_seq_length return type annotation (#45987)#46173
Sanjays2402 wants to merge 1 commit into
huggingface:mainfrom
Sanjays2402:fix/get-seq-length-return-type-45987

Conversation

@Sanjays2402
Copy link
Copy Markdown

Fixes #45987.

StaticLayer.cumulative_length is initialized as a shape-(1,) torch.Tensor (so that in-place .add_() updates stay torch.compile-friendly), and StaticLayer.get_seq_length() returns it directly. The current -> int annotation lies about that.

Per @Rocketknight1's comment on the issue:

I think this is mostly just a type-hint bug, right? There isn't really much difference between a tensor with shape (1,) and a 0-dim tensor, they'll both broadcast with any other tensor and work with item() etc. Also, we generally want to avoid item() where possible because it causes a CUDA sync. So maybe just make the return type int | Tensor or something but leave the rest of the code alone?

This PR does exactly that — minimal annotation-only change:

  • CacheLayerMixin.get_seq_length abstract: -> int-> int | torch.Tensor
  • StaticLayer.get_seq_length: -> int-> int | torch.Tensor, plus a docstring note explaining why a tensor is returned.

No runtime behavior change. No .item() calls added. Other concrete get_seq_length overrides (DynamicLayer, StaticSlidingWindowLayer, QuantizedLayer, ...) already return ints and don't need touching.

The four earlier PRs for this issue (#46005, #46010, #46081, #45997) all attempted heavier rewrites that added .item() or restructured cumulative_length, which is the opposite of the requested fix. This PR sticks to the maintainer-specified shape.

Before submitting

Who can review?

@Rocketknight1 (you commented on #45987 with the requested approach)

)

`StaticLayer.cumulative_length` is a shape-(1,) `torch.Tensor`, so
`StaticLayer.get_seq_length()` actually returns a `Tensor`, not the
`int` its annotation promised. Per maintainer guidance on huggingface#45987,
calling `.item()` to coerce would force a CUDA sync and isn't worth it
when the tensor broadcasts identically to an int at every call site. The
right fix is just to relax the type annotation.

- Update `CacheLayerMixin.get_seq_length` abstract signature to
  `int | torch.Tensor`.
- Update `StaticLayer.get_seq_length` to match, with a docstring note
  explaining why a tensor is returned.
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46173&sha=ece191

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] StaticCache.get_seq_length() returns shape-(1,) Tensor despite -> int contract

1 participant