Skip to content

Conversation

@mctigger
Copy link
Owner

Fix TensorContainer Indexing Logic and Enhance Type Safety

Summary

This pull request addresses critical indexing bugs in the TensorContainer system and introduces comprehensive type safety improvements. The changes resolve issues with __setitem__ operations, enhance ellipsis handling, and significantly expand test coverage for all PyTorch indexing patterns.

Key Changes

Type Safety

  • Added IndexType type alias covering all PyTorch indexing patterns
  • Consolidated method signatures using IndexType for consistent type annotations
  • Simplified TensorDict indexing overloads

Bug Fixes

  • Fixed __setitem__ logic for non-tuple indices causing assignment failures
  • Corrected ellipsis processing order in conditional structure
  • Removed problematic scalar promotion interfering with PyTorch's cross-device assignment behavior

Documentation

  • Added comprehensive docstrings for get_number_of_consuming_dims and transform_ellipsis_index
  • Included examples explaining ellipsis expansion and batch/event dimension semantics
  • Documented boolean indexing behavior and dimension consumption

Testing

  • Expanded test coverage with systematic categories: basic, advanced, boolean, multi-dimensional, and edge cases
  • Added cross-device assignment tests demonstrating PyTorch's native behavior
  • Improved test parametrization and validation logic

Backward Compatibility

All changes maintain backward compatibility. The modifications improve correctness and type safety without breaking existing functionality.

Tim Joseph added 7 commits September 11, 2025 19:54
The __setitem__ method in TensorContainer had the indexing loop incorrectly indented inside the tuple check, causing it to skip for non-tuple indices. Moved the loop outside the if block to ensure it runs for all index types.

Expanded test cases in test_setitem.py to cover basic, advanced, boolean, multi-dimensional, and edge-case indices for more comprehensive validation.
…r type safety

- Add IndexType union type to cover all PyTorch indexing patterns
- Update __getitem__ and __setitem__ to use IndexType instead of Any
- Fix isinstance check in __setitem__ to use original index variable
Add comprehensive tests for TensorDict slice assignments across different devices (CPU to CUDA and vice versa), demonstrating PyTorch's automatic device transfer for slice operations while highlighting stricter requirements for boolean mask assignments. This ensures device consistency and validates expected behavior in cross-device scenarios.
…ype for type safety

Simplify __getitem__ and __setitem__ method signatures by replacing multiple overloads with a unified IndexType, improving code maintainability and reducing redundancy while preserving existing functionality. This change builds on prior fixes to indexing logic.
…e assignments

There seems to be no reason, why we would need scalar promotion here.
…indexing methods

Added detailed Google-style docstrings to `get_number_of_consuming_dims` and `transform_ellipsis_index`.
These docstrings clarify the purpose, arguments, returns, and behavior of these critical indexing methods.
`transform_ellipsis_index` is particularly important for ensuring consistent ellipsis expansion across
tensors within the container, and its documentation now includes a comprehensive example
to illustrate its necessity and function. Inline comments were also added to
`transform_ellipsis_index` to explain the step-by-step logic.
@mctigger mctigger changed the title Fix setitem Fix __setitem__ in TensorContainer Sep 12, 2025
The `types.EllipsisType` was introduced in Python 3.10. This change ensures that `EllipsisType` is correctly defined and available when running on Python 3.9 environments, preventing import errors.
@mctigger mctigger merged commit aa46357 into main Sep 12, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants