From 06e9052ae11dbb8d071b437b9415a1c9e211ec77 Mon Sep 17 00:00:00 2001 From: bgreni Date: Mon, 22 Apr 2024 18:49:45 -0600 Subject: [PATCH] Use __index__ for __getitem__ and __setitem__ When indexing stdlib containers we should accept a generic type that calls on the __index__ method to allow types other than Int to be used but doesn't allow Intable types that should not be used for such purposes (such as Float) Signed-off-by: Brian Grenier --- stdlib/src/builtin/bool.mojo | 21 +++++-- stdlib/src/builtin/builtin_list.mojo | 18 ++++-- stdlib/src/builtin/builtin_slice.mojo | 7 ++- stdlib/src/builtin/int.mojo | 1 + stdlib/src/builtin/int_literal.mojo | 1 + stdlib/src/builtin/range.mojo | 12 ++-- stdlib/src/builtin/simd.mojo | 46 +++++++++++--- stdlib/src/builtin/string.mojo | 14 +++-- stdlib/src/builtin/value.mojo | 23 +++++++ stdlib/src/collections/list.mojo | 77 +++++++++++++++++------- stdlib/src/collections/vector.mojo | 22 ++++--- stdlib/src/memory/unsafe.mojo | 16 ++--- stdlib/src/python/object.mojo | 8 ++- stdlib/src/utils/index.mojo | 8 +-- stdlib/src/utils/inlined_string.mojo | 7 ++- stdlib/src/utils/static_tuple.mojo | 18 +++--- stdlib/src/utils/stringref.mojo | 7 ++- stdlib/test/builtin/test_list.mojo | 2 + stdlib/test/builtin/test_object.mojo | 8 +++ stdlib/test/builtin/test_slice.mojo | 8 +++ stdlib/test/builtin/test_string.mojo | 3 + stdlib/test/builtin/test_stringref.mojo | 7 +++ stdlib/test/collections/test_list.mojo | 9 +++ stdlib/test/collections/test_vector.mojo | 5 ++ stdlib/test/utils/test_tuple.mojo | 3 + 25 files changed, 264 insertions(+), 87 deletions(-) diff --git a/stdlib/src/builtin/bool.mojo b/stdlib/src/builtin/bool.mojo index 70cb1c1e9..64aeee7f0 100644 --- a/stdlib/src/builtin/bool.mojo +++ b/stdlib/src/builtin/bool.mojo @@ -57,7 +57,12 @@ trait Boolable: @value @register_passable("trivial") struct Bool( - Stringable, CollectionElement, Boolable, EqualityComparable, Intable + Stringable, + CollectionElement, + Boolable, + EqualityComparable, + Intable, + Indexer, ): """The primitive Bool scalar value used in Mojo.""" @@ -299,10 +304,18 @@ struct Bool( """ return lhs ^ self + # ===----------------------------------------------------------------------=== # + # bool + # ===----------------------------------------------------------------------=== # -# ===----------------------------------------------------------------------=== # -# bool -# ===----------------------------------------------------------------------=== # + @always_inline("nodebug") + fn __index__(self) -> Int: + """Convert this Bool to an integer for indexing purposes + + Returns: + Bool as Int + """ + return self.__int__() @always_inline diff --git a/stdlib/src/builtin/builtin_list.mojo b/stdlib/src/builtin/builtin_list.mojo index 6888e978d..4e48a431a 100644 --- a/stdlib/src/builtin/builtin_list.mojo +++ b/stdlib/src/builtin/builtin_list.mojo @@ -138,16 +138,19 @@ struct VariadicList[type: AnyRegType](Sized): return __mlir_op.`pop.variadic.size`(self.value) @always_inline - fn __getitem__(self, index: Int) -> type: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> type: """Gets a single element on the variadic list. + Parameters: + indexer: The type of the indexing value. + Args: - index: The index of the element to access on the list. + idx: The index of the element to access on the list. Returns: The element on the list corresponding to the given index. """ - return __mlir_op.`pop.variadic.get`(self.value, index.value) + return __mlir_op.`pop.variadic.get`(self.value, index(idx).value) @always_inline fn __iter__(self) -> Self.IterType: @@ -344,18 +347,21 @@ struct VariadicListMem[ # TODO: Fix for loops + _VariadicListIter to support a __nextref__ protocol # allowing us to get rid of this and make foreach iteration clean. @always_inline - fn __getitem__(self, index: Int) -> Self.reference_type: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Self.reference_type: """Gets a single element on the variadic list. + Parameters: + indexer: The type of the indexing value. + Args: - index: The index of the element to access on the list. + idx: The index of the element to access on the list. Returns: A low-level pointer to the element on the list corresponding to the given index. """ return Self.reference_type( - __mlir_op.`pop.variadic.get`(self.value, index.value) + __mlir_op.`pop.variadic.get`(self.value, index(idx).value) ) @always_inline diff --git a/stdlib/src/builtin/builtin_slice.mojo b/stdlib/src/builtin/builtin_slice.mojo index ccb8ec01f..6e240cc24 100644 --- a/stdlib/src/builtin/builtin_slice.mojo +++ b/stdlib/src/builtin/builtin_slice.mojo @@ -149,16 +149,19 @@ struct Slice(Sized, Stringable, EqualityComparable): return len(range(self.start, self.end, self.step)) @always_inline - fn __getitem__(self, idx: Int) -> Int: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: """Get the slice index. + Parameters: + indexer: The type of the indexing value. + Args: idx: The index. Returns: The slice index. """ - return self.start + idx * self.step + return self.start + index(idx) * self.step @always_inline("nodebug") fn _has_end(self) -> Bool: diff --git a/stdlib/src/builtin/int.mojo b/stdlib/src/builtin/int.mojo index 662ee0b33..918ddbe49 100644 --- a/stdlib/src/builtin/int.mojo +++ b/stdlib/src/builtin/int.mojo @@ -204,6 +204,7 @@ struct Int( KeyElement, Roundable, Stringable, + Indexer, ): """This type represents an integer value.""" diff --git a/stdlib/src/builtin/int_literal.mojo b/stdlib/src/builtin/int_literal.mojo index 7a571d31d..6ab8b1e9b 100644 --- a/stdlib/src/builtin/int_literal.mojo +++ b/stdlib/src/builtin/int_literal.mojo @@ -28,6 +28,7 @@ struct IntLiteral( Intable, Roundable, Stringable, + Indexer, ): """This type represents a static integer literal value with infinite precision. They can't be materialized at runtime and diff --git a/stdlib/src/builtin/range.mojo b/stdlib/src/builtin/range.mojo index 7d58fff84..bc4709696 100644 --- a/stdlib/src/builtin/range.mojo +++ b/stdlib/src/builtin/range.mojo @@ -82,8 +82,8 @@ struct _ZeroStartingRange(Sized, ReversibleRange): return self.curr @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return idx + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return index(idx) @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: @@ -113,8 +113,8 @@ struct _SequentialRange(Sized, ReversibleRange): return self.end - self.start if self.start < self.end else 0 @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return self.start + idx + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return self.start + index(idx) @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: @@ -185,8 +185,8 @@ struct _StridedRange(Sized, ReversibleRange): return _div_ceil_positive(abs(self.start - self.end), abs(self.step)) @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Int: - return self.start + idx * self.step + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Int: + return self.start + index(idx) * self.step @always_inline("nodebug") fn __reversed__(self) -> _StridedRangeIterator: diff --git a/stdlib/src/builtin/simd.mojo b/stdlib/src/builtin/simd.mojo index 46fef4559..7b681206b 100644 --- a/stdlib/src/builtin/simd.mojo +++ b/stdlib/src/builtin/simd.mojo @@ -132,6 +132,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( Roundable, Sized, Stringable, + Indexer, ): """Represents a small vector that is backed by a hardware vector element. @@ -502,6 +503,22 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( rebind[Scalar[type]](self).value ) + @always_inline("nodebug") + fn __index__(self) -> Int: + """Returns the value as an int if it is an integral value + + Contraints: + Must be an integral value + + Returns: + The value as an integer + """ + constrained[ + type.is_integral() or type.is_bool(), + "expected integral or bool type", + ]() + return self.__int__() + @always_inline fn __str__(self) -> String: """Get the SIMD as a string. @@ -1742,9 +1759,12 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( # ===-------------------------------------------------------------------===# @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> Scalar[type]: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Scalar[type]: """Gets an element from the vector. + Parameters: + indexer: The type of the indexing value. + Args: idx: The element index. @@ -1753,32 +1773,44 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()]( """ return __mlir_op.`pop.simd.extractelement`[ _type = __mlir_type[`!pop.scalar<`, type.value, `>`] - ](self.value, idx.value) + ](self.value, index(idx).value) @always_inline("nodebug") - fn __setitem__(inout self, idx: Int, val: Scalar[type]): + fn __setitem__[ + indexer: Indexer + ](inout self, idx: indexer, val: Scalar[type]): """Sets an element in the vector. + Parameters: + indexer: The type of the indexing value. + Args: idx: The index to set. val: The value to set. """ self.value = __mlir_op.`pop.simd.insertelement`( - self.value, val.value, idx.value + self.value, val.value, index(idx).value ) @always_inline("nodebug") - fn __setitem__( - inout self, idx: Int, val: __mlir_type[`!pop.scalar<`, type.value, `>`] + fn __setitem__[ + indexer: Indexer + ]( + inout self, + idx: indexer, + val: __mlir_type[`!pop.scalar<`, type.value, `>`], ): """Sets an element in the vector. + Parameters: + indexer: The type of the indexing value. + Args: idx: The index to set. val: The value to set. """ self.value = __mlir_op.`pop.simd.insertelement`( - self.value, val, idx.value + self.value, val, index(idx).value ) fn __hash__(self) -> Int: diff --git a/stdlib/src/builtin/string.mojo b/stdlib/src/builtin/string.mojo index e8a0087b7..9931b54d1 100644 --- a/stdlib/src/builtin/string.mojo +++ b/stdlib/src/builtin/string.mojo @@ -687,21 +687,25 @@ struct String( """ return len(self) > 0 - fn __getitem__(self, idx: Int) -> String: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> String: """Gets the character at the specified position. + Parameters: + indexer: The type of the indexing value. + Args: idx: The index value. Returns: A new string containing the character at the specified position. """ - if idx < 0: - return self.__getitem__(len(self) + idx) + var index_val = index(idx) + if index_val < 0: + return self.__getitem__(len(self) + index_val) - debug_assert(0 <= idx < len(self), "index must be in range") + debug_assert(0 <= index_val < len(self), "index must be in range") var buf = Self._buffer_type(capacity=1) - buf.append(self._buffer[idx]) + buf.append(self._buffer[index_val]) buf.append(0) return String(buf^) diff --git a/stdlib/src/builtin/value.mojo b/stdlib/src/builtin/value.mojo index 909a8197d..3b83bbabd 100644 --- a/stdlib/src/builtin/value.mojo +++ b/stdlib/src/builtin/value.mojo @@ -145,3 +145,26 @@ trait RepresentableCollectionElement(CollectionElement, Representable): """ pass + + +trait Indexer: + """This trait denotes a type that can be used to index a container that + handles integral index values. + + This solves the issue of being able to index data structures such as `List` with the various + integral types without being too broad and allowing types that should not be used such as float point + values. + """ + + fn __index__(self) -> Int: + """Return the index value + + Returns: + The index value of the object + """ + ... + + +@always_inline("nodebug") +fn index[indexer: Indexer](idx: indexer) -> Int: + return idx.__index__() diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index 9959524b6..b3d6cb25c 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -59,10 +59,14 @@ struct _ListIter[ @parameter if forward: self.index += 1 - return self.src[].__get_ref(self.index - 1) + return self.src[].__get_ref[list_mutability, list_lifetime]( + self.index - 1 + ) else: self.index -= 1 - return self.src[].__get_ref(self.index) + return self.src[].__get_ref[list_mutability, list_lifetime]( + self.index + ) fn __len__(self) -> Int: @parameter @@ -485,17 +489,24 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): self.capacity = 0 return ptr - fn __setitem__(inout self, i: Int, owned value: T): + fn __setitem__[indexer: Indexer](inout self, i: indexer, owned value: T): """Sets a list element at the given index. + Parameters: + indexer: The type of the indexing value. + Args: i: The index of the element. value: The value to assign. """ - debug_assert(-self.size <= i < self.size, "index must be within bounds") + var normalized_idx = index(i) - var normalized_idx = i - if i < 0: + debug_assert( + -self.size <= normalized_idx < self.size, + "index must be within bounds", + ) + + if normalized_idx < 0: normalized_idx += len(self) destroy_pointee(self.data + normalized_idx) @@ -545,29 +556,39 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): return res^ @always_inline - fn __getitem__(self, i: Int) -> T: + fn __getitem__[indexer: Indexer](self, i: indexer) -> T: """Gets a copy of the list element at the given index. FIXME(lifetimes): This should return a reference, not a copy! + Parameters: + indexer: The type of the indexing value. + Args: i: The index of the element. Returns: A copy of the element at the given index. """ - debug_assert(-self.size <= i < self.size, "index must be within bounds") + var normalized_idx = index(i) - var normalized_idx = i - if i < 0: + debug_assert( + -self.size <= normalized_idx < self.size, + "index must be within bounds", + ) + + if normalized_idx < 0: normalized_idx += len(self) return (self.data + normalized_idx)[] # TODO(30737): Replace __getitem__ with this as __refitem__, but lots of places use it - fn __get_ref( - self: Reference[Self, _, _], i: Int - ) -> Reference[T, self.is_mutable, self.lifetime]: + fn __get_ref[ + mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type + ]( + self: Reference[Self, mutability, self_life]._mlir_type, + i: Int, + ) -> Reference[T, mutability, self_life]: """Gets a reference to the list element at the given index. Args: @@ -578,29 +599,39 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable): """ var normalized_idx = i if i < 0: - normalized_idx += self[].size + normalized_idx += Reference(self)[].size - return (self[].data + normalized_idx)[] + var offset_ptr = Reference(self)[].data + normalized_idx + return offset_ptr[] - fn __iter__( - self: Reference[Self, _, _], - ) -> _ListIter[T, self.is_mutable, self.lifetime]: + fn __iter__[ + mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type + ]( + self: Reference[Self, mutability, self_life]._mlir_type, + ) -> _ListIter[ + T, mutability, self_life + ]: """Iterate over elements of the list, returning immutable references. Returns: An iterator of immutable references to the list elements. """ - return _ListIter(0, self) + return _ListIter[T, mutability, self_life](0, Reference(self)) - fn __reversed__( - self: Reference[Self, _, _] - ) -> _ListIter[T, self.is_mutable, self.lifetime, False]: + fn __reversed__[ + mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type + ]( + self: Reference[Self, mutability, self_life]._mlir_type, + ) -> _ListIter[ + T, mutability, self_life, False + ]: """Iterate backwards over the list, returning immutable references. Returns: A reversed iterator of immutable references to the list elements. """ - return _ListIter[forward=False](len(self[]), self) + var ref = Reference(self) + return _ListIter[T, mutability, self_life, False](len(ref[]), ref) @staticmethod fn __str__[U: RepresentableCollectionElement](self: List[U]) -> String: diff --git a/stdlib/src/collections/vector.mojo b/stdlib/src/collections/vector.mojo index 1beb9eb2b..35967c6b1 100644 --- a/stdlib/src/collections/vector.mojo +++ b/stdlib/src/collections/vector.mojo @@ -181,21 +181,24 @@ struct InlinedFixedVector[ return self.current_size @always_inline - fn __getitem__(self, i: Int) -> type: + fn __getitem__[indexer: Indexer](self, i: indexer) -> type: """Gets a vector element at the given index. + Parameters: + indexer: The type of the indexing value. + Args: i: The index of the element. Returns: The element at the given index. """ + var normalized_idx = index(i) debug_assert( - -self.current_size <= i < self.current_size, + -self.current_size <= normalized_idx < self.current_size, "index must be within bounds", ) - var normalized_idx = i - if i < 0: + if normalized_idx < 0: normalized_idx += len(self) if normalized_idx < Self.static_size: @@ -204,20 +207,23 @@ struct InlinedFixedVector[ return self.dynamic_data[normalized_idx - Self.static_size] @always_inline - fn __setitem__(inout self, i: Int, value: type): + fn __setitem__[indexer: Indexer](inout self, i: indexer, value: type): """Sets a vector element at the given index. + Parameters: + indexer: The type of the indexing value. + Args: i: The index of the element. value: The value to assign. """ + var normalized_idx = index(i) debug_assert( - -self.current_size <= i < self.current_size, + -self.current_size <= normalized_idx < self.current_size, "index must be within bounds", ) - var normalized_idx = i - if i < 0: + if normalized_idx < 0: normalized_idx += len(self) if normalized_idx < Self.static_size: diff --git a/stdlib/src/memory/unsafe.mojo b/stdlib/src/memory/unsafe.mojo index f75912e49..bcdd79ba5 100644 --- a/stdlib/src/memory/unsafe.mojo +++ b/stdlib/src/memory/unsafe.mojo @@ -292,11 +292,11 @@ struct LegacyPointer[ ) @always_inline("nodebug") - fn __refitem__[T: Intable](self, offset: T) -> Self._mlir_ref_type: + fn __refitem__[T: Indexer](self, offset: T) -> Self._mlir_ref_type: """Enable subscript syntax `ref[idx]` to access the element. Parameters: - T: The Intable type of the offset. + T: The Indexer type of the offset. Args: offset: The offset to load from. @@ -304,7 +304,7 @@ struct LegacyPointer[ Returns: The MLIR reference for the Mojo compiler to use. """ - return (self + offset).__refitem__() + return (self + index(offset)).__refitem__() # ===------------------------------------------------------------------=== # # Load/Store @@ -717,7 +717,7 @@ struct DTypePointer[ return LegacyPointer.address_of(arg[]) @always_inline("nodebug") - fn __getitem__[T: Intable](self, offset: T) -> Scalar[type]: + fn __getitem__[T: Indexer](self, offset: T) -> Scalar[type]: """Loads a single element (SIMD of size 1) from the pointer at the specified index. @@ -730,20 +730,20 @@ struct DTypePointer[ Returns: The loaded value. """ - return self.load(offset) + return self.load(index(offset)) @always_inline("nodebug") - fn __setitem__[T: Intable](self, offset: T, val: Scalar[type]): + fn __setitem__[T: Indexer](self, offset: T, val: Scalar[type]): """Stores a single element value at the given offset. Parameters: - T: The Intable type of the offset. + T: The type of the indexing value. Args: offset: The offset to store to. val: The value to store. """ - return self.store(offset, val) + return self.store(index(offset), val) # ===------------------------------------------------------------------=== # # Comparisons diff --git a/stdlib/src/python/object.mojo b/stdlib/src/python/object.mojo index 44bd922bc..151984b99 100644 --- a/stdlib/src/python/object.mojo +++ b/stdlib/src/python/object.mojo @@ -101,7 +101,13 @@ struct _PyIter(Sized): @register_passable struct PythonObject( - Intable, Stringable, SizedRaising, Boolable, CollectionElement, KeyElement + Intable, + Stringable, + SizedRaising, + Boolable, + CollectionElement, + KeyElement, + Indexer, ): """A Python object.""" diff --git a/stdlib/src/utils/index.mojo b/stdlib/src/utils/index.mojo index 5f8abe86d..f506b2dea 100644 --- a/stdlib/src/utils/index.mojo +++ b/stdlib/src/utils/index.mojo @@ -335,11 +335,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable): return size @always_inline("nodebug") - fn __getitem__[intable: Intable](self, index: intable) -> Int: + fn __getitem__[indexer: Indexer](self, index: indexer) -> Int: """Gets an element from the tuple by index. Parameters: - intable: The intable type. + indexer: The type of the indexing value. Args: index: The element index. @@ -362,11 +362,11 @@ struct StaticIntTuple[size: Int](Sized, Stringable, EqualityComparable): self.data.__setitem__[index](val) @always_inline("nodebug") - fn __setitem__[intable: Intable](inout self, index: intable, val: Int): + fn __setitem__[indexer: Indexer](inout self, index: indexer, val: Int): """Sets an element in the tuple at the given index. Parameters: - intable: The intable type. + indexer: The type of the indexing value. Args: index: The element index. diff --git a/stdlib/src/utils/inlined_string.mojo b/stdlib/src/utils/inlined_string.mojo index 4df5fc9f1..79cb03600 100644 --- a/stdlib/src/utils/inlined_string.mojo +++ b/stdlib/src/utils/inlined_string.mojo @@ -524,9 +524,12 @@ struct _ArrayMem[ElementType: AnyRegType, SIZE: Int](Sized): """ return SIZE - fn __setitem__(inout self, index: Int, owned value: ElementType): + fn __setitem__[ + indexer: Indexer + ](inout self, idx: indexer, owned value: ElementType): var ptr = __mlir_op.`pop.array.gep`( - UnsafePointer(Reference(self.storage.array)).address, index.value + UnsafePointer(Reference(self.storage.array)).address, + index(idx).value, ) __mlir_op.`pop.store`(value, ptr) diff --git a/stdlib/src/utils/static_tuple.mojo b/stdlib/src/utils/static_tuple.mojo index ded021611..7330ee85c 100644 --- a/stdlib/src/utils/static_tuple.mojo +++ b/stdlib/src/utils/static_tuple.mojo @@ -197,19 +197,19 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): self = tmp @always_inline("nodebug") - fn __getitem__[intable: Intable](self, index: intable) -> Self.element_type: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> Self.element_type: """Returns the value of the tuple at the given dynamic index. Parameters: - intable: The intable type. + indexer: The type of the indexing value. Args: - index: The index into the tuple. + idx: The index into the tuple. Returns: The value at the specified position. """ - var offset = int(index) + var offset = index(idx) debug_assert(offset < size, "index must be within bounds") # Copy the array so we can get its address, because we can't take the # address of 'self' in a non-mutating method. @@ -221,18 +221,18 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized): @always_inline("nodebug") fn __setitem__[ - intable: Intable - ](inout self, index: intable, val: Self.element_type): + indexer: Indexer + ](inout self, idx: indexer, val: Self.element_type): """Stores a single value into the tuple at the specified dynamic index. Parameters: - intable: The intable type. + indexer: The type of the indexing value. Args: - index: The index into the tuple. + idx: The index into the tuple. val: The value to store. """ - var offset = int(index) + var offset = index(idx) debug_assert(offset < size, "index must be within bounds") var tmp = self var ptr = __mlir_op.`pop.array.gep`( diff --git a/stdlib/src/utils/stringref.mojo b/stdlib/src/utils/stringref.mojo index 2f85e1840..dd8346d86 100644 --- a/stdlib/src/utils/stringref.mojo +++ b/stdlib/src/utils/stringref.mojo @@ -257,16 +257,19 @@ struct StringRef( return not (self == rhs) @always_inline("nodebug") - fn __getitem__(self, idx: Int) -> StringRef: + fn __getitem__[indexer: Indexer](self, idx: indexer) -> StringRef: """Get the string value at the specified position. + Parameters: + indexer: The type of the indexing value. + Args: idx: The index position. Returns: The character at the specified position. """ - return StringRef {data: self.data + idx, length: 1} + return StringRef {data: self.data + index(idx), length: 1} fn __hash__(self) -> Int: """Hash the underlying buffer using builtin hash. diff --git a/stdlib/test/builtin/test_list.mojo b/stdlib/test/builtin/test_list.mojo index 79bc5666b..9c6af3e79 100644 --- a/stdlib/test/builtin/test_list.mojo +++ b/stdlib/test/builtin/test_list.mojo @@ -27,6 +27,8 @@ fn test_variadic_list() raises: assert_equal(nums[2], 6) assert_equal(len(nums), 3) + assert_equal(nums[False], 5) + assert_equal(nums[Int16(2)], 6) check_list(5, 8, 6) diff --git a/stdlib/test/builtin/test_object.mojo b/stdlib/test/builtin/test_object.mojo index 02cdc4ea3..648a20c24 100644 --- a/stdlib/test/builtin/test_object.mojo +++ b/stdlib/test/builtin/test_object.mojo @@ -310,6 +310,13 @@ def test_convert_to_string(): assert_equal(str(a), "{'foo' = 5, 'bar' = [1, 2], 'baz' = False}") +def test_indexing(): + a = object([1, 2, 3]) + assert_equal(a[0], 1) + assert_equal(a[True], 2) + assert_equal(a[Int16(2)], 3) + + def main(): test_object_ctors() test_comparison_ops() @@ -320,3 +327,4 @@ def main(): test_non_object_getattr() test_matrix() test_convert_to_string() + test_indexing() diff --git a/stdlib/test/builtin/test_slice.mojo b/stdlib/test/builtin/test_slice.mojo index 672d14e79..fb4b5ecf7 100644 --- a/stdlib/test/builtin/test_slice.mojo +++ b/stdlib/test/builtin/test_slice.mojo @@ -82,9 +82,17 @@ def test_slice_stringable(): assert_equal(s[:-1], "0:-1:1") +def test_indexing(): + var s = slice(1, 10) + assert_equal(s[True], 2) + assert_equal(s[UInt64(0)], 1) + + def main(): test_none_end_folds() test_slicable() test_has_end() test_slice_stringable() + + test_indexing() diff --git a/stdlib/test/builtin/test_string.mojo b/stdlib/test/builtin/test_string.mojo index 031bb1722..48dec07d9 100644 --- a/stdlib/test/builtin/test_string.mojo +++ b/stdlib/test/builtin/test_string.mojo @@ -225,6 +225,9 @@ fn test_string_indexing() raises: assert_equal("!jMolH", str[:-1:-2]) + assert_equal(str[True], "e") + assert_equal(str[Int8(0)], "H") + fn test_atol() raises: # base 10 diff --git a/stdlib/test/builtin/test_stringref.mojo b/stdlib/test/builtin/test_stringref.mojo index 257b59345..d9ed5395f 100644 --- a/stdlib/test/builtin/test_stringref.mojo +++ b/stdlib/test/builtin/test_stringref.mojo @@ -41,6 +41,13 @@ def test_intable(): int(StringRef("hi")) +def test_indexing(): + a = StringRef("abc") + assert_equal(a[False], "a") + assert_equal(a[Int16(1)], "b") + + def main(): test_strref_from_start() test_intable() + test_indexing() diff --git a/stdlib/test/collections/test_list.mojo b/stdlib/test/collections/test_list.mojo index a65d070c4..a44fe91dd 100644 --- a/stdlib/test/collections/test_list.mojo +++ b/stdlib/test/collections/test_list.mojo @@ -689,6 +689,14 @@ def test_list_count(): assert_equal(0, __type_of(list2).count(list2, 1)) +def test_indexer(): + var l = List[Int](1, 2, 3) + assert_equal(l[Int8(1)], 2) + assert_equal(l[UInt64(2)], 3) + assert_equal(l[False], 1) + assert_equal(l[True], 2) + + def main(): test_mojo_issue_698() test_list() @@ -714,3 +722,4 @@ def main(): test_constructor_from_other_list_through_pointer() test_converting_list_to_string() test_list_count() + test_indexer() diff --git a/stdlib/test/collections/test_vector.mojo b/stdlib/test/collections/test_vector.mojo index 9750141bd..7537e06e9 100644 --- a/stdlib/test/collections/test_vector.mojo +++ b/stdlib/test/collections/test_vector.mojo @@ -103,6 +103,11 @@ def test_inlined_fixed_vector_with_default(): vector[5] = -2 assert_equal(-2, vector[5]) + # check we can index with non Int or IntLiteral + assert_equal(1, vector[Int16(1)]) + assert_equal(1, vector[True]) + assert_equal(1, vector[Scalar[DType.bool](True)]) + vector.clear() assert_equal(0, len(vector)) diff --git a/stdlib/test/utils/test_tuple.mojo b/stdlib/test/utils/test_tuple.mojo index c28e68baf..43b955048 100644 --- a/stdlib/test/utils/test_tuple.mojo +++ b/stdlib/test/utils/test_tuple.mojo @@ -34,6 +34,9 @@ def test_static_tuple(): assert_equal(tup3[Int(0)], 1) assert_equal(tup3[Int64(0)], 1) + assert_equal(tup3[True], 2) + assert_equal(tup3[Int32(2)], 3) + def test_static_int_tuple(): assert_equal(str(StaticIntTuple[1](1)), "(1,)")