Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve venom performance #5

Merged
merged 6 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/unit/compiler/venom/test_multi_entry_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_multi_entry_block_1():
assert ctx.normalized, "CFG should be normalized"

finish_bb = ctx.get_basic_block(finish_label.value)
cfg_in = list(finish_bb.cfg_in.keys())
cfg_in = list(finish_bb.cfg_in)
assert cfg_in[0].label.value == "target", "Should contain target"
assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish"
assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish"
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_multi_entry_block_2():
assert ctx.normalized, "CFG should be normalized"

finish_bb = ctx.get_basic_block(finish_label.value)
cfg_in = list(finish_bb.cfg_in.keys())
cfg_in = list(finish_bb.cfg_in)
assert cfg_in[0].label.value == "target", "Should contain target"
assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish"
assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish"
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_multi_entry_block_with_dynamic_jump():
assert ctx.normalized, "CFG should be normalized"

finish_bb = ctx.get_basic_block(finish_label.value)
cfg_in = list(finish_bb.cfg_in.keys())
cfg_in = list(finish_bb.cfg_in)
assert cfg_in[0].label.value == "target", "Should contain target"
assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish"
assert cfg_in[2].label.value == "block_1_split_finish", "Should contain block_1_split_finish"
72 changes: 49 additions & 23 deletions vyper/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import binascii
import itertools
import contextlib
import decimal
import enum
Expand All @@ -15,7 +16,7 @@
_T = TypeVar("_T")


class OrderedSet(Generic[_T], dict[_T, None]):
class OrderedSet(Generic[_T]):
"""
a minimal "ordered set" class. this is needed in some places
because, while dict guarantees you can recover insertion order
Expand All @@ -25,57 +26,82 @@ class OrderedSet(Generic[_T], dict[_T, None]):
"""

def __init__(self, iterable=None):
super().__init__()
self._data = dict()
if iterable is not None:
for item in iterable:
self.add(item)
self.update(iterable)

def __repr__(self):
keys = ", ".join(repr(k) for k in self.keys())
keys = ", ".join(repr(k) for k in self)
return f"{{{keys}}}"

def get(self, *args, **kwargs):
raise RuntimeError("can't call get() on OrderedSet!")
def __iter__(self):
return iter(self._data)

def __contains__(self, item):
return self._data.__contains__(item)

def __len__(self):
return len(self._data)

def first(self):
return next(iter(self))

def add(self, item: _T) -> None:
self[item] = None
self._data[item] = None

def remove(self, item: _T) -> None:
del self[item]
del self._data[item]

def drop(self, item: _T):
# friendly version of remove
self._data.pop(item, None)

def dropmany(self, iterable):
for item in iterable:
self._data.pop(item, None)

def difference(self, other):
ret = self.copy()
for k in other.keys():
if k in ret:
ret.remove(k)
ret.dropmany(other)
return ret

def update(self, other):
# CMC 2024-03-22 for some reason, this is faster than dict.update?
# (maybe size dependent)
for item in other:
self._data[item] = None

def union(self, other):
return self | other

def update(self, other):
super().update(self.__class__.fromkeys(other))
def __ior__(self, other):
self.update(other)
return self

def __or__(self, other):
return self.__class__(super().__or__(other))
ret = self.copy()
ret.update(other)
return ret

def __eq__(self, other):
return self._data == other._data

def copy(self):
return self.__class__(super().copy())
cls = self.__class__
ret = cls.__new__(cls)
ret._data = self._data.copy()
return ret

@classmethod
def intersection(cls, *sets):
res = OrderedSet()
if len(sets) == 0:
raise ValueError("undefined: intersection of no sets")
if len(sets) == 1:
return sets[0].copy()
for e in sets[0].keys():
if all(e in s for s in sets[1:]):
res.add(e)
return res

ret = sets[0].copy()
for e in sets[0]:
if any(e not in s for s in sets[1:]):
ret.remove(e)
return ret


class StringEnum(enum.Enum):
Expand Down
14 changes: 9 additions & 5 deletions vyper/venom/analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
import itertools

from vyper.exceptions import CompilerPanic
from vyper.utils import OrderedSet
Expand Down Expand Up @@ -53,12 +54,15 @@ def _calculate_liveness(bb: IRBasicBlock) -> bool:
orig_liveness = bb.instructions[0].liveness.copy()
liveness = bb.out_vars.copy()
for instruction in reversed(bb.instructions):
ops = instruction.get_inputs()
ins = instruction.get_inputs()
outs = instruction.get_outputs()

if ins or outs:
# perf: only copy if changed
liveness = liveness.copy()
liveness.update(ins)
liveness.dropmany(outs)

liveness = liveness.union(OrderedSet.fromkeys(ops))
out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None
if out in liveness:
liveness.remove(out)
instruction.liveness = liveness

return orig_liveness != bb.instructions[0].liveness
Expand Down
1 change: 1 addition & 0 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional, Union
from functools import cached_property

from vyper.utils import OrderedSet

Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/dominators.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _compute_dominators(self):
"""
basic_blocks = list(self.dfs_order.keys())
self.dominators = {bb: OrderedSet(basic_blocks) for bb in basic_blocks}
self.dominators[self.entry_block] = OrderedSet({self.entry_block})
self.dominators[self.entry_block] = OrderedSet([self.entry_block])
changed = True
count = len(basic_blocks) ** 2 # TODO: find a proper bound for this
while changed:
Expand Down
22 changes: 15 additions & 7 deletions vyper/venom/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, name: IRLabel = None) -> None:

self._ast_source_stack = []
self._error_msg_stack = []
self._bb_index = {}

self.add_entry_point(name)
self.append_basic_block(IRBasicBlock(name, self))
Expand Down Expand Up @@ -76,25 +77,32 @@ def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock:

return self.basic_blocks[-1]

def _get_basicblock_index(self, label: str):
ix = self._bb_index.get(label, -1)
if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label:
return ix
# do a reindex
self._bb_index = dict((bb.label, ix) for ix, bb in enumerate(self.basic_blocks))
return self._bb_index[label]


def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock:
"""
Get basic block by label.
If label is None, return the last basic block.
"""
if label is None:
return self.basic_blocks[-1]
for bb in self.basic_blocks:
if bb.label.value == label:
return bb
raise AssertionError(f"Basic block '{label}' not found")
ix = self._get_basicblock_index(label)
return self.basic_blocks[ix]

def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock:
"""
Get basic block after label.
"""
for i, bb in enumerate(self.basic_blocks[:-1]):
if bb.label.value == label.value:
return self.basic_blocks[i + 1]
ix = self._get_basicblock_index(label.value)
if 0 <= ix < len(self.basic_blocks) - 1:
return self.basic_blocks[ix + 1]
raise AssertionError(f"Basic block after '{label}' not found")

def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]:
Expand Down
2 changes: 1 addition & 1 deletion vyper/venom/venom_to_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _generate_evm_for_instruction(
# NOTE: stack in general can contain multiple copies of the same variable,
# however we are safe in the case of jmp/djmp/jnz as it's not going to
# have multiples.
target_stack_list = list(target_stack.keys())
target_stack_list = list(target_stack)
self._stack_reorder(assembly, stack, target_stack_list)

# final step to get the inputs to this instruction ordered
Expand Down
Loading