Skip to content

Commit

Permalink
[External] [mojo-stdlib] Create InlineArray type (#38793)
Browse files Browse the repository at this point in the history
[External] [mojo-stdlib] Create `InlineArray` type

This PR creates an `InlineArray` type that takes any `CollectionElement`
rather than just `AnyRegType`. See also [this
thread](https://discord.com/channels/1087530497313357884/1228515158234501231)
on Discord.

---------

Co-authored-by: Lukas Hermann <lukashermann28@gmail.com>
Closes #2294
MODULAR_ORIG_COMMIT_REV_ID: d53b6de6f9ecdf035df951ce679cea2aed1b8e65
  • Loading branch information
lsh authored and JoeLoser committed Apr 30, 2024
1 parent 477791c commit 904ac4e
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 2 deletions.
2 changes: 1 addition & 1 deletion stdlib/src/utils/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ from .index import (
)
from .inlined_string import InlinedString
from .loop import unroll
from .static_tuple import StaticTuple
from .static_tuple import StaticTuple, InlineArray
from .stringref import StringRef
from .variant import Variant
160 changes: 160 additions & 0 deletions stdlib/src/utils/static_tuple.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,163 @@ struct StaticTuple[element_type: AnyRegType, size: Int](Sized):
)
Pointer(ptr).store(val)
self = tmp


# ===----------------------------------------------------------------------===#
# Array
# ===----------------------------------------------------------------------===#


@value
struct InlineArray[ElementType: CollectionElement, size: Int](Sized):
"""A fixed-size sequence of size homogenous elements where size is a constant expression.
Parameters:
ElementType: The type of the elements in the array.
size: The size of the array.
"""

alias type = __mlir_type[
`!pop.array<`, size.value, `, `, Self.ElementType, `>`
]
var _array: Self.type
"""The underlying storage for the array."""

@always_inline
fn __init__(inout self):
"""This constructor will always cause a compile time error if used.
It is used to steer users away from uninitialized memory.
"""
constrained[
False,
(
"Initialize with either a variadic list of arguments or a"
" default fill element."
),
]()
self._array = __mlir_op.`kgen.undef`[_type = Self.type]()

@always_inline
fn __init__(inout self, fill: Self.ElementType):
"""Constructs an empty array where each element is the supplied `fill`.
Args:
fill: The element to fill each index.
"""
_static_tuple_construction_checks[size]()
self._array = __mlir_op.`kgen.undef`[_type = Self.type]()

@unroll
for i in range(size):
var ptr = self._get_reference_unsafe(i)
initialize_pointee_copy(UnsafePointer[Self.ElementType](ptr), fill)

@always_inline
fn __init__(inout self, *elems: Self.ElementType):
"""Constructs an array given a set of arguments.
Args:
elems: The element types.
"""
debug_assert(len(elems) == size, "Elements must be of length size")
_static_tuple_construction_checks[size]()
self._array = __mlir_op.`kgen.undef`[_type = Self.type]()

@unroll
for i in range(size):
var ref = self._get_reference_unsafe(i)
initialize_pointee_move(
UnsafePointer[Self.ElementType](ref), elems[i]
)

@always_inline("nodebug")
fn __len__(self) -> Int:
"""Returns the length of the array. This is a known constant value.
Returns:
The size of the list.
"""
return size

@always_inline("nodebug")
fn _get_reference_unsafe[
mutability: __mlir_type.i1,
self_life: AnyLifetime[mutability].type,
](
self: Reference[Self, mutability, self_life]._mlir_type, index: Int
) -> Reference[Self.ElementType, mutability, self_life]:
"""Get a reference to an element of self without checking index bounds.
Users should opt for `__refitem__` instead of this method.
"""
var ptr = __mlir_op.`pop.array.gep`(
Reference(Reference(self)[]._array).get_legacy_pointer().address,
index.value,
)
return Reference[Self.ElementType, mutability, self_life](
UnsafePointer(ptr)[]
)

@always_inline("nodebug")
fn __refitem__[
mutability: __mlir_type.i1,
self_life: AnyLifetime[mutability].type,
IntableType: Intable,
](
self: Reference[Self, mutability, self_life]._mlir_type,
index: IntableType,
) -> Reference[Self.ElementType, mutability, self_life]:
"""Get a `Reference` to the element at the given index.
Parameters:
mutability: The inferred mutability of the reference.
self_life: The inferred lifetime of the reference.
IntableType: The inferred type of an intable argument.
Args:
index: The index of the item.
Returns:
A reference to the item at the given index.
"""
debug_assert(-size <= int(index) < size, "Index must be within bounds.")
var normalized_idx = int(index)
if normalized_idx < 0:
normalized_idx += size

return Reference(self)[]._get_reference_unsafe[mutability, self_life](
normalized_idx
)

@always_inline("nodebug")
fn __refitem__[
mutability: __mlir_type.i1,
self_life: AnyLifetime[mutability].type,
IntableType: Intable,
index: IntableType,
](self: Reference[Self, mutability, self_life]._mlir_type) -> Reference[
Self.ElementType, mutability, self_life
]:
"""Get a `Reference` to the element at the given index.
Parameters:
mutability: The inferred mutability of the reference.
self_life: The inferred lifetime of the reference.
IntableType: The inferred type of an intable argument.
index: The index of the item.
Returns:
A reference to the item at the given index.
"""
alias i = int(index)
constrained[-size <= i < size, "Index must be within bounds."]()

var normalized_idx = i

@parameter
if i < 0:
normalized_idx += size

return Reference(self)[]._get_reference_unsafe[mutability, self_life](
normalized_idx
)
90 changes: 89 additions & 1 deletion stdlib/test/utils/test_tuple.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from testing import assert_equal, assert_false, assert_true

from utils import StaticTuple, StaticIntTuple
from utils import StaticTuple, StaticIntTuple, InlineArray


def test_static_tuple():
Expand Down Expand Up @@ -80,7 +80,95 @@ def test_tuple_literal():
assert_equal(len(()), 0)


def test_array_int():
var arr = InlineArray[Int, 3](0, 0, 0)

assert_equal(arr[0], 0)
assert_equal(arr[1], 0)
assert_equal(arr[2], 0)

arr[0] = 1
arr[1] = 2
arr[2] = 3

assert_equal(arr[0], 1)
assert_equal(arr[1], 2)
assert_equal(arr[2], 3)

# test negative indexing
assert_equal(arr[-1], 3)
assert_equal(arr[-2], 2)

# test negative indexing with dynamic index
var i = -1
assert_equal(arr[i], 3)
i -= 1
assert_equal(arr[i], 2)

var copy = arr
assert_equal(arr[0], copy[0])
assert_equal(arr[1], copy[1])
assert_equal(arr[2], copy[2])

var move = arr^
assert_equal(copy[0], move[0])
assert_equal(copy[1], move[1])
assert_equal(copy[2], move[2])

# fill element initializer
var arr2 = InlineArray[Int, 3](5)
assert_equal(arr2[0], 5)
assert_equal(arr2[1], 5)
assert_equal(arr2[2], 5)

var arr3 = InlineArray[Int, 1](5)
assert_equal(arr2[0], 5)


def test_array_str():
var arr = InlineArray[String, 3]("hi", "hello", "hey")

assert_equal(arr[0], "hi")
assert_equal(arr[1], "hello")
assert_equal(arr[2], "hey")

# Test mutating an array through its __refitem__
arr[0] = "howdy"
arr[1] = "morning"
arr[2] = "wazzup"

assert_equal(arr[0], "howdy")
assert_equal(arr[1], "morning")
assert_equal(arr[2], "wazzup")

# test negative indexing
assert_equal(arr[-1], "wazzup")
assert_equal(arr[-2], "morning")

var copy = arr
assert_equal(arr[0], copy[0])
assert_equal(arr[1], copy[1])
assert_equal(arr[2], copy[2])

var move = arr^
assert_equal(copy[0], move[0])
assert_equal(copy[1], move[1])
assert_equal(copy[2], move[2])

# fill element initializer
var arr2 = InlineArray[String, 3]("hi")
assert_equal(arr2[0], "hi")
assert_equal(arr2[1], "hi")
assert_equal(arr2[2], "hi")

# size 1 array to prevent regressions in the constructors
var arr3 = InlineArray[String, 1]("hi")
assert_equal(arr3[0], "hi")


def main():
test_static_tuple()
test_static_int_tuple()
test_tuple_literal()
test_array_int()
test_array_str()

0 comments on commit 904ac4e

Please sign in to comment.