From 94084a3e5d655c0cd1b625582ca533810bcd8b7b Mon Sep 17 00:00:00 2001 From: bgreni Date: Mon, 22 Apr 2024 18:49:45 -0600 Subject: [PATCH] Use __index__ for __getitem__ and __setitem__ When indexing stdlib containers we should accept a generic type that calls on the __index__ method to allow types other than Int to be used but doesn't allow Intable types that should not be used for such purposes (such as Float) Signed-off-by: Brian Grenier --- stdlib/src/builtin/bool.mojo | 16 ++++++++++- stdlib/src/builtin/builtin_list.mojo | 10 ++++--- stdlib/src/builtin/builtin_slice.mojo | 4 +-- stdlib/src/builtin/int.mojo | 2 +- stdlib/src/builtin/int_literal.mojo | 2 +- stdlib/src/builtin/range.mojo | 12 ++++---- stdlib/src/builtin/simd.mojo | 40 ++++++++++++++++++--------- stdlib/src/builtin/value.mojo | 10 +++++++ stdlib/src/collections/vector.mojo | 16 +++++------ stdlib/src/memory/unsafe.mojo | 12 ++++---- stdlib/src/python/object.mojo | 8 +++++- stdlib/src/utils/index.mojo | 8 +++--- stdlib/src/utils/static_tuple.mojo | 14 +++++----- stdlib/src/utils/stringref.mojo | 4 +-- 14 files changed, 102 insertions(+), 56 deletions(-) diff --git a/stdlib/src/builtin/bool.mojo b/stdlib/src/builtin/bool.mojo index fb2eaef30..eca9585a9 100644 --- a/stdlib/src/builtin/bool.mojo +++ b/stdlib/src/builtin/bool.mojo @@ -47,7 +47,12 @@ trait Boolable: @value @register_passable("trivial") struct Bool( - Stringable, CollectionElement, Boolable, EqualityComparable, Intable + Stringable, + CollectionElement, + Boolable, + EqualityComparable, + Intable, + Indexer, ): """The primitive Bool scalar value used in Mojo.""" @@ -261,6 +266,15 @@ struct Bool( ) ) + @always_inline("nodebug") + fn __index__(self) -> Int: + """Convert this Bool to an integer for indexing purposes + + Returns: + Bool as Int + """ + return self.__int__() + @always_inline fn bool(value: None) -> Bool: diff --git a/stdlib/src/builtin/builtin_list.mojo b/stdlib/src/builtin/builtin_list.mojo index ab5aab3bc..67c7d659b 100644 --- a/stdlib/src/builtin/builtin_list.mojo +++ b/stdlib/src/builtin/builtin_list.mojo @@ -138,7 +138,7 @@ struct VariadicList[type: AnyRegType](Sized): return __mlir_op.`pop.variadic.size`(self.value) @always_inline - fn __getitem__(self, index: Int) -> type: + fn __getitem__[indexer: Indexer](self, index: indexer) -> type: """Gets a single element on the variadic list. Args: @@ -147,7 +147,7 @@ struct VariadicList[type: AnyRegType](Sized): Returns: The element on the list corresponding to the given index. """ - return __mlir_op.`pop.variadic.get`(self.value, index.value) + return __mlir_op.`pop.variadic.get`(self.value, index.__index__().value) @always_inline fn __iter__(self) -> Self.IterType: @@ -344,7 +344,9 @@ struct VariadicListMem[ # TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol # allowing us to get rid of this and make foreach iteration clean. @always_inline - fn __getitem__(self, index: Int) -> Self.reference_type: + fn __getitem__[ + indexer: Indexer + ](self, index: indexer) -> Self.reference_type: """Gets a single element on the variadic list. Args: @@ -355,7 +357,7 @@ struct VariadicListMem[ given index. """ return Self.reference_type( - __mlir_op.`pop.variadic.get`(self.value, index.value) + __mlir_op.`pop.variadic.get`(self.value, index.__index__().value) ) @always_inline diff --git a/stdlib/src/builtin/builtin_slice.mojo b/stdlib/src/builtin/builtin_slice.mojo index b6da005a8..ae00b5558 100644 --- a/stdlib/src/builtin/builtin_slice.mojo +++ b/stdlib/src/builtin/builtin_slice.mojo @@ -155,7 +155,7 @@ struct Slice(Sized, Stringable, EqualityComparable): return len(range(self.start, self.end, self.step)) @always_inline - fn __getitem__(self, idx: Int) -> Int: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: """Get the slice index. Args: @@ -164,7 +164,7 @@ struct Slice(Sized, Stringable, EqualityComparable): Returns: The slice index. """ - return self.start + idx * self.step + return self.start + idx.__index__() * self.step @always_inline("nodebug") fn _has_end(self) -> Bool: diff --git a/stdlib/src/builtin/int.mojo b/stdlib/src/builtin/int.mojo index 982fbbc58..5523c645c 100644 --- a/stdlib/src/builtin/int.mojo +++ b/stdlib/src/builtin/int.mojo @@ -169,7 +169,7 @@ fn int[T: IntableRaising](value: T) raises -> Int: @lldb_formatter_wrapping_type @value @register_passable("trivial") -struct Int(Intable, Stringable, KeyElement, Boolable): +struct Int(Intable, Stringable, KeyElement, Boolable, Indexer): """This type represents an integer value.""" var value: __mlir_type.index diff --git a/stdlib/src/builtin/int_literal.mojo b/stdlib/src/builtin/int_literal.mojo index 8b9be688a..680a245f2 100644 --- a/stdlib/src/builtin/int_literal.mojo +++ b/stdlib/src/builtin/int_literal.mojo @@ -16,7 +16,7 @@ @value @nonmaterializable(Int) @register_passable("trivial") -struct IntLiteral(Intable, Stringable, Boolable, EqualityComparable): +struct IntLiteral(Intable, Stringable, Boolable, EqualityComparable, Indexer): """This type represents a static integer literal value with infinite precision. They can't be materialized at runtime and must be lowered to other integer types (like Int), but allow for diff --git a/stdlib/src/builtin/range.mojo b/stdlib/src/builtin/range.mojo index edb58126c..54c63a6fd 100644 --- a/stdlib/src/builtin/range.mojo +++ b/stdlib/src/builtin/range.mojo @@ -92,8 +92,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange): return self.curr @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return idx + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return idx.__index__() @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: @@ -121,8 +121,8 @@ struct _SequentialRange(Sized, ReversibleRange): return _max(0, self.end - self.start) @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return self.start + idx + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return self.start + idx.__index__() @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: @@ -192,8 +192,8 @@ struct _StridedRange(Sized, ReversibleRange): return _div_ceil_positive(_abs(self.start - self.end), _abs(self.step)) @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return self.start + idx * self.step + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return self.start + idx.__index__() * self.step @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: diff --git a/stdlib/src/builtin/simd.mojo b/stdlib/src/builtin/simd.mojo index 70741969b..7cb6a1d8c 100644 --- a/stdlib/src/builtin/simd.mojo +++ b/stdlib/src/builtin/simd.mojo @@ -113,12 +113,7 @@ fn _unchecked_zero[type: DType, size: Int]() -> SIMD[type, size]: @lldb_formatter_wrapping_type @register_passable("trivial") struct SIMD[type: DType, size: Int = simdwidthof[type]()]( - Sized, - Intable, - CollectionElement, - Stringable, - Hashable, - Boolable, + Sized, Intable, CollectionElement, Stringable, Hashable, Boolable, Indexer ): """Represents a small vector that is backed by a hardware vector element. @@ -513,6 +508,19 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( rebind[Scalar[type]](self).value ) + @always_inline("nodebug") + fn __index__(self) -> Int: + """Returns the value as an int if it is an integral value + + Contraints: + Must be an integral value + + Returns: + The value as an integer + """ + constrained[type.is_integral(), "expected integral type"]() + return self.__int__() + @always_inline fn __str__(self) -> String: """Get the SIMD as a string. @@ -1551,7 +1559,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( # ===-------------------------------------------------------------------===# @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Scalar[type]: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]: """Gets an element from the vector. Args: @@ -1562,10 +1570,12 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( """ return __mlir_op.`pop.simd.extractelement`[ _type = __mlir_type[`!pop.scalar<`, type.value, `>`] - ](self.value, idx.value) + ](self.value, idx.__index__().value) @always_inline("nodebug") - fn __setitem__(inout self, idx: Int, val: Scalar[type]): + fn __setitem__[ + indexer: Indexer + ](inout self, idx: indexer, val: Scalar[type]): """Sets an element in the vector. Args: @@ -1573,12 +1583,16 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( val: The value to set. """ self.value = __mlir_op.`pop.simd.insertelement`( - self.value, val.value, idx.value + self.value, val.value, idx.__index__().value ) @always_inline("nodebug") - fn __setitem__( - inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`] + fn __setitem__[ + indexer: Indexer + ]( + inout self, + idx: indexer, + val: __mlir_type[`!pop.scalar<`, type.value, `>`], ): """Sets an element in the vector. @@ -1587,7 +1601,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( val: The value to set. """ self.value = __mlir_op.`pop.simd.insertelement`( - self.value, val, idx.value + self.value, val, idx.__index__().value ) fn __hash__(self) -> Int: diff --git a/stdlib/src/builtin/value.mojo b/stdlib/src/builtin/value.mojo index a1e706773..b87028aed 100644 --- a/stdlib/src/builtin/value.mojo +++ b/stdlib/src/builtin/value.mojo @@ -121,3 +121,13 @@ trait StringableCollectionElement(CollectionElement, Stringable): """ pass + + +trait Indexer: + fn __index__(self) -> Int: + """Return the index value + + Returns: + The index value of the object + """ + ... diff --git a/stdlib/src/collections/vector.mojo b/stdlib/src/collections/vector.mojo index 51a9ff963..9b22f283f 100644 --- a/stdlib/src/collections/vector.mojo +++ b/stdlib/src/collections/vector.mojo @@ -182,7 +182,7 @@ struct InlinedFixedVector[ return self.current_size @always_inline - fn __getitem__(self, i: Int) -> type: + fn __getitem__[indexer: Indexer](self, i: indexer) -> type: """Gets a vector element at the given index. Args: @@ -191,12 +191,12 @@ struct InlinedFixedVector[ Returns: The element at the given index. """ + var normalized_idx = i.__index__() debug_assert( - -self.current_size <= i < self.current_size, + -self.current_size <= normalized_idx < self.current_size, "index must be within bounds", ) - var normalized_idx = i - if i < 0: + if normalized_idx < 0: normalized_idx += len(self) if normalized_idx < Self.static_size: @@ -205,20 +205,20 @@ struct InlinedFixedVector[ return self.dynamic_data[normalized_idx - Self.static_size] @always_inline - fn __setitem__(inout self, i: Int, value: type): + fn __setitem__[indexer: Indexer](inout self, i: indexer, value: type): """Sets a vector element at the given index. Args: i: The index of the element. value: The value to assign. """ + var normalized_idx = i.__index__() debug_assert( - -self.current_size <= i < self.current_size, + -self.current_size <= normalized_idx < self.current_size, "index must be within bounds", ) - var normalized_idx = i - if i < 0: + if normalized_idx < 0: normalized_idx += len(self) if normalized_idx < Self.static_size: diff --git a/stdlib/src/memory/unsafe.mojo b/stdlib/src/memory/unsafe.mojo index 08843cae0..6ce026ee6 100644 --- a/stdlib/src/memory/unsafe.mojo +++ b/stdlib/src/memory/unsafe.mojo @@ -288,7 +288,7 @@ struct LegacyPointer[ ) @always_inline("nodebug") - fn __getitem__[T: Intable](self, offset: T) -> type: + fn __getitem__[T: Indexer](self, offset: T) -> type: """Loads the value the LegacyPointer object points to with the given offset. Parameters: @@ -300,10 +300,10 @@ struct LegacyPointer[ Returns: The loaded value. """ - return self.load(offset) + return self.load(offset.__index__()) @always_inline("nodebug") - fn __setitem__[T: Intable](self, offset: T, val: type): + fn __setitem__[T: Indexer](self, offset: T, val: type): """Stores the specified value to the location the LegacyPointer object points to with the given offset. @@ -314,7 +314,7 @@ struct LegacyPointer[ offset: The offset to store to. val: The value to store. """ - return self.store(offset, val) + return self.store(offset.__index__(), val) # ===------------------------------------------------------------------=== # # Load/Store @@ -727,7 +727,7 @@ struct DTypePointer[ return arg.get_legacy_pointer() @always_inline("nodebug") - fn __getitem__[T: Intable](self, offset: T) -> Scalar[type]: + fn __getitem__[T: Indexer](self, offset: T) -> Scalar[type]: """Loads a single element (SIMD of size 1) from the pointer at the specified index. @@ -740,7 +740,7 @@ struct DTypePointer[ Returns: The loaded value. """ - return self.load(offset) + return self.load(offset.__index__()) @always_inline("nodebug") fn __setitem__[T: Intable](self, offset: T, val: Scalar[type]): diff --git a/stdlib/src/python/object.mojo b/stdlib/src/python/object.mojo index 44bd922bc..151984b99 100644 --- a/stdlib/src/python/object.mojo +++ b/stdlib/src/python/object.mojo @@ -101,7 +101,13 @@ struct _PyIter(Sized): @register_passable struct PythonObject( - Intable, Stringable, SizedRaising, Boolable, CollectionElement, KeyElement + Intable, + Stringable, + SizedRaising, + Boolable, + CollectionElement, + KeyElement, + Indexer, ): """A Python object.""" diff --git a/stdlib/src/utils/index.mojo b/stdlib/src/utils/index.mojo index 163d2a700..657c36095 100644 --- a/stdlib/src/utils/index.mojo +++ b/stdlib/src/utils/index.mojo @@ -335,11 +335,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable): return size @always_inline("nodebug") - fn __getitem__[intable: Intable](self, index: intable) -> Int: + fn __getitem__[indexer: Indexer](self, index: indexer) -> Int: """Gets an element from the tuple by index. Parameters: - intable: The intable type. + indexer: The index type. Args: index: The element index. @@ -362,11 +362,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable): self.data.__setitem__[index](val) @always_inline("nodebug") - fn __setitem__[intable: Intable](inout self, index: intable, val: Int): + fn __setitem__[indexer: Indexer](inout self, index: indexer, val: Int): """Sets an element in the tuple at the given index. Parameters: - intable: The intable type. + indexer: The index type. Args: index: The element index. diff --git a/stdlib/src/utils/static_tuple.mojo b/stdlib/src/utils/static_tuple.mojo index 6b2e0d233..cb1e8e596 100644 --- a/stdlib/src/utils/static_tuple.mojo +++ b/stdlib/src/utils/static_tuple.mojo @@ -197,11 +197,11 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): self = tmp @always_inline("nodebug") - fn __getitem__[intable: Intable](self, index: intable) -> Self.element_type: + fn __getitem__[indexer: Indexer](self, index: indexer) -> Self.element_type: """Returns the value of the tuple at the given dynamic index. Parameters: - intable: The intable type. + indexer: The index type. Args: index: The index into the tuple. @@ -209,7 +209,7 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): Returns: The value at the specified position. """ - var offset = int(index) + var offset = index.__index__() debug_assert(offset < size, "index must be within bounds") # Copy the array so we can get its address, because we can't take the # address of 'self' in a non-mutating method. @@ -221,18 +221,18 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): @always_inline("nodebug") fn __setitem__[ - intable: Intable - ](inout self, index: intable, val: Self.element_type): + indexer: Indexer + ](inout self, index: indexer, val: Self.element_type): """Stores a single value into the tuple at the specified dynamic index. Parameters: - intable: The intable type. + indexer: The intable type. Args: index: The index into the tuple. val: The value to store. """ - var offset = int(index) + var offset = index.__index__() debug_assert(offset < size, "index must be within bounds") var tmp = self var ptr = __mlir_op.`pop.array.gep`( diff --git a/stdlib/src/utils/stringref.mojo b/stdlib/src/utils/stringref.mojo index 9187ad8f0..35f5e96b8 100644 --- a/stdlib/src/utils/stringref.mojo +++ b/stdlib/src/utils/stringref.mojo @@ -194,7 +194,7 @@ struct StringRef( return not (self == rhs) @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> StringRef: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> StringRef: """Get the string value at the specified position. Args: @@ -203,7 +203,7 @@ struct StringRef( Returns: The character at the specified position. """ - return StringRef {data: self.data + idx, length: 1} + return StringRef {data: self.data + idx.__index__(), length: 1} fn __hash__(self) -> Int: """Hash the underlying buffer using builtin hash.