Skip to content

Commit

Permalink
[mojo-stdlib] Massively simplify parametric mutability reference usag…
Browse files Browse the repository at this point in the history
…e (#39369)

Now that we can declare `self` to be a `Reference` and not just a
`!lit.ref` we can use parameter inference to deduce the mutability
and lifetime and use `_` expressions to avoid having to explicitly
declare them.

This radically simplifies the usage across the stdlib.

MODULAR_ORIG_COMMIT_REV_ID: 123cd20425d4a4adcd609cc74a254e8ab116f29f
  • Loading branch information
lattner authored and JoeLoser committed May 6, 2024
1 parent 8985c12 commit 0fe4cb7
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 259 deletions.
24 changes: 9 additions & 15 deletions stdlib/src/builtin/reversed.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,17 @@ fn reversed[T: ReversibleRange](value: T) -> _StridedRangeIterator:


fn reversed[
mutability: __mlir_type.`i1`,
self_life: AnyLifetime[mutability].type,
T: CollectionElement,
T: CollectionElement
](
value: Reference[List[T], mutability, self_life]._mlir_type,
value: Reference[List[T], _, _],
) -> _ListIter[
T, mutability, self_life, False
T, value.is_mutable, value.lifetime, False
]:
"""Get a reversed iterator of the input list.
**Note**: iterators are currently non-raising.
Parameters:
mutability: Whether the reference to the list is mutable.
self_life: The lifetime of the list.
T: The type of the elements in the list.
Args:
Expand All @@ -100,24 +96,22 @@ fn reversed[
Returns:
The reversed iterator of the list.
"""
return Reference(value)[].__reversed__[mutability, self_life]()
return value[].__reversed__()


fn reversed[
mutability: __mlir_type.`i1`,
self_life: AnyLifetime[mutability].type,
K: KeyElement,
V: CollectionElement,
](
value: Reference[Dict[K, V], mutability, self_life]._mlir_type,
) -> _DictKeyIter[K, V, mutability, self_life, False]:
value: Reference[Dict[K, V], _, _],
) -> _DictKeyIter[
K, V, value.is_mutable, value.lifetime, False
]:
"""Get a reversed iterator of the input dict.
**Note**: iterators are currently non-raising.
Parameters:
mutability: Whether the reference to the dict is mutable.
self_life: The lifetime of the dict.
K: The type of the keys in the dict.
V: The type of the values in the dict.
Expand All @@ -127,4 +121,4 @@ fn reversed[
Returns:
The reversed iterator of the dict.
"""
return Reference(value)[].__reversed__[mutability, self_life]()
return value[].__reversed__()
14 changes: 5 additions & 9 deletions stdlib/src/builtin/tuple.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,19 @@ struct Tuple[*element_types: CollectionElement](Sized, CollectionElement):

@always_inline("nodebug")
fn __refitem__[
idx: Int,
mutability: __mlir_type.i1,
self_life: AnyLifetime[mutability].type,
](self_lit: Reference[Self, mutability, self_life]._mlir_type) -> Reference[
element_types[idx.value], mutability, self_life
idx: Int
](self: Reference[Self, _, _]) -> Reference[
element_types[idx.value], self.is_mutable, self.lifetime
]:
# Return a reference to an element at the specified index, propagating
# mutability of self.
var storage_kgen_ptr = UnsafePointer.address_of(
Reference(self_lit)[].storage
).address
var storage_kgen_ptr = UnsafePointer.address_of(self[].storage).address

# KGenPointer to the element.
var elt_kgen_ptr = __mlir_op.`kgen.pack.gep`[index = idx.value](
storage_kgen_ptr
)
# Convert to an immortal mut reference, which conforms to self_life.
# Use an immortal mut reference, which converts to self's lifetime.
return UnsafePointer(elt_kgen_ptr)[]

# TODO(#38268): Remove this method when references and parameter expressions
Expand Down
160 changes: 38 additions & 122 deletions stdlib/src/collections/dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,8 @@ struct _DictEntryIter[
else:
debug_assert(self.index >= 0, "dict iter bounds")

if self.src[]._entries.__get_ref(self.index)[]:
var opt_entry_ref = self.src[]._entries.__get_ref[
__mlir_attr.`0: i1`,
Self.imm_dict_lifetime,
](self.index)
var opt_entry_ref = self.src[]._entries.__get_ref(self.index)
if opt_entry_ref[]:

@parameter
if forward:
Expand All @@ -109,9 +106,7 @@ struct _DictEntryIter[
self.index -= 1

self.seen += 1
# Super unsafe, but otherwise we have to do a bunch of super
# unsafe reference lifetime casting.
return opt_entry_ref.unsafe_bitcast[DictEntry[K, V]]()
return opt_entry_ref[].value()[]

@parameter
if forward:
Expand Down Expand Up @@ -644,67 +639,39 @@ struct Dict[K: KeyElement, V: CollectionElement](
return default.value()[]
raise "KeyError"

fn __iter__[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictKeyIter[
K, V, mutability, self_life
]:
fn __iter__(
self: Reference[Self, _, _],
) -> _DictKeyIter[K, V, self.is_mutable, self.lifetime]:
"""Iterate over the dict's keys as immutable references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary keys.
"""
return _DictKeyIter(
_DictEntryIter[K, V, mutability, self_life](0, 0, Reference(self))
)
return _DictKeyIter(_DictEntryIter(0, 0, self))

fn keys[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictKeyIter[
K, V, mutability, self_life
]:
fn keys(
self: Reference[Self, _, _]
) -> _DictKeyIter[K, V, self.is_mutable, self.lifetime]:
"""Iterate over the dict's keys as immutable references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary keys.
"""
return Self.__iter__(self)

fn values[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictValueIter[K, V, mutability, self_life]:
fn values(
self: Reference[Self, _, _]
) -> _DictValueIter[K, V, self.is_mutable, self.lifetime]:
"""Iterate over the dict's values as references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of references to the dictionary values.
"""
return _DictValueIter(
_DictEntryIter[K, V, mutability, self_life](0, 0, Reference(self))
)
return _DictValueIter(_DictEntryIter(0, 0, self))

fn items[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictEntryIter[K, V, mutability, self_life]:
fn items(
self: Reference[Self, _, _]
) -> _DictEntryIter[K, V, self.is_mutable, self.lifetime]:
"""Iterate over the dict's entries as immutable references.
These can't yet be unpacked like Python dict items, but you can
Expand All @@ -715,16 +682,10 @@ struct Dict[K: KeyElement, V: CollectionElement](
print(e[].key, e[].value)
```
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary entries.
"""
return _DictEntryIter[K, V, mutability, self_life](
0, 0, Reference(self)
)
return _DictEntryIter(0, 0, self)

fn update(inout self, other: Self, /):
"""Update the dictionary with the key/value pairs from other, overwriting existing keys.
Expand Down Expand Up @@ -838,23 +799,16 @@ struct Dict[K: KeyElement, V: CollectionElement](

self._n_entries = self.size

fn __reversed__[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictKeyIter[
K, V, mutability, self_life, False
]:
fn __reversed__(
self: Reference[Self, _, _]
) -> _DictKeyIter[K, V, self.is_mutable, self.lifetime, False]:
"""Iterate backwards over the dict keys, returning immutable references.
Returns:
A reversed iterator of immutable references to the dict keys.
"""
var ref = Reference(self)
return _DictKeyIter(
_DictEntryIter[K, V, mutability, self_life, False](
ref[]._reserved - 1, 0, ref
)
_DictEntryIter[forward=False](self[]._reserved - 1, 0, self)
)


Expand Down Expand Up @@ -970,77 +924,45 @@ struct OwnedKwargsDict[V: CollectionElement](Sized, CollectionElement):
"""
return self._dict.pop(key, default^)

fn __iter__[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictKeyIter[
Self.key_type, V, mutability, self_life
]:
fn __iter__(
self: Reference[Self, _, _]
) -> _DictKeyIter[Self.key_type, V, self.is_mutable, self.lifetime]:
"""Iterate over the keyword dict's keys as immutable references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary keys.
"""
# TODO(#36448): Use this instead of the current workaround
# return self._dict.__iter__()
return _DictKeyIter(
_DictEntryIter[Self.key_type, V, mutability, self_life](
0, 0, Reference(self)[]._dict
)
)
return _DictKeyIter(_DictEntryIter(0, 0, self[]._dict))

fn keys[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictKeyIter[
Self.key_type, V, mutability, self_life
]:
fn keys(
self: Reference[Self, _, _],
) -> _DictKeyIter[Self.key_type, V, self.is_mutable, self.lifetime]:
"""Iterate over the keyword dict's keys as immutable references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary keys.
"""
# TODO(#36448): Use this instead of the current workaround
# return self._dict.keys()
return Self.__iter__(self)

fn values[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictValueIter[Self.key_type, V, mutability, self_life]:
fn values(
self: Reference[Self, _, _],
) -> _DictValueIter[Self.key_type, V, self.is_mutable, self.lifetime]:
"""Iterate over the keyword dict's values as references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of references to the dictionary values.
"""
# TODO(#36448): Use this instead of the current workaround
# return self._dict.values()
return _DictValueIter(
_DictEntryIter[Self.key_type, V, mutability, self_life](
0, 0, Reference(self)[]._dict
)
)
return _DictValueIter(_DictEntryIter(0, 0, self[]._dict))

fn items[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life]._mlir_type,
) -> _DictEntryIter[Self.key_type, V, mutability, self_life]:
fn items(
self: Reference[Self, _, _]
) -> _DictEntryIter[Self.key_type, V, self.is_mutable, self.lifetime]:
"""Iterate over the keyword dictionary's entries as immutable references.
These can't yet be unpacked like Python dict items, but you can
Expand All @@ -1051,19 +973,13 @@ struct OwnedKwargsDict[V: CollectionElement](Sized, CollectionElement):
print(e[].key, e[].value)
```
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary entries.
"""

# TODO(#36448): Use this instead of the current workaround
# return Reference(self)[]._dict.items()
return _DictEntryIter[Self.key_type, V, mutability, self_life](
0, 0, Reference(self)[]._dict
)
# return self[]._dict.items()
return _DictEntryIter(0, 0, self[]._dict)

@always_inline("nodebug")
fn _insert(inout self, owned key: Self.key_type, owned value: V):
Expand Down
Loading

0 comments on commit 0fe4cb7

Please sign in to comment.