Skip to content

Commit

Permalink
Add index normalization
Browse files Browse the repository at this point in the history
Signed-off-by: gabrieldemarmiesse <gabrieldemarmiesse@gmail.com>
  • Loading branch information
gabrieldemarmiesse committed May 20, 2024
1 parent 510ab7c commit 5d859b1
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 9 deletions.
83 changes: 83 additions & 0 deletions stdlib/src/collections/_index_normalization.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""The utilities provided in this module help normalize the access
to data elements in arrays."""


fn get_out_of_bounds_error_message[
container_name: String
](i: Int, container_length: Int) -> String:
if container_length == 0:
return (
"The "
+ container_name
+ " has a length of 0. "
+ "Thus it's not possible to access its values with an index "
+ "but the index value "
+ str(i)
+ " was used. "
+ "Aborting now to avoid an out-of-bounds access."
)
else:
return (
"The "
+ container_name
+ " has a length of "
+ str(container_length)
+ ". "
+ "Thus the index provided should be between "
+ str(-container_length)
+ " (inclusive) and "
+ str(container_length)
+ " (exclusive) but the index value "
+ str(i)
+ " was used. "
+ "Aborting now to avoid an out-of-bounds access."
)


@always_inline
fn normalize_index[
inferred IndexType: Indexer,
inferred ContainerType: Sized,
container_name: StringLiteral,
](index_value: IndexType, container: ContainerType) -> Int:
"""Normalize the given index value to a valid index value for the given container length.
If the provided value is negative, the `index + container_length` is returned.
Parameters:
IndexType: The type of the index value. Must have an `__index__` method.
ContainerType: The type of the container. Must have a `__len__` method.
container_name: The name of the container. Used for the error message.
Args:
index_value: The index value to normalize.
container: The container to normalize the index for.
Returns:
The normalized index value.
"""
var index_as_int = index(index_value)
var container_length = len(container)

if not (-container_length <= index_as_int < container_length):
# TODO: Get the container_name from the ContainerType when the compiler allows it.
abort(
get_out_of_bounds_error_message[container_name](
index_as_int, container_length
)
)
if index_as_int < 0:
index_as_int += container_length
return index_as_int
17 changes: 8 additions & 9 deletions stdlib/src/utils/static_tuple.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ You can import these APIs from the `utils` package. For example:
from utils import StaticTuple
```
"""

from collections._index_normalization import normalize_index
from memory import Pointer

from utils import unroll
Expand Down Expand Up @@ -369,27 +369,26 @@ struct InlineArray[ElementType: CollectionElement, size: Int](Sized):

@always_inline("nodebug")
fn __refitem__[
IntableType: Intable,
](self: Reference[Self, _, _], index: IntableType) -> Reference[
IndexerType: Indexer,
](self: Reference[Self, _, _], index: IndexerType) -> Reference[
Self.ElementType, self.is_mutable, self.lifetime
]:
"""Get a `Reference` to the element at the given index.
Parameters:
IntableType: The inferred type of an intable argument.
IndexerType: The inferred type of an indexer argument.
Args:
index: The index of the item.
Returns:
A reference to the item at the given index.
"""
debug_assert(-size <= int(index) < size, "Index must be within bounds.")
var normalized_idx = int(index)
if normalized_idx < 0:
normalized_idx += size
var normalized_index = normalize_index[container_name="InlineArray"](
index, self[]
)

return self[]._get_reference_unsafe(normalized_idx)
return self[]._get_reference_unsafe(normalized_index)

@always_inline("nodebug")
fn __refitem__[
Expand Down
64 changes: 64 additions & 0 deletions stdlib/test/collections/test_index_normalization.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# RUN: %mojo %s

from collections._index_normalization import (
get_out_of_bounds_error_message,
normalize_index,
)
from testing import assert_equal


def test_out_of_bounds_message():
assert_equal(
get_out_of_bounds_error_message[container_name="List"](5, 2),
(
"The List has a length of 2. Thus the index provided should be"
" between -2 (inclusive) and 2 (exclusive) but the index value 5"
" was used. Aborting now to avoid an out-of-bounds access."
),
)

assert_equal(
get_out_of_bounds_error_message[container_name="List"](0, 0),
(
"The List has a length of 0. Thus it's not possible to access its"
" values with an index but the index value 0 was used. Aborting now"
" to avoid an out-of-bounds access."
),
)
assert_equal(
get_out_of_bounds_error_message[container_name="InlineArray"](8, 0),
(
"The InlineArray has a length of 0. Thus it's not possible to"
" access its values with an index but the index value 8 was used."
" Aborting now to avoid an out-of-bounds access."
),
)


def test_normalize_index():
container = List[Int](1, 1, 1, 1)
assert_equal(normalize_index[container_name=""](-4, container), 0)
assert_equal(normalize_index[container_name=""](-3, container), 1)
assert_equal(normalize_index[container_name=""](-2, container), 2)
assert_equal(normalize_index[container_name=""](-1, container), 3)
assert_equal(normalize_index[container_name=""](0, container), 0)
assert_equal(normalize_index[container_name=""](1, container), 1)
assert_equal(normalize_index[container_name=""](2, container), 2)
assert_equal(normalize_index[container_name=""](3, container), 3)


def main():
test_out_of_bounds_message()
test_normalize_index()

0 comments on commit 5d859b1

Please sign in to comment.