-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[External] [stdlib] Add the
normalize_index
function (#40280)
[External] [stdlib] Add the `normalize_index` function This PR is some kind of mix between #2386 and #2384 which have issues (aborting) or have too many conflicts because the PR is too big. This PR solves part of #2251 and #2337 We try here to give the ground work for indexing correctly. This function added can then be used wherever we work with sequences. Two things I noticed during development: 1) The `debug_assert` does not run in unit tests. Is there any way to enable it? We currently have out-of-bounds bugs in our test suite. 2) The null terminator is causing pain, again, again, and again. Do we have any plans to make it optional when working with String? I opened #2678 to discuss this. To avoid to fix those issues in this PR, I used the `normalize_index` on the `__refitem__` of `InlineArray` which doesn't have widespread use yet and isn't impacted by out-of-bounds bugs. My recommendation would be to merge this PR then to rebase #2386 and #2384 on it. We should also afterwards fix the out of bounds bugs that can be triggered in the test suite by enabling debug_assert. The diff might seem big, but no worries, it's mostly the licenses and docstrings :) Co-authored-by: Gabriel de Marmiesse <gabriel.demarmiesse@datadoghq.com> Closes #2677 MODULAR_ORIG_COMMIT_REV_ID: 66e7121a6a333c16284eb33a89eb85c034c296c3
- Loading branch information
1 parent
b74d39f
commit 216c340
Showing
5 changed files
with
200 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# ===----------------------------------------------------------------------=== # | ||
# Copyright (c) 2024, Modular Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
# https://llvm.org/LICENSE.txt | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ===----------------------------------------------------------------------=== # | ||
"""The utilities provided in this module help normalize the access | ||
to data elements in arrays.""" | ||
|
||
from sys import triple_is_nvidia_cuda | ||
|
||
|
||
fn get_out_of_bounds_error_message[ | ||
container_name: StringLiteral | ||
](i: Int, container_length: Int) -> String: | ||
if container_length == 0: | ||
return ( | ||
"The " | ||
+ container_name | ||
+ " has a length of 0. Thus it's not possible to access its values" | ||
" with an index but the index value " | ||
+ str(i) | ||
+ " was used. Aborting now to avoid an out-of-bounds access." | ||
) | ||
else: | ||
return ( | ||
"The " | ||
+ container_name | ||
+ " has a length of " | ||
+ str(container_length) | ||
+ ". Thus the index provided should be between " | ||
+ str(-container_length) | ||
+ " (inclusive) and " | ||
+ str(container_length) | ||
+ " (exclusive) but the index value " | ||
+ str(i) | ||
+ " was used. Aborting now to avoid an out-of-bounds access." | ||
) | ||
|
||
|
||
@always_inline | ||
fn normalize_index[ | ||
ContainerType: Sized, //, container_name: StringLiteral | ||
](idx: Int, container: ContainerType) -> Int: | ||
"""Normalize the given index value to a valid index value for the given container length. | ||
If the provided value is negative, the `index + container_length` is returned. | ||
Parameters: | ||
ContainerType: The type of the container. Must have a `__len__` method. | ||
container_name: The name of the container. Used for the error message. | ||
Args: | ||
idx: The index value to normalize. | ||
container: The container to normalize the index for. | ||
Returns: | ||
The normalized index value. | ||
""" | ||
var container_length = len(container) | ||
if not (-container_length <= idx < container_length): | ||
|
||
@parameter | ||
if triple_is_nvidia_cuda(): | ||
abort() | ||
else: | ||
abort( | ||
get_out_of_bounds_error_message[container_name]( | ||
idx, container_length | ||
) | ||
) | ||
return idx + int(idx < 0) * container_length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# ===----------------------------------------------------------------------=== # | ||
# Copyright (c) 2024, Modular Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
# https://llvm.org/LICENSE.txt | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ===----------------------------------------------------------------------=== # | ||
# RUN: %mojo %s | ||
|
||
from collections._index_normalization import ( | ||
get_out_of_bounds_error_message, | ||
normalize_index, | ||
) | ||
from testing import assert_equal | ||
|
||
|
||
def test_out_of_bounds_message(): | ||
assert_equal( | ||
get_out_of_bounds_error_message["List"](5, 2), | ||
( | ||
"The List has a length of 2. Thus the index provided should be" | ||
" between -2 (inclusive) and 2 (exclusive) but the index value 5" | ||
" was used. Aborting now to avoid an out-of-bounds access." | ||
), | ||
) | ||
|
||
assert_equal( | ||
get_out_of_bounds_error_message["List"](0, 0), | ||
( | ||
"The List has a length of 0. Thus it's not possible to access its" | ||
" values with an index but the index value 0 was used. Aborting now" | ||
" to avoid an out-of-bounds access." | ||
), | ||
) | ||
assert_equal( | ||
get_out_of_bounds_error_message["InlineArray"](8, 0), | ||
( | ||
"The InlineArray has a length of 0. Thus it's not possible to" | ||
" access its values with an index but the index value 8 was used." | ||
" Aborting now to avoid an out-of-bounds access." | ||
), | ||
) | ||
|
||
|
||
def test_normalize_index(): | ||
container = List[Int](1, 1, 1, 1) | ||
assert_equal(normalize_index[""](-4, container), 0) | ||
assert_equal(normalize_index[""](-3, container), 1) | ||
assert_equal(normalize_index[""](-2, container), 2) | ||
assert_equal(normalize_index[""](-1, container), 3) | ||
assert_equal(normalize_index[""](0, container), 0) | ||
assert_equal(normalize_index[""](1, container), 1) | ||
assert_equal(normalize_index[""](2, container), 2) | ||
assert_equal(normalize_index[""](3, container), 3) | ||
|
||
|
||
def main(): | ||
test_out_of_bounds_message() | ||
test_normalize_index() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters