Skip to content

cache: store StaticLayer.cumulative_length as a 0-dim scalar tensor#45997

Open
joaquinhuigomez wants to merge 1 commit into
huggingface:mainfrom
joaquinhuigomez:fix/static-cache-seq-length-scalar
Open

cache: store StaticLayer.cumulative_length as a 0-dim scalar tensor#45997
joaquinhuigomez wants to merge 1 commit into
huggingface:mainfrom
joaquinhuigomez:fix/static-cache-seq-length-scalar

Conversation

@joaquinhuigomez
Copy link
Copy Markdown
Contributor

StaticLayer.cumulative_length was initialised as torch.tensor([0]) (shape-(1,)), so StaticCache.get_seq_length() returned a shape-(1,) tensor instead of a value consistent with DynamicCache.get_seq_length(), which returns a plain int. The two cache types weren't safely interchangeable, and downstream int - past_len arithmetic promoted to shape-(1,) tensors that can propagate into slicing logic.

Storing the cumulative length as a 0-dim scalar tensor preserves the compile-friendly tensor semantics the existing comment calls out (the value still mutates in-place via add_() and is still attached to the static address for torch.compile), but makes arithmetic against ints behave the same as with DynamicCache.

Fixes #45987

StaticLayer.cumulative_length was initialised as torch.tensor([0])
(shape-(1,)), so StaticCache.get_seq_length() returned a shape-(1,)
tensor instead of a value consistent with DynamicCache.get_seq_length(),
which returns a plain int. The two cache types weren't safely
interchangeable, and downstream 'int - past_len' arithmetic
promoted to shape-(1,) tensors that propagated into slicing logic.

Storing the cumulative length as a 0-dim scalar tensor preserves the
compile-friendly tensor semantics that the existing comment calls
out (the value still mutates in-place via add_(), and is still
attached to the static address for torch.compile), but makes
arithmetic against ints behave the same as with DynamicCache.

Fixes huggingface#45987
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45997&sha=519287

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] StaticCache.get_seq_length() returns shape-(1,) Tensor despite -> int contract

1 participant