Skip to content

Commit

Permalink
[External] [stdlib] Add the normalize_index function (#40280)
Browse files Browse the repository at this point in the history
[External] [stdlib] Add the `normalize_index` function

This PR is some kind of mix between
#2386 and
#2384 which have issues (aborting)
or have too many conflicts because the PR is too big.

This PR solves part of #2251 and
#2337

We try here to give the ground work for indexing correctly. This
function added can then be used wherever we work with sequences.

Two things I noticed during development:
1) The `debug_assert` does not run in unit tests. Is there any way to
enable it? We currently have out-of-bounds bugs in our test suite.
2) The null terminator is causing pain, again, again, and again. Do we
have any plans to make it optional when working with String? I opened
#2678 to discuss this.

To avoid to fix those issues in this PR, I used the `normalize_index` on
the `__refitem__` of `InlineArray` which doesn't have widespread use yet
and isn't impacted by out-of-bounds bugs.

My recommendation would be to merge this PR then to rebase
#2386 and
#2384 on it. We should also
afterwards fix the out of bounds bugs that can be triggered in the test
suite by enabling debug_assert.

The diff might seem big, but no worries, it's mostly the licenses and
docstrings :)

Co-authored-by: Gabriel de Marmiesse <gabriel.demarmiesse@datadoghq.com>
Closes #2677
MODULAR_ORIG_COMMIT_REV_ID: 66e7121a6a333c16284eb33a89eb85c034c296c3
  • Loading branch information
gabrieldemarmiesse authored and modularbot committed Jun 7, 2024
1 parent b74d39f commit 216c340
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 35 deletions.
78 changes: 78 additions & 0 deletions stdlib/src/collections/_index_normalization.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""The utilities provided in this module help normalize the access
to data elements in arrays."""

from sys import triple_is_nvidia_cuda


fn get_out_of_bounds_error_message[
container_name: StringLiteral
](i: Int, container_length: Int) -> String:
if container_length == 0:
return (
"The "
+ container_name
+ " has a length of 0. Thus it's not possible to access its values"
" with an index but the index value "
+ str(i)
+ " was used. Aborting now to avoid an out-of-bounds access."
)
else:
return (
"The "
+ container_name
+ " has a length of "
+ str(container_length)
+ ". Thus the index provided should be between "
+ str(-container_length)
+ " (inclusive) and "
+ str(container_length)
+ " (exclusive) but the index value "
+ str(i)
+ " was used. Aborting now to avoid an out-of-bounds access."
)


@always_inline
fn normalize_index[
ContainerType: Sized, //, container_name: StringLiteral
](idx: Int, container: ContainerType) -> Int:
"""Normalize the given index value to a valid index value for the given container length.
If the provided value is negative, the `index + container_length` is returned.
Parameters:
ContainerType: The type of the container. Must have a `__len__` method.
container_name: The name of the container. Used for the error message.
Args:
idx: The index value to normalize.
container: The container to normalize the index for.
Returns:
The normalized index value.
"""
var container_length = len(container)
if not (-container_length <= idx < container_length):

@parameter
if triple_is_nvidia_cuda():
abort()
else:
abort(
get_out_of_bounds_error_message[container_name](
idx, container_length
)
)
return idx + int(idx < 0) * container_length
28 changes: 11 additions & 17 deletions stdlib/src/collections/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -791,40 +791,34 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable):
return self[].unsafe_get(normalized_idx)

@always_inline
fn unsafe_get[
IndexerType: Indexer,
](self: Reference[Self, _, _], idx: IndexerType) -> Reference[
Self.T, self.is_mutable, self.lifetime
]:
fn unsafe_get(
self: Reference[Self, _, _], idx: Int
) -> Reference[Self.T, self.is_mutable, self.lifetime]:
"""Get a reference to an element of self without checking index bounds.
Users should consider using `__getitem__` instead of this method as it is unsafe.
If an index is out of bounds, this method will not abort, it will be considered
undefined behavior.
Users should consider using `__getitem__` instead of this method as it
is unsafe. If an index is out of bounds, this method will not abort, it
will be considered undefined behavior.
Note that there is no wraparound for negative indices, caution is advised.
Using negative indices is considered undefined behavior.
Never use `my_list.unsafe_get(-1)` to get the last element of the list. It will not work.
Note that there is no wraparound for negative indices, caution is
advised. Using negative indices is considered undefined behavior. Never
use `my_list.unsafe_get(-1)` to get the last element of the list.
Instead, do `my_list.unsafe_get(len(my_list) - 1)`.
Parameters:
IndexerType: The type of the argument used as index.
Args:
idx: The index of the element to get.
Returns:
A reference to the element at the given index.
"""
var idx_as_int = index(idx)
debug_assert(
0 <= idx_as_int < len(self[]),
0 <= idx < len(self[]),
(
"The index provided must be within the range [0, len(List) -1]"
" when using List.unsafe_get()"
),
)
return (self[].data + idx_as_int)[]
return (self[].data + idx)[]

fn count[T: ComparableCollectionElement](self: List[T], value: T) -> Int:
"""Counts the number of occurrences of a value in the list.
Expand Down
46 changes: 28 additions & 18 deletions stdlib/src/utils/static_tuple.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ You can import these APIs from the `utils` package. For example:
from utils import StaticTuple
```
"""

from collections._index_normalization import normalize_index
from memory import Pointer

from utils import unroll
Expand Down Expand Up @@ -341,28 +341,20 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized):
# ===------------------------------------------------------------------===#

@always_inline("nodebug")
fn __getitem__[
IntableType: Intable,
](self: Reference[Self, _, _], index: IntableType) -> ref [
self.lifetime
] Self.ElementType:
fn __getitem__(
self: Reference[Self, _, _], idx: Int
) -> ref [self.lifetime] Self.ElementType:
"""Get a `Reference` to the element at the given index.
Parameters:
IntableType: The inferred type of an intable argument.
Args:
index: The index of the item.
idx: The index of the item.
Returns:
A reference to the item at the given index.
"""
debug_assert(-size <= int(index) < size, "Index must be within bounds.")
var normalized_idx = int(index)
if normalized_idx < 0:
normalized_idx += size
var normalized_index = normalize_index["InlineArray"](idx, self[])

return self[]._get_reference_unsafe(normalized_idx)[]
return self[]._get_reference_unsafe(normalized_index)[]

@always_inline("nodebug")
fn __getitem__[
Expand Down Expand Up @@ -408,15 +400,33 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized):

@always_inline("nodebug")
fn _get_reference_unsafe(
self: Reference[Self, _, _], index: Int
self: Reference[Self, _, _], idx: Int
) -> Reference[Self.ElementType, self.is_mutable, self.lifetime]:
"""Get a reference to an element of self without checking index bounds.
Users should opt for `__getitem__` instead of this method.
Users should opt for `__getitem__` instead of this method as it is
unsafe.
Note that there is no wraparound for negative indices. Using negative
indices is considered undefined behavior.
Args:
idx: The index of the element to get.
Returns:
A reference to the element at the given index.
"""
var idx_as_int = index(idx)
debug_assert(
0 <= idx_as_int < size,
(
"Index must be within bounds when using"
" `InlineArray.unsafe_get()`."
),
)
var ptr = __mlir_op.`pop.array.gep`(
UnsafePointer.address_of(self[]._array).address,
index.value,
idx_as_int.value,
)
return UnsafePointer(ptr)[]

Expand Down
64 changes: 64 additions & 0 deletions stdlib/test/collections/test_index_normalization.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# RUN: %mojo %s

from collections._index_normalization import (
get_out_of_bounds_error_message,
normalize_index,
)
from testing import assert_equal


def test_out_of_bounds_message():
assert_equal(
get_out_of_bounds_error_message["List"](5, 2),
(
"The List has a length of 2. Thus the index provided should be"
" between -2 (inclusive) and 2 (exclusive) but the index value 5"
" was used. Aborting now to avoid an out-of-bounds access."
),
)

assert_equal(
get_out_of_bounds_error_message["List"](0, 0),
(
"The List has a length of 0. Thus it's not possible to access its"
" values with an index but the index value 0 was used. Aborting now"
" to avoid an out-of-bounds access."
),
)
assert_equal(
get_out_of_bounds_error_message["InlineArray"](8, 0),
(
"The InlineArray has a length of 0. Thus it's not possible to"
" access its values with an index but the index value 8 was used."
" Aborting now to avoid an out-of-bounds access."
),
)


def test_normalize_index():
container = List[Int](1, 1, 1, 1)
assert_equal(normalize_index[""](-4, container), 0)
assert_equal(normalize_index[""](-3, container), 1)
assert_equal(normalize_index[""](-2, container), 2)
assert_equal(normalize_index[""](-1, container), 3)
assert_equal(normalize_index[""](0, container), 0)
assert_equal(normalize_index[""](1, container), 1)
assert_equal(normalize_index[""](2, container), 2)
assert_equal(normalize_index[""](3, container), 3)


def main():
test_out_of_bounds_message()
test_normalize_index()
19 changes: 19 additions & 0 deletions stdlib/test/utils/test_tuple.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ def test_tuple_literal():
assert_equal(len(()), 0)


def test_array_get_reference_unsafe():
# Negative indexing is undefined behavior with _get_reference_unsafe
# so there are not test cases for it.
var arr = InlineArray[Int, 3](0, 0, 0)

assert_equal(arr._get_reference_unsafe(0)[], 0)
assert_equal(arr._get_reference_unsafe(1)[], 0)
assert_equal(arr._get_reference_unsafe(2)[], 0)

arr[0] = 1
arr[1] = 2
arr[2] = 3

assert_equal(arr._get_reference_unsafe(0)[], 1)
assert_equal(arr._get_reference_unsafe(1)[], 2)
assert_equal(arr._get_reference_unsafe(2)[], 3)


def test_array_int():
var arr = InlineArray[Int, 3](0, 0, 0)

Expand Down Expand Up @@ -199,6 +217,7 @@ def main():
test_static_tuple()
test_static_int_tuple()
test_tuple_literal()
test_array_get_reference_unsafe()
test_array_int()
test_array_str()
test_array_int_pointer()
Expand Down

0 comments on commit 216c340

Please sign in to comment.