diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index cc148416a5..47f4b88707 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -39,7 +39,7 @@ def test_multi_entry_block_1(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" @@ -91,7 +91,7 @@ def test_multi_entry_block_2(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" @@ -132,7 +132,7 @@ def test_multi_entry_block_with_dynamic_jump(): assert ctx.normalized, "CFG should be normalized" finish_bb = ctx.get_basic_block(finish_label.value) - cfg_in = list(finish_bb.cfg_in.keys()) + cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish" diff --git a/vyper/utils.py b/vyper/utils.py index ba615e58d7..0fd74b8972 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -1,4 +1,5 @@ import binascii +import itertools import contextlib import decimal import enum @@ -15,7 +16,7 @@ _T = TypeVar("_T") -class OrderedSet(Generic[_T], dict[_T, None]): +class OrderedSet(Generic[_T]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -25,57 +26,82 @@ class OrderedSet(Generic[_T], dict[_T, None]): """ def __init__(self, iterable=None): - super().__init__() + self._data = dict() if iterable is not None: - for item in iterable: - self.add(item) + self.update(iterable) def __repr__(self): - keys = ", ".join(repr(k) for k in self.keys()) + keys = ", ".join(repr(k) for k in self) return f"{{{keys}}}" - def get(self, *args, **kwargs): - raise RuntimeError("can't call get() on OrderedSet!") + def __iter__(self): + return iter(self._data) + + def __contains__(self, item): + return self._data.__contains__(item) + + def __len__(self): + return len(self._data) def first(self): return next(iter(self)) def add(self, item: _T) -> None: - self[item] = None + self._data[item] = None def remove(self, item: _T) -> None: - del self[item] + del self._data[item] + + def drop(self, item: _T): + # friendly version of remove + self._data.pop(item, None) + + def dropmany(self, iterable): + for item in iterable: + self._data.pop(item, None) def difference(self, other): ret = self.copy() - for k in other.keys(): - if k in ret: - ret.remove(k) + ret.dropmany(other) return ret + def update(self, other): + # CMC 2024-03-22 for some reason, this is faster than dict.update? + # (maybe size dependent) + for item in other: + self._data[item] = None + def union(self, other): return self | other - def update(self, other): - super().update(self.__class__.fromkeys(other)) + def __ior__(self, other): + self.update(other) + return self def __or__(self, other): - return self.__class__(super().__or__(other)) + ret = self.copy() + ret.update(other) + return ret + + def __eq__(self, other): + return self._data == other._data def copy(self): - return self.__class__(super().copy()) + cls = self.__class__ + ret = cls.__new__(cls) + ret._data = self._data.copy() + return ret @classmethod def intersection(cls, *sets): - res = OrderedSet() if len(sets) == 0: raise ValueError("undefined: intersection of no sets") - if len(sets) == 1: - return sets[0].copy() - for e in sets[0].keys(): - if all(e in s for s in sets[1:]): - res.add(e) - return res + + ret = sets[0].copy() + for e in sets[0]: + if any(e not in s for s in sets[1:]): + ret.remove(e) + return ret class StringEnum(enum.Enum): diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 0bc1ec0fac..4be4e48f28 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -1,4 +1,5 @@ from typing import Optional +import itertools from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet @@ -53,12 +54,15 @@ def _calculate_liveness(bb: IRBasicBlock) -> bool: orig_liveness = bb.instructions[0].liveness.copy() liveness = bb.out_vars.copy() for instruction in reversed(bb.instructions): - ops = instruction.get_inputs() + ins = instruction.get_inputs() + outs = instruction.get_outputs() + + if ins or outs: + # perf: only copy if changed + liveness = liveness.copy() + liveness.update(ins) + liveness.dropmany(outs) - liveness = liveness.union(OrderedSet.fromkeys(ops)) - out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None - if out in liveness: - liveness.remove(out) instruction.liveness = liveness return orig_liveness != bb.instructions[0].liveness diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 3171fd0172..2bfeaea2a9 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,5 +1,6 @@ from enum import Enum, auto from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional, Union +from functools import cached_property from vyper.utils import OrderedSet diff --git a/vyper/venom/dominators.py b/vyper/venom/dominators.py index ce6a4868cf..d11d61cec3 100644 --- a/vyper/venom/dominators.py +++ b/vyper/venom/dominators.py @@ -58,7 +58,7 @@ def _compute_dominators(self): """ basic_blocks = list(self.dfs_order.keys()) self.dominators = {bb: OrderedSet(basic_blocks) for bb in basic_blocks} - self.dominators[self.entry_block] = OrderedSet({self.entry_block}) + self.dominators[self.entry_block] = OrderedSet([self.entry_block]) changed = True count = len(basic_blocks) ** 2 # TODO: find a proper bound for this while changed: diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 81d46fec3e..e52343ca32 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -49,6 +49,7 @@ def __init__(self, name: IRLabel = None) -> None: self._ast_source_stack = [] self._error_msg_stack = [] + self._bb_index = {} self.add_entry_point(name) self.append_basic_block(IRBasicBlock(name, self)) @@ -76,6 +77,15 @@ def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: return self.basic_blocks[-1] + def _get_basicblock_index(self, label: str): + ix = self._bb_index.get(label, -1) + if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label: + return ix + # do a reindex + self._bb_index = dict((bb.label, ix) for ix, bb in enumerate(self.basic_blocks)) + return self._bb_index[label] + + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ Get basic block by label. @@ -83,18 +93,16 @@ def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ if label is None: return self.basic_blocks[-1] - for bb in self.basic_blocks: - if bb.label.value == label: - return bb - raise AssertionError(f"Basic block '{label}' not found") + ix = self._get_basicblock_index(label) + return self.basic_blocks[ix] def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: """ Get basic block after label. """ - for i, bb in enumerate(self.basic_blocks[:-1]): - if bb.label.value == label.value: - return self.basic_blocks[i + 1] + ix = self._get_basicblock_index(label.value) + if 0 <= ix < len(self.basic_blocks) - 1: + return self.basic_blocks[ix + 1] raise AssertionError(f"Basic block after '{label}' not found") def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]: diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index d702b8dc2b..0cb13becf2 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -407,7 +407,7 @@ def _generate_evm_for_instruction( # NOTE: stack in general can contain multiple copies of the same variable, # however we are safe in the case of jmp/djmp/jnz as it's not going to # have multiples. - target_stack_list = list(target_stack.keys()) + target_stack_list = list(target_stack) self._stack_reorder(assembly, stack, target_stack_list) # final step to get the inputs to this instruction ordered