diff --git a/LICENSE b/LICENSE index 7d184d50..061938a8 100644 --- a/LICENSE +++ b/LICENSE @@ -17,3 +17,53 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +-------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF +hereby grants Licensee a nonexclusive, royalty-free, world-wide +license to reproduce, analyze, test, perform and/or display publicly, +prepare derivative works, distribute, and otherwise use Python +alone or in any derivative version, provided, however, that PSF's +License Agreement and PSF's notice of copyright, i.e., "Copyright (c) +2001, 2002, 2003, 2004, 2005, 2006 Python Software Foundation; All Rights +Reserved" are retained in Python alone or in any derivative version +prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. diff --git a/byterun/abstractvm.py b/byterun/abstractvm.py new file mode 100644 index 00000000..6732fe6b --- /dev/null +++ b/byterun/abstractvm.py @@ -0,0 +1,226 @@ +"""Classes to ease the abstraction of pyvm2.VirtualMachine. + +This module provides 2 classes that provide different kinds of +abstraction. AbstractVirtualMachine abstracts operators and other magic method +uses. AncestorTraversalVirtualMachine changes the execution order of basic +blocks so that each only executes once. +""" + +import logging + + +from byterun import pycfg +from byterun import pyvm2 +import six + +log = logging.getLogger(__name__) + + +class AbstractVirtualMachine(pyvm2.VirtualMachine): + """A base class for abstract interpreters based on VirtualMachine. + + AbstractVirtualMachine replaces the default metacyclic implementation of + operators and other operations that actually forward to a python magic + method with a virtual machine level attribute get and a call to the + returned method. + """ + + def __init__(self): + super(AbstractVirtualMachine, self).__init__() + # The key is the instruction suffix and the value is the magic method + # name. + binary_operator_name_mapping = dict( + ADD="__add__", + AND="__and__", + DIVIDE="__div__", + FLOOR_DIVIDE="__floordiv__", + LSHIFT="__lshift__", + MODULO="__mod__", + MULTIPLY="__mul__", + OR="__or__", + POWER="__pow__", + RSHIFT="__rshift__", + SUBSCR="__getitem__", + SUBTRACT="__sub__", + TRUE_DIVIDE="__truediv__", + XOR="__xor__", + ) + # Use the above data to generate wrappers for each magic operators. This + # replaces the original dict since any operator that is not listed here + # will not work, so it is better to have it cause a KeyError. + self.binary_operators = dict((op, self.magic_operator(magic)) + for op, magic in + binary_operator_name_mapping.iteritems()) + # TODO(ampere): Add support for unary and comparison operators + + def magic_operator(self, name): + # TODO(ampere): Implement support for r-operators + def magic_operator_wrapper(x, y): + return self.call_function(self.load_attr(x, name), + [y], {}) + return magic_operator_wrapper + + reversable_operators = set([ + "__add__", "__sub__", "__mul__", + "__div__", "__truediv__", "__floordiv__", + "__mod__", "__divmod__", "__pow__", + "__lshift__", "__rshift__", "__and__", "__or__", "__xor__" + ]) + + @staticmethod + def reverse_operator_name(name): + if name in AbstractVirtualMachine.reversable_operators: + return "__r" + name[2:] + return None + + def build_slice(self, start, stop, step): + return slice(start, stop, step) + + def byte_GET_ITER(self): + self.push(self.load_attr(self.pop(), "__iter__")) + self.call_function_from_stack(0, [], {}) + + def byte_FOR_ITER(self, jump): + try: + self.push(self.load_attr(self.top(), "next")) + self.call_function_from_stack(0, [], {}) + self.jump(self.frame.f_lasti) + except StopIteration: + self.pop() + self.jump(jump) + + def byte_STORE_MAP(self): + # pylint: disable=unbalanced-tuple-unpacking + the_map, val, key = self.popn(3) + self.store_subscr(the_map, key, val) + self.push(the_map) + + def del_subscr(self, obj, key): + self.call_function(self.load_attr(obj, "__delitem__"), + [key], {}) + + def store_subscr(self, obj, key, val): + self.call_function(self.load_attr(obj, "__setitem__"), + [key, val], {}) + + def sliceOperator(self, op): # pylint: disable=invalid-name + start = 0 + end = None # we will take this to mean end + op, count = op[:-2], int(op[-1]) + if count == 1: + start = self.pop() + elif count == 2: + end = self.pop() + elif count == 3: + end = self.pop() + start = self.pop() + l = self.pop() + if end is None: + end = self.call_function(self.load_attr(l, "__len__"), [], {}) + if op.startswith('STORE_'): + self.call_function(self.load_attr(l, "__setitem__"), + [self.build_slice(start, end, 1), self.pop()], + {}) + elif op.startswith('DELETE_'): + self.call_function(self.load_attr(l, "__delitem__"), + [self.build_slice(start, end, 1)], + {}) + else: + self.push(self.call_function(self.load_attr(l, "__getitem__"), + [self.build_slice(start, end, 1)], + {})) + + def byte_UNPACK_SEQUENCE(self, count): + seq = self.pop() + itr = self.call_function(self.load_attr(seq, "__iter__"), [], {}) + values = [] + for _ in range(count): + # TODO(ampere): Fix for python 3 + values.append(self.call_function(self.load_attr(itr, "next"), + [], {})) + for value in reversed(values): + self.push(value) + + +class AncestorTraversalVirtualMachine(AbstractVirtualMachine): + """An abstract interpreter implementing a traversal of basic blocks. + + This class replaces run_frame with a traversal that executes all basic + blocks in ancestor first order starting with the entry block. This uses + pycfg.BlockTable.get_ancestors_first_traversal(); see it's documentation for + more information about the order. + + As the traversal is done there is no attempt to rollback the state, so + parallel paths in the CFG (even those that cannot be run in the same + execution) will often see each other's side-effects. Effectively this means + that the execution of each basic block needs to commute with the execution + of other blocks it is not ordered with. + """ + + def __init__(self): + super(AncestorTraversalVirtualMachine, self).__init__() + self.cfg = pycfg.CFG() + + def frame_traversal_setup(self, frame): + """Initialize a frame to allow ancestors first traversal. + + Args: + frame: The execution frame to update. + """ + frame.block_table = self.cfg.get_block_table(frame.f_code) + frame.order = frame.block_table.get_ancestors_first_traversal() + assert frame.f_lasti == 0 + + def frame_traversal_next(self, frame): + """Move the frame instruction pointer to the next instruction. + + This implements the next instruction operation on the ancestors first + traversal order. + + Args: + frame: The execution frame to update. + + Returns: + False if the traversal is done (every instruction in the frames code + has been executed. True otherwise. + """ + head = frame.order[0] + if frame.f_lasti < head.begin or frame.f_lasti > head.end: + frame.order.pop(0) + if not frame.order: + return False + head = frame.order[0] + if frame.f_lasti != head.begin: + log.debug("natural next %d, order next %d", + frame.f_lasti, head.begin) + frame.f_lasti = head.begin + return True + + def run_frame(self, frame): + """Run a frame until it returns (somehow). + + Exceptions are raised, the return value is returned. + + This implementation executes in ancestors first order. See + pycfg.BlockTable.get_ancestors_first_traversal(). + + Args: + frame: The execution frame. + + Returns: + The return value of the frame after execution. + """ + self.push_frame(frame) + self.frame_traversal_setup(frame) + while True: + why = self.run_instruction() + # TODO(ampere): Store various breaking "why"s so they can be handled + if not self.frame_traversal_next(frame): + break + self.pop_frame() + + # TODO(ampere): We don't really support exceptions. + if why == "exception": + six.reraise(*self.last_exception) + + return self.return_value diff --git a/byterun/execfile.py b/byterun/execfile.py index d157a2ab..d31ab98a 100644 --- a/byterun/execfile.py +++ b/byterun/execfile.py @@ -105,6 +105,7 @@ def run_python_file(filename, args, package=None): main_mod.__file__ = filename if package: main_mod.__package__ = package + # TODO(ampere): This may be incorrect if we are overriding builtins main_mod.__builtins__ = BUILTINS # Set sys.argv and the first path element properly. diff --git a/byterun/pycfg.py b/byterun/pycfg.py new file mode 100644 index 00000000..1df202e1 --- /dev/null +++ b/byterun/pycfg.py @@ -0,0 +1,599 @@ +"""Build a Control Flow Graph (CFG) from CPython bytecode. + +A class that builds and provides access to a CFG built from CPython bytecode. + +For a basic introduction to CFGs see the wikipedia article: +http://en.wikipedia.org/wiki/Control_flow_graph +""" + +import bisect +import dis +import itertools +import logging + + +import six + +PY3, PY2 = six.PY3, not six.PY3 + +if six.PY3: + byteint = lambda b: b +else: + byteint = ord + +log = logging.getLogger(__name__) + +# The following sets contain instructions with specific branching properties. + +# Untargetted unconditional jumps always jump, but do so to some statically +# unknown location. Examples include, raising exceptions and returning from +# functions: in both cases you are jumping but you cannot statically determine +# to where. +_UNTARGETTED_UNCONDITIONAL_JUMPS = frozenset([ + dis.opmap["BREAK_LOOP"], + dis.opmap["RETURN_VALUE"], + dis.opmap["RAISE_VARARGS"], + ]) + +# Untargetted conditional jumps may jump to a statically unknown location, but +# may allow control to continue to the next instruction. +_UNTARGETTED_CONDITIONAL_JUMPS = frozenset([ + dis.opmap["END_FINALLY"], + dis.opmap["EXEC_STMT"], + dis.opmap["WITH_CLEANUP"], + dis.opmap["IMPORT_NAME"], + dis.opmap["IMPORT_FROM"], + dis.opmap["IMPORT_STAR"], + dis.opmap["CALL_FUNCTION"], + dis.opmap["CALL_FUNCTION_VAR"], + dis.opmap["CALL_FUNCTION_KW"], + dis.opmap["CALL_FUNCTION_VAR_KW"], + dis.opmap["YIELD_VALUE"], # yield is treated as both branching somewhere + # unknown and to the next instruction. + ]) + +# Targetted unconditional jumps always jump to a statically known target +# instruction. +_TARGETTED_UNCONDITIONAL_JUMPS = frozenset([ + dis.opmap["CONTINUE_LOOP"], + dis.opmap["JUMP_FORWARD"], + dis.opmap["JUMP_ABSOLUTE"], + ]) + +# Targetted conditional jumps either jump to a statically known target or they +# continue to the next instruction. +_TARGETTED_CONDITIONAL_JUMPS = frozenset([ + dis.opmap["POP_JUMP_IF_TRUE"], + dis.opmap["POP_JUMP_IF_FALSE"], + dis.opmap["JUMP_IF_TRUE_OR_POP"], + dis.opmap["JUMP_IF_FALSE_OR_POP"], + dis.opmap["FOR_ITER"], + ]) + +_TARGETTED_JUMPS = (_TARGETTED_CONDITIONAL_JUMPS | + _TARGETTED_UNCONDITIONAL_JUMPS) + +_CONDITIONAL_JUMPS = (_TARGETTED_CONDITIONAL_JUMPS | + _UNTARGETTED_CONDITIONAL_JUMPS) + +_UNTARGETTED_JUMPS = (_UNTARGETTED_CONDITIONAL_JUMPS | + _UNTARGETTED_UNCONDITIONAL_JUMPS) + + +def _parse_instructions(code): + """A generator yielding each instruction in code. + + Args: + code: A bytecode string (not a code object). + + Yields: + A triple (opcode, argument or None, offset) for each instruction in code. + Where offset is the byte offset of the beginning of the instruction. + + This is derived from dis.findlabels in the Python standard library. + """ + n = len(code) + i = 0 + while i < n: + offset = i + op = byteint(code[i]) + i += 1 + oparg = None + if op >= dis.HAVE_ARGUMENT: + oparg = byteint(code[i]) + byteint(code[i+1])*256 + i += 2 + yield (op, oparg, offset) + + +class InstructionsIndex(object): + """An index of all the instructions in a code object. + + Attributes: + instruction_offsets: A list of instruction offsets. + """ + + def __init__(self, code): + self.instruction_offsets = [i for _, _, i in _parse_instructions(code)] + + def prev(self, offset): + """Return the offset of the previous instruction. + + Args: + offset: The offset of an instruction in the code. + + Returns: + The offset of the instruction immediately before the instruction specified + by the offset argument. + + Raises: + IndexError: If the offset is outside the range of valid instructions. + """ + if offset < 0: + raise IndexError("Instruction offset cannot be less than 0") + if offset > self.instruction_offsets[-1]: + raise IndexError("Instruction offset cannot be greater than " + "the offset of the last instruction") + # Find the rightmost instruction offset that is less than the offset + # argument, this will be the previous instruction because it is closest + # instruction that is before the offset. + return self.instruction_offsets[ + bisect.bisect_left(self.instruction_offsets, offset) - 1] + + def next(self, offset): + """Return the offset of the next instruction. + + Args: + offset: The offset of an instruction in the code. + + Returns: + The offset of the instruction immediately after the instruction specified + by the offset argument. + + Raises: + IndexError: If the offset is outside the range of valid instructions. + """ + if offset < 0: + raise IndexError("Instruction offset cannot be less than 0") + if offset > self.instruction_offsets[-1]: + raise IndexError("Instruction offset cannot be greater than " + "the offset of the last instruction") + # Find the leftmost instruction offset that is greater than the offset + # argument, this will be the next instruction because it is closest + # instruction that is after the offset. + return self.instruction_offsets[ + bisect.bisect_right(self.instruction_offsets, offset)] + + +def _find_jumps(code): + """Detect all offsets in a byte code which are instructions that can jump. + + Args: + code: A bytecode string (not a code object). + + Returns: + A pair of a dict and set. The dict mapping the offsets of jump instructions + to sets with the same semantics as outgoing in Block. The set of all the + jump targets it found. + """ + all_targets = set() + jumps = {} + for op, oparg, i in _parse_instructions(code): + targets = set() + is_jump = False + next_i = i + 1 if oparg is None else i + 3 + if oparg is not None: + if op in _TARGETTED_JUMPS: + # Add the known jump target + is_jump = True + if op in dis.hasjrel: + targets.add(next_i+oparg) + all_targets.add(next_i+oparg) + elif op in dis.hasjabs: + targets.add(oparg) + all_targets.add(oparg) + else: + targets.add(None) + + if op in _CONDITIONAL_JUMPS: + # The jump is conditional so add the next instruction as a target + is_jump = True + targets.add(next_i) + all_targets.add(next_i) + if op in _UNTARGETTED_JUMPS: + # The jump is untargetted so add None to mean unknown target + is_jump = True + targets.add(None) + + if is_jump: + jumps[i] = targets + return jumps, all_targets + + +class Block(object): + """A Block instance represents a basic block in the CFG. + + Each basic block has at most one jump instruction which is always at the + end. In this representation we will not add forward jumps to blocks that don't + have them and instead just take a block that has no jump instruction as + implicitly jumping to the next instruction when it reaches the end of the + block. Control may only jump to the beginning of a basic block, so if any + instruction in a basic block executes they all do and they do so in order. + + Attributes: + + begin, end: The beginning and ending (resp) offsets of the basic block in + bytes. + + outgoing: A set of blocks that the last instruction of this basic block can + branch to. A None in this set denotes that there are statically + unknown branch targets (due to exceptions, for instance). + + incoming: A set of blocks that can branch to the beginning of this + basic block. + + code: The code object that contains this basic block. + + This object uses the identity hash and equality. This is correct as there + should never be more than one block object that represents the same actual + basic block. + """ + + def __init__(self, begin, end, code, block_table): + self.outgoing = set() + self.incoming = set() + self._dominators = set() + self._reachable_from = None + self.begin = begin + self.end = end + self.code = code + self.block_table = block_table + + def reachable_from(self, other): + """Return true if self is reachable from other. + + Args: + other: A block. + + Returns: + A boolean + """ + return other in self._reachable_from + + def dominates(self, other): + """Return true if self dominates other. + + Args: + other: A block. + + Returns: + A boolean + """ + # This is an instance of my own class and this inversion makes the interface + # cleaner + # pylint: disable=protected-access + return self in other._dominators + + def get_name(self): + return "{}:{}-{}".format(self.block_table.get_filename(), + self.block_table.get_line(self.begin), + self.block_table.get_line(self.end)) + + def __repr__(self): + return "{}(outgoing={{{}}},incoming={{{}}})".format( + self.get_name(), + ", ".join(b.get_name() for b in self.outgoing), + ", ".join(b.get_name() for b in self.incoming)) + + +class BlockTable(object): + """A table of basic blocks in a single bytecode object. + + A None in an outgoing list means that that block can branch to an unknown + location (usually by returning or raising an exception). At the moment, + continue and break are also treated this way, however it will be possible to + remove them as the static target is known from the enclosing SETUP_LOOP + instruction. + + The algorithm to build the Control Flow Graph (CFG) is the naive algorithm + presented in many compilers classes and probably most compiler text books. We + simply find all the instructions where CFGs end and begin, make sure they + match up (there is a begin after every end), and then build a basic block for + ever range between a beginning and an end. This may not produce the smallest + possible CFG, but it will produce a correct one because every branch point + becomes the end of a basic block and every instruction that is branched to + becomes the beginning of a basic block. + """ + + def __init__(self, code): + """Construct a table with the blocks in the given code object. + + Args: + code: a code object (such as function.func_code) to process. + """ + self.code = code + self.line_offsets, self.lines = zip(*dis.findlinestarts(self.code)) + + instruction_index = InstructionsIndex(code.co_code) + + # Get a map from jump instructions to jump targets and a combined set of all + # targets. + jumps, all_targets = _find_jumps(code.co_code) + + # TODO(ampere): Using dis.findlabels may not be the right + # thing. Specifically it is not clear when the targets of SETUP_* + # instructions should be used to make basic blocks. + + # Make a list of all the directly obvious block begins from the jump targets + # found above and the labels found by dis. + direct_begins = all_targets.union(dis.findlabels(code.co_code)) + + # Any jump instruction must be the end of a basic block. + direct_ends = jumps.viewkeys() + + # The actual sorted list of begins is build using the direct_begins along + # with all instructions that follow a jump instruction. Also the beginning + # of the code is a begin. + begins = [0] + sorted(set(list(direct_begins) + + [instruction_index.next(i) for i in direct_ends + if i < len(code.co_code) - 1])) + # The actual ends are every instruction that proceeds a real block begin and + # the last instruction in the code. Since we included the instruction after + # every jump above this will include every jump and every instruction that + # comes before a target. + ends = ([instruction_index.prev(i) for i in begins if i > 0] + + [instruction_index.instruction_offsets[-1]]) + + # Add targets for the ends of basic blocks that don't have a real jump + # instruction. + for end in ends: + if end not in jumps: + jumps[end] = set([instruction_index.next(end)]) + + # Build a reverse mapping from jump targets to the instructions that jump to + # them. + reversemap = {0: set()} + for (jump, targets) in jumps.items(): + for target in targets: + reversemap.setdefault(target, set()).add(jump) + for begin in begins: + if begin not in reversemap: + reversemap[begin] = set() + + assert len(begins) == len(ends) + + # Build the actual basic blocks by pairing the begins and ends directly. + self._blocks = [Block(begin, end, code=code, block_table=self) + for begin, end in itertools.izip(begins, ends)] + # Build a begins list for use with bisect + self._block_begins = [b.begin for b in self._blocks] + # Fill in incoming and outgoing + for block in self._blocks: + block.outgoing = frozenset(self.get_basic_block(o) if + o is not None else None + for o in jumps[block.end]) + block.incoming = frozenset(self.get_basic_block(o) + for o in reversemap[block.begin]) + # TODO(ampere): Both _dominators and _reachable_from are O(n^2) where n is + # the number of blocks. This could be corrected by using a tree and + # searching down it for lookups. + self._compute_dominators() + # Compute all the reachability information by starting recursion from each + # node. + # TODO(ampere): This could be much more efficient, but graphs are small. + for block in self._blocks: + if not block.incoming: + self._compute_reachable_from(block, frozenset()) + + def get_basic_block(self, index): + """Get the basic block that contains the instruction at the given index.""" + return self._blocks[bisect.bisect_right(self._block_begins, index) - 1] + + def get_line(self, index): + """Get the line number for an instruction. + + Args: + index: The offset of the instruction. + + Returns: + The line number of the specified instruction. + """ + return self.lines[max(bisect.bisect_right(self.line_offsets, index)-1, 0)] + + def get_filename(self): + """Get the filename of the code object used in this table. + + Returns: + The string filename. + """ + return self.code.co_filename + + @staticmethod + def _compute_reachable_from(current, history): + """Compute reachability information starting from current. + + The performs a depth first traversal over the graph and adds information to + Block._reachable_from about what paths reach each node. + + Args: + current: The current node in the traversal. + history: A set of nodes that are on the current path to this node. + """ + orig = current._reachable_from # pylint: disable=protected-access + new = history | (orig or set()) + # The base case is that there is no new information, about this node. This + # comparison is why None is used above; we need to be able to distinguish + # nodes we have never touched from nodes with an empty reachable_from set. + if new != orig: + current._reachable_from = new # pylint: disable=protected-access + for child in current.outgoing: + if child: + # pylint: disable=protected-access + BlockTable._compute_reachable_from(child, history | {current}) + + def reachable_from(self, a, b): + """True if the instruction at a is reachable from the instruction at b.""" + block_a = self.get_basic_block(a) + block_b = self.get_basic_block(b) + if block_a == block_b: + return a >= b + else: + return block_a.reachable_from(block_b) + + def _compute_dominators(self): + """Compute dominators for all nodes by iteration. + """ + # pylint: disable=protected-access + # For accessing Block._dominators + entry = self._blocks[0] + # Initialize dominators for the entry node to itself + entry._dominators = frozenset([entry]) + # Initialize all other nodes to be dominated by all nodes + all_blocks_set = frozenset(self._blocks) + for block in self._blocks[1:]: # all but entry block + block._dominators = all_blocks_set + # Now we perform iteration to solve for the dominators. + while True: + # TODO(ampere): use a worklist here. But graphs are small. + changed = False + for block in self._blocks[1:]: # all but entry block + # Compute new dominator information for block by taking the intersection + # of the dominators of every incoming block and adding itself. + new_dominators = all_blocks_set + for pred in block.incoming: + new_dominators &= pred._dominators + new_dominators |= {block} + # Update only if something changed. + if new_dominators != block._dominators: + block._dominators = new_dominators + changed = True + # If we did a pass without changing anything exit. + if not changed: + break + + def dominates(self, a, b): + """True if the instruction at a dominates the instruction at b.""" + block_a = self.get_basic_block(a) + block_b = self.get_basic_block(b) + if block_a == block_b: + # if they are in the same block domination is the same as instruction + # ordering + return a <= b + else: + return block_a.dominates(block_b) + + def get_ancestors_first_traversal(self): + """Build an ancestors first traversal of the blocks in this table. + + Back edges are detected and handled specially. Specifically, the back edge + is ignored for blocks in the cycle (allowing the cycle to be processed), but + we do not allow the blocks after the loop to come before any block in the + loop. + + Returns: + A list of blocks in the proper order. + """ + # TODO(ampere): This assumes all loops are natural. This may be false, but I + # kinda doubt it. The python compiler is very well behaved. + order = [self._blocks[0]] + # A partially processed block has been added to the order, but is part of a + # loop that has not been fully processed. + partially_processed = set() + worklist = list(self._blocks) + while worklist: + block = worklist.pop(0) + # We can process a block if: + # 1) All forward incoming blocks are in the order + # 2) All partially processed blocks are reachable from this block + forward_incoming = set(b for b in block.incoming + if not block.dominates(b)) + all_forward_incoming_ordered = forward_incoming.issubset(order) + # TODO(ampere): Replace forward_incoming in order check with a counter + # that counts the remaining blocks not in order. Similarly below for + # incoming. + all_partially_processed_reachable = all(b.reachable_from(block) + for b in partially_processed) + if (not all_forward_incoming_ordered or + not all_partially_processed_reachable): + continue + # When a node is processed: + # If all incoming blocks (forward and backward) are in the order add to + # the order or remove from partially_processed as needed + # Otherwise, there are backward incoming blocks that are not in the + # order, and we add block to the order and to partially_processed + # We add children to the work list if we either removed block from + # partially_processed or added it to order. + all_incoming_ordered = block.incoming.issubset(order) + # When adding to the work list remove None outgoing edges since they + # represent unknown targets that we cannot handle. + children = filter(None, block.outgoing) + if all_incoming_ordered: + if block in partially_processed: + # block was waiting on a cycle it is part of, but now the cycle is + # processed. + partially_processed.remove(block) + worklist += children + elif block not in order: + # block is ready to add and is not in the order. + order.append(block) + worklist += children + elif block not in order: + # block is not in the order and is part of a cycle. + partially_processed.add(block) + order.append(block) + worklist += children + return order + + +class CFG(object): + """A Control Flow Graph object. + + The CFG may contain any number of code objects, but edges never go between + code objects. + """ + + def __init__(self): + """Initialize a CFG object.""" + self._block_tables = {} + + def get_block_table(self, code): + """Get (building if needed) the BlockTable for a given code object.""" + if code in self._block_tables: + ret = self._block_tables[code] + else: + ret = BlockTable(code) + self._block_tables[code] = ret + return ret + + def get_basic_block(self, code, index): + """Get a basic block by code object and index.""" + blocktable = self.get_block_table(code) + return blocktable.get_basic_block(index) + + +def _bytecode_repr(code): + """Generate a python expression that evaluates to the bytecode. + + Args: + code: A python code string. + Returns: + A human readable and python parsable expression that gives the bytecode. + """ + ret = [] + for op, oparg, i in _parse_instructions(code): + sb = "dis.opmap['" + dis.opname[op] + "']" + if oparg is not None: + sb += ", " + str(oparg & 255) + ", " + str((oparg >> 8) & 255) + sb += ", # " + str(i) + if oparg is not None: + if op in dis.hasjrel: + sb += ", dest=" + str(i+3+oparg) + elif op in dis.hasjabs: + sb += ", dest=" + str(oparg) + else: + sb += ", arg=" + str(oparg) + ret.append(sb) + return "pycfg._list_to_string([\n " + "\n ".join(ret) + "\n ])" + + +def _list_to_string(lst): + return "".join(chr(c) for c in lst) diff --git a/byterun/pyobj.py b/byterun/pyobj.py index cfd536c3..2b86b7d1 100644 --- a/byterun/pyobj.py +++ b/byterun/pyobj.py @@ -1,5 +1,9 @@ """Implementations of Python fundamental objects for Byterun.""" + +# TODO(ampere): Add doc strings and remove this. +# pylint: disable=missing-docstring + import collections import inspect import types @@ -28,6 +32,19 @@ class Function(object): '_vm', '_func', ] + CO_OPTIMIZED = 0x0001 + CO_NEWLOCALS = 0x0002 + CO_VARARGS = 0x0004 + CO_VARKEYWORDS = 0x0008 + CO_NESTED = 0x0010 + CO_GENERATOR = 0x0020 + CO_NOFREE = 0x0040 + CO_FUTURE_DIVISION = 0x2000 + CO_FUTURE_ABSOLUTE_IMPORT = 0x4000 + CO_FUTURE_WITH_STATEMENT = 0x8000 + CO_FUTURE_PRINT_FUNCTION = 0x10000 + CO_FUTURE_UNICODE_LITERALS = 0x20000 + def __init__(self, name, code, globs, defaults, closure, vm): self._vm = vm self.func_code = code @@ -61,24 +78,19 @@ def __get__(self, instance, owner): return self def __call__(self, *args, **kwargs): - if PY2 and self.func_name in ["", "", ""]: + if PY2 and self.func_name in ['', '', '']: # D'oh! http://bugs.python.org/issue19611 Py2 doesn't know how to # inspect set comprehensions, dict comprehensions, or generator # expressions properly. They are always functions of one argument, # so just do the right thing. - assert len(args) == 1 and not kwargs, "Surprising comprehension!" - callargs = {".0": args[0]} + assert len(args) == 1 and not kwargs, 'Surprising comprehension!' + callargs = {'.0': args[0]} else: - try: - callargs = inspect.getcallargs(self._func, *args, **kwargs) - except Exception as e: - import pudb;pudb.set_trace() # -={XX}=-={XX}=-={XX}=- - raise + callargs = inspect.getcallargs(self._func, *args, **kwargs) frame = self._vm.make_frame( self.func_code, callargs, self.func_globals, self.func_locals ) - CO_GENERATOR = 32 # flag for "this code uses yield" - if self.func_code.co_flags & CO_GENERATOR: + if self.func_code.co_flags & self.CO_GENERATOR: gen = Generator(frame, self._vm) frame.generator = gen retval = gen @@ -88,22 +100,86 @@ def __call__(self, *args, **kwargs): class Class(object): - def __init__(self, name, bases, methods): + """ + The VM level mirror of python class type objects. + """ + + def __init__(self, name, bases, methods, vm): + self._vm = vm self.__name__ = name self.__bases__ = bases + self.__mro__ = self._compute_mro(self) self.locals = dict(methods) + self.locals['__name__'] = self.__name__ + self.locals['__mro__'] = self.__mro__ + self.locals['__bases__'] = self.__bases__ + + @classmethod + def mro_merge(cls, seqs): + """ + Merge a sequence of MROs into a single resulting MRO. + This code is copied from the following URL with print statments removed. + https://www.python.org/download/releases/2.3/mro/ + """ + res = [] + while True: + nonemptyseqs = [seq for seq in seqs if seq] + if not nonemptyseqs: + return res + for seq in nonemptyseqs: # find merge candidates among seq heads + cand = seq[0] + nothead = [s for s in nonemptyseqs if cand in s[1:]] + if nothead: + cand = None # reject candidate + else: + break + if not cand: + raise TypeError("Illegal inheritance.") + res.append(cand) + for seq in nonemptyseqs: # remove candidate + if seq[0] == cand: + del seq[0] + + @classmethod + def _compute_mro(cls, c): + """ + Compute the class precedence list (mro) according to C3. + This code is copied from the following URL with print statments removed. + https://www.python.org/download/releases/2.3/mro/ + """ + return tuple(cls.mro_merge([[c]] + + [list(base.__mro__) for base in c.__bases__] + + [list(c.__bases__)])) def __call__(self, *args, **kw): - return Object(self, self.locals, args, kw) + return self._vm.make_instance(self, args, kw) def __repr__(self): # pragma: no cover return '' % (self.__name__, id(self)) + def resolve_attr(self, name): + """ + Find an attribute in self and return it raw. This does not handle + properties or method wrapping. + """ + for base in self.__mro__: + # The following code does a double lookup on the dict, however + # measurements show that this is faster than either a special + # sentinel value or catching KeyError. + # Handle both VM classes and python host environment classes. + if isinstance(base, Class): + if name in base.locals: + return base.locals[name] + else: + if name in base.__dict__: + # Avoid using getattr so we can handle method wrapping + return base.__dict__[name] + raise AttributeError( + "%r class has no attribute %r" % (self.__name__, name) + ) + def __getattr__(self, name): - try: - val = self.locals[name] - except KeyError: - raise AttributeError("Fooey: %r" % (name,)) + val = self.resolve_attr(name) # Check if we have a descriptor get = getattr(val, '__get__', None) if get: @@ -113,22 +189,29 @@ def __getattr__(self, name): class Object(object): - def __init__(self, _class, methods, args, kw): + + def __init__(self, _class, args, kw): + # pylint: disable=protected-access + self._vm = _class._vm self._class = _class - self.locals = methods - if '__init__' in methods: - methods['__init__'](self, *args, **kw) + self.locals = {} + if '__init__' in _class.locals: + _class.locals['__init__'](self, *args, **kw) def __repr__(self): # pragma: no cover return '<%s Instance at 0x%08x>' % (self._class.__name__, id(self)) def __getattr__(self, name): - try: + if name in self.locals: val = self.locals[name] - except KeyError: - raise AttributeError( - "%r object has no attribute %r" % (self._class.__name__, name) - ) + else: + try: + val = self._class.resolve_attr(name) + except AttributeError: + raise AttributeError( + "%r object has no attribute %r" % + (self._class.__name__, name) + ) # Check if we have a descriptor get = getattr(val, '__get__', None) if get: @@ -136,8 +219,11 @@ def __getattr__(self, name): # Not a descriptor, return the value. return val + # TODO(ampere): Does this need a __setattr__ and __delattr__ implementation? + class Method(object): + def __init__(self, obj, _class, func): self.im_self = obj self.im_class = _class @@ -176,6 +262,7 @@ class Cell(object): actual value. """ + def __init__(self, value): self.contents = value @@ -190,6 +277,28 @@ def set(self, value): class Frame(object): + """ + An interpreter frame. This contains the local value and block + stacks and the associated code and pointer. The most complex usage + is with generators in which a frame is stored and then repeatedly + reactivated. Other than that frames are created executed and then + discarded. + + Attributes: + f_code: The code object this frame is executing. + f_globals: The globals dict used for global name resolution. + f_locals: Similar for locals. + f_builtins: Similar for builtins. + f_back: The frame above self on the stack. + f_lineno: The first line number of the code object. + f_lasti: The instruction pointer. Despite its name (which matches actual + python frames) this points to the next instruction that will be executed. + block_stack: A stack of blocks used to manage exceptions, loops, and + "with"s. + data_stack: The value stack that is used for instruction operands. + generator: None or a Generator object if this frame is a generator frame. + """ + def __init__(self, f_code, f_globals, f_locals, f_back): self.f_code = f_code self.f_globals = f_globals @@ -205,16 +314,14 @@ def __init__(self, f_code, f_globals, f_locals, f_back): self.f_lineno = f_code.co_firstlineno self.f_lasti = 0 + self.cells = {} if f_code.co_cellvars: - self.cells = {} if not f_back.cells: f_back.cells = {} for var in f_code.co_cellvars: # Make a cell for the variable in our locals, or None. cell = Cell(self.f_locals.get(var)) f_back.cells[var] = self.cells[var] = cell - else: - self.cells = None if f_code.co_freevars: if not self.cells: @@ -224,9 +331,16 @@ def __init__(self, f_code, f_globals, f_locals, f_back): assert f_back.cells, "f_back.cells: %r" % (f_back.cells,) self.cells[var] = f_back.cells[var] + # The stack holding exception and generator handling information self.block_stack = [] + # The stack holding input and output of bytecode instructions + self.data_stack = [] self.generator = None + def push(self, *vals): + """Push values onto the value stack.""" + self.data_stack.extend(vals) + def __repr__(self): # pragma: no cover return '' % ( id(self), self.f_code.co_filename, self.f_lineno @@ -253,6 +367,7 @@ def line_number(self): class Generator(object): + def __init__(self, g_frame, vm): self.gi_frame = g_frame self.vm = vm @@ -263,9 +378,13 @@ def __iter__(self): return self def next(self): + if self.finished: + raise StopIteration + # Ordinary iteration is like sending None into a generator. + # Push the value onto the frame stack. if not self.first: - self.vm.push(None) + self.gi_frame.push(None) self.first = False # To get the next value from an iterator, push its frame onto the # stack, and let it run. diff --git a/byterun/pyvm2.py b/byterun/pyvm2.py index c1687673..576770a1 100644 --- a/byterun/pyvm2.py +++ b/byterun/pyvm2.py @@ -2,6 +2,14 @@ # Based on: # pyvm2 by Paul Swartz (z3p), from http://www.twistedmatrix.com/users/z3p/ + +# Disable because there are enough false positives to make it useless +# pylint: disable=unbalanced-tuple-unpacking +# pylint: disable=unpacking-non-sequence + +# TODO(ampere): Add doc strings and remove this. +# pylint: disable=missing-docstring + from __future__ import print_function, division import dis import inspect @@ -13,12 +21,12 @@ import six from six.moves import reprlib -PY3, PY2 = six.PY3, not six.PY3 - from .pyobj import Frame, Block, Method, Object, Function, Class, Generator log = logging.getLogger(__name__) +PY3, PY2 = six.PY3, not six.PY3 + if six.PY3: byteint = lambda b: b else: @@ -36,19 +44,58 @@ class VirtualMachineError(Exception): class VirtualMachine(object): + def __init__(self): # The call stack of frames. self.frames = [] # The current frame. self.frame = None - # The data stack. - self.stack = [] self.return_value = None self.last_exception = None + self.vmbuiltins = dict(__builtins__) + self.vmbuiltins["isinstance"] = self.isinstance + # Operator tables. These are overriden by subclasses to replace the + # meta-cyclic implementations. + self.unary_operators = { + 'POSITIVE': operator.pos, + 'NEGATIVE': operator.neg, + 'NOT': operator.not_, + 'CONVERT': repr, + 'INVERT': operator.invert, + } + self.binary_operators = { + 'POWER': pow, + 'MULTIPLY': operator.mul, + 'DIVIDE': getattr(operator, 'div', lambda x, y: None), + 'FLOOR_DIVIDE': operator.floordiv, + 'TRUE_DIVIDE': operator.truediv, + 'MODULO': operator.mod, + 'ADD': operator.add, + 'SUBTRACT': operator.sub, + 'SUBSCR': operator.getitem, + 'LSHIFT': operator.lshift, + 'RSHIFT': operator.rshift, + 'AND': operator.and_, + 'XOR': operator.xor, + 'OR': operator.or_, + } + self.compare_operators = [ + operator.lt, + operator.le, + operator.eq, + operator.ne, + operator.gt, + operator.ge, + lambda x, y: x in y, + lambda x, y: x not in y, + lambda x, y: x is y, + lambda x, y: x is not y, + lambda x, y: issubclass(x, Exception) and issubclass(x, y), + ] def top(self): """Return the value at the top of the stack, with no changes.""" - return self.stack[-1] + return self.frame.data_stack[-1] def pop(self, i=0): """Pop a value from the stack. @@ -57,11 +104,11 @@ def pop(self, i=0): instead. """ - return self.stack.pop(-1-i) + return self.frame.data_stack.pop(-1-i) def push(self, *vals): """Push values onto the value stack.""" - self.stack.extend(vals) + self.frame.push(*vals) def popn(self, n): """Pop a number of values from the value stack. @@ -70,30 +117,42 @@ def popn(self, n): """ if n: - ret = self.stack[-n:] - self.stack[-n:] = [] + ret = self.frame.data_stack[-n:] + self.frame.data_stack[-n:] = [] return ret else: return [] def peek(self, n): - """Get a value `n` entries down in the stack, without changing the stack.""" - return self.stack[-n] + """ + Get a value `n` entries down in the stack, without changing the stack. + """ + return self.frame.data_stack[-n] def jump(self, jump): - """Move the bytecode pointer to `jump`, so it will execute next.""" + """ + Move the bytecode pointer to `jump`, so it will execute next. + + Jump may be the very next instruction and hence already the value of + f_lasti. This is used to notify a subclass when a jump was not taken and + instead we continue to the next instruction. + """ self.frame.f_lasti = jump def push_block(self, type, handler=None, level=None): if level is None: - level = len(self.stack) + level = len(self.frame.data_stack) self.frame.block_stack.append(Block(type, handler, level)) def pop_block(self): return self.frame.block_stack.pop() def make_frame(self, code, callargs={}, f_globals=None, f_locals=None): - log.info("make_frame: code=%r, callargs=%s" % (code, repper(callargs))) + # The callargs default is safe because we never modify the dict. + # pylint: disable=dangerous-default-value + log.info("make_frame: code=%r, callargs=%s, f_globals=%r, f_locals=%r", + code, repper(callargs), (type(f_globals), id(f_globals)), + (type(f_locals), id(f_locals))) if f_globals is not None: f_globals = f_globals if f_locals is None: @@ -102,14 +161,21 @@ def make_frame(self, code, callargs={}, f_globals=None, f_locals=None): f_globals = self.frame.f_globals f_locals = {} else: + # TODO(ampere): __name__, __doc__, __package__ below are not correct f_globals = f_locals = { - '__builtins__': __builtins__, + '__builtins__': self.vmbuiltins, '__name__': '__main__', '__doc__': None, '__package__': None, } + + # Implement NEWLOCALS flag. See Objects/frameobject.c in CPython. + if code.co_flags & Function.CO_NEWLOCALS: + f_locals = {} + f_locals.update(callargs) - frame = Frame(code, f_globals, f_locals, self.frame) + frame = self.make_frame_with_dicts(code, f_globals, f_locals) + log.info("%r", frame) return frame def push_frame(self, frame): @@ -138,6 +204,7 @@ def print_frames(self): def resume_frame(self, frame): frame.f_back = self.frame + log.info("resume_frame: %r", frame) val = self.run_frame(frame) frame.f_back = None return val @@ -148,8 +215,9 @@ def run_code(self, code, f_globals=None, f_locals=None): # Check some invariants if self.frames: # pragma: no cover raise VirtualMachineError("Frames left over!") - if self.stack: # pragma: no cover - raise VirtualMachineError("Data left on stack! %r" % self.stack) + if self.frame is not None and self.frame.data_stack: # pragma: no cover + raise VirtualMachineError("Data left on stack! %r" % + self.frame.data_stack) return val @@ -159,7 +227,7 @@ def unwind_block(self, block): else: offset = 0 - while len(self.stack) > block.level + offset: + while len(self.frame.data_stack) > block.level + offset: self.pop() if block.type == 'except-handler': @@ -169,7 +237,13 @@ def unwind_block(self, block): def parse_byte_and_args(self): f = self.frame opoffset = f.f_lasti - byteCode = byteint(f.f_code.co_code[opoffset]) + try: + byteCode = byteint(f.f_code.co_code[opoffset]) + except IndexError: + raise VirtualMachineError( + "Bad bytecode offset %d in %s (len=%d)" % + (opoffset, str(f.f_code), len(f.f_code.co_code)) + ) f.f_lasti += 1 byteName = dis.opname[byteCode] arg = None @@ -201,11 +275,12 @@ def parse_byte_and_args(self): return byteName, arguments, opoffset def log(self, byteName, arguments, opoffset): + # pylint: disable=logging-not-lazy op = "%d: %s" % (opoffset, byteName) if arguments: op += " %r" % (arguments[0],) indent = " "*(len(self.frames)-1) - stack_rep = repper(self.stack) + stack_rep = repper(self.frame.data_stack) block_stack_rep = repper(self.frame.block_stack) log.info(" %sdata: %s" % (indent, stack_rep)) @@ -231,8 +306,7 @@ def dispatch(self, byteName, arguments): "unknown bytecode type: %s" % byteName ) why = bytecode_fn(*arguments) - - except: + except: # pylint: disable=bare-except # deal with exceptions encountered while executing the op. self.last_exception = sys.exc_info()[:2] + (None,) log.exception("Caught exception during execution") @@ -300,37 +374,44 @@ def manage_block_stack(self, why): return why + def run_instruction(self): + """Run one instruction in the current frame. + + Return None if the frame should continue executing otherwise return the + reason it should stop. + """ + frame = self.frame + byteName, arguments, opoffset = self.parse_byte_and_args() + if log.isEnabledFor(logging.INFO): + self.log(byteName, arguments, opoffset) + + # When unwinding the block stack, we need to keep track of why we + # are doing it. + why = self.dispatch(byteName, arguments) + if why == 'exception': + # TODO: ceval calls PyTraceBack_Here, not sure what that does. + pass + + if why == 'reraise': + why = 'exception' + + if why != 'yield': + while why and frame.block_stack: + # Deal with any block management we need to do. + why = self.manage_block_stack(why) + + return why def run_frame(self, frame): """Run a frame until it returns (somehow). Exceptions are raised, the return value is returned. - """ self.push_frame(frame) while True: - byteName, arguments, opoffset = self.parse_byte_and_args() - if log.isEnabledFor(logging.INFO): - self.log(byteName, arguments, opoffset) - - # When unwinding the block stack, we need to keep track of why we - # are doing it. - why = self.dispatch(byteName, arguments) - if why == 'exception': - # TODO: ceval calls PyTraceBack_Here, not sure what that does. - pass - - if why == 'reraise': - why = 'exception' - - if why != 'yield': - while why and frame.block_stack: - # Deal with any block management we need to do. - why = self.manage_block_stack(why) - + why = self.run_instruction() if why: break - self.pop_frame() if why == 'exception': @@ -338,10 +419,169 @@ def run_frame(self, frame): return self.return_value + ## Builders for objects that subclasses may want to replace with subclasses + + def make_instance(self, cls, args, kw): + """ + Create an instance of the given class with the given constructor args. + """ + return Object(cls, args, kw) + + def make_class(self, name, bases, methods): + """ + Create a class with the name bases and methods given. + """ + return Class(name, bases, methods, self) + + def make_function(self, name, code, globs, defaults, closure): + """ + Create a function or closure given the arguments. + """ + return Function(name, code, globs, defaults, closure, self) + + def make_frame_with_dicts(self, code, f_globals, f_locals): + """ + Create a frame with the given code, globals, and locals. + """ + return Frame(code, f_globals, f_locals, self.frame) + + ## Built-in overrides + + def isinstance(self, obj, cls): + if isinstance(obj, Object): + # pylint: disable=protected-access + return issubclass(obj._class, cls) + elif isinstance(cls, Class): + return False + else: + return isinstance(obj, cls) + + ## Abstraction hooks + + def load_constant(self, value): + """ + Called when the constant value is loaded onto the stack. + The returned value is pushed onto the stack instead. + """ + return value + + def get_locals_dict(self): + """Get a real python dict of the locals.""" + return self.frame.f_locals + + def get_locals_dict_bytecode(self): + """Get a possibly abstract bytecode level representation of the locals. + """ + return self.frame.f_locals + + def set_locals_dict_bytecode(self, lcls): + """Set the locals from a possibly abstract bytecode level dict. + """ + self.frame.f_locals = lcls + + def get_globals_dict(self): + """Get a real python dict of the globals.""" + return self.frame.f_globals + + def load_local(self, name): + """ + Called when a local is loaded onto the stack. + The returned value is pushed onto the stack instead of the actual loaded + value. + """ + return self.frame.f_locals[name] + + def load_builtin(self, name): + return self.frame.f_builtins[name] + + def load_global(self, name): + """ + Same as load_local except for globals. + """ + return self.frame.f_globals[name] + + def load_deref(self, name): + """ + Same as load_local except for closure cells. + """ + return self.frame.cells[name].get() + + def store_local(self, name, value): + """ + Called when a local is written. + The returned value is stored instead of the value on the stack. + """ + self.frame.f_locals[name] = value + + def store_deref(self, name, value): + """ + Same as store_local except for closure cells. + """ + self.frame.cells[name].set(value) + + def del_local(self, name): + """ + Called when a local is deleted. + """ + del self.frame.f_locals[name] + + def load_attr(self, obj, attr): + """ + Perform the actual attribute load on an object. This must support all + objects that may appear in the VM. This defaults to just get attr. + """ + return getattr(obj, attr) + + def store_attr(self, obj, attr, value): + """ + Same as load_attr except for setting attributes. Defaults to setattr. + """ + setattr(obj, attr, value) + + def del_attr(self, obj, attr): + """ + Same as load_attr except for deleting attributes. Defaults to delattr. + """ + delattr(obj, attr) + + def build_tuple(self, content): + """ + Create a VM tuple from the given sequence. + The returned object must support the tuple interface. + """ + return tuple(content) + + def build_list(self, content): + """ + Create a VM list from the given sequence. + The returned object must support the list interface. + """ + return list(content) + + def build_set(self, content): + """ + Create a VM set from the given sequence. + The returned object must support the set interface. + """ + return set(content) + + def build_map(self): + """ + Create an empty VM dict. + The returned object must support the dict interface. + """ + return dict() + + def store_subscr(self, obj, subscr, val): + obj[subscr] = val + + def del_subscr(self, obj, subscr): + del obj[subscr] + ## Stack manipulation def byte_LOAD_CONST(self, const): - self.push(const) + self.push(self.load_constant(const)) def byte_POP_TOP(self): self.pop() @@ -374,91 +614,68 @@ def byte_ROT_FOUR(self): ## Names def byte_LOAD_NAME(self, name): - frame = self.frame - if name in frame.f_locals: - val = frame.f_locals[name] - elif name in frame.f_globals: - val = frame.f_globals[name] - elif name in frame.f_builtins: - val = frame.f_builtins[name] - else: - raise NameError("name '%s' is not defined" % name) + try: + val = self.load_local(name) + except KeyError: + try: + val = self.load_global(name) + except KeyError: + try: + val = self.load_builtin(name) + except KeyError: + raise NameError("name '%s' is not defined" % name) self.push(val) def byte_STORE_NAME(self, name): - self.frame.f_locals[name] = self.pop() + self.store_local(name, self.pop()) def byte_DELETE_NAME(self, name): - del self.frame.f_locals[name] + self.del_local(name) def byte_LOAD_FAST(self, name): - if name in self.frame.f_locals: - val = self.frame.f_locals[name] - else: + try: + val = self.load_local(name) + log.info("LOAD_FAST: %s from %r -> %r", name, self.frame, val) + except KeyError: raise UnboundLocalError( "local variable '%s' referenced before assignment" % name ) self.push(val) def byte_STORE_FAST(self, name): - self.frame.f_locals[name] = self.pop() + self.byte_STORE_NAME(name) def byte_DELETE_FAST(self, name): - del self.frame.f_locals[name] + self.byte_DELETE_NAME(name) def byte_LOAD_GLOBAL(self, name): - f = self.frame - if name in f.f_globals: - val = f.f_globals[name] - elif name in f.f_builtins: - val = f.f_builtins[name] - else: - raise NameError("global name '%s' is not defined" % name) + try: + val = self.load_global(name) + except KeyError: + try: + val = self.load_builtin(name) + except KeyError: + raise NameError("global name '%s' is not defined" % name) self.push(val) def byte_LOAD_DEREF(self, name): - self.push(self.frame.cells[name].get()) + self.push(self.load_deref(name)) def byte_STORE_DEREF(self, name): - self.frame.cells[name].set(self.pop()) + self.store_deref(name, self.pop()) def byte_LOAD_LOCALS(self): - self.push(self.frame.f_locals) + self.push(self.get_locals_dict_bytecode()) ## Operators - UNARY_OPERATORS = { - 'POSITIVE': operator.pos, - 'NEGATIVE': operator.neg, - 'NOT': operator.not_, - 'CONVERT': repr, - 'INVERT': operator.invert, - } - def unaryOperator(self, op): x = self.pop() - self.push(self.UNARY_OPERATORS[op](x)) - - BINARY_OPERATORS = { - 'POWER': pow, - 'MULTIPLY': operator.mul, - 'DIVIDE': getattr(operator, 'div', lambda x, y: None), - 'FLOOR_DIVIDE': operator.floordiv, - 'TRUE_DIVIDE': operator.truediv, - 'MODULO': operator.mod, - 'ADD': operator.add, - 'SUBTRACT': operator.sub, - 'SUBSCR': operator.getitem, - 'LSHIFT': operator.lshift, - 'RSHIFT': operator.rshift, - 'AND': operator.and_, - 'XOR': operator.xor, - 'OR': operator.or_, - } + self.push(self.unary_operators[op](x)) def binaryOperator(self, op): x, y = self.popn(2) - self.push(self.BINARY_OPERATORS[op](x, y)) + self.push(self.binary_operators[op](x, y)) def inplaceOperator(self, op): x, y = self.popn(2) @@ -511,65 +728,52 @@ def sliceOperator(self, op): else: self.push(l[start:end]) - COMPARE_OPERATORS = [ - operator.lt, - operator.le, - operator.eq, - operator.ne, - operator.gt, - operator.ge, - lambda x, y: x in y, - lambda x, y: x not in y, - lambda x, y: x is y, - lambda x, y: x is not y, - lambda x, y: issubclass(x, Exception) and issubclass(x, y), - ] - def byte_COMPARE_OP(self, opnum): x, y = self.popn(2) - self.push(self.COMPARE_OPERATORS[opnum](x, y)) + self.push(self.compare_operators[opnum](x, y)) ## Attributes and indexing def byte_LOAD_ATTR(self, attr): obj = self.pop() - val = getattr(obj, attr) + log.info("LOAD_ATTR: %r %s", type(obj), attr) + val = self.load_attr(obj, attr) self.push(val) def byte_STORE_ATTR(self, name): val, obj = self.popn(2) - setattr(obj, name, val) + self.store_attr(obj, name, val) def byte_DELETE_ATTR(self, name): obj = self.pop() - delattr(obj, name) + self.del_attr(obj, name) def byte_STORE_SUBSCR(self): val, obj, subscr = self.popn(3) - obj[subscr] = val + self.store_subscr(obj, subscr, val) def byte_DELETE_SUBSCR(self): obj, subscr = self.popn(2) - del obj[subscr] + self.del_subscr(obj, subscr) ## Building def byte_BUILD_TUPLE(self, count): elts = self.popn(count) - self.push(tuple(elts)) + self.push(self.build_tuple(elts)) def byte_BUILD_LIST(self, count): elts = self.popn(count) - self.push(elts) + self.push(self.build_list(elts)) def byte_BUILD_SET(self, count): # TODO: Not documented in Py2 docs. elts = self.popn(count) - self.push(set(elts)) + self.push(self.build_set(elts)) def byte_BUILD_MAP(self, size): # size is ignored. - self.push({}) + self.push(self.build_map()) def byte_STORE_MAP(self): the_map, val, key = self.popn(3) @@ -660,21 +864,29 @@ def byte_JUMP_IF_TRUE(self, jump): val = self.top() if val: self.jump(jump) + else: + self.jump(self.frame.f_lasti) def byte_JUMP_IF_FALSE(self, jump): val = self.top() if not val: self.jump(jump) + else: + self.jump(self.frame.f_lasti) def byte_POP_JUMP_IF_TRUE(self, jump): val = self.pop() if val: self.jump(jump) + else: + self.jump(self.frame.f_lasti) def byte_POP_JUMP_IF_FALSE(self, jump): val = self.pop() if not val: self.jump(jump) + else: + self.jump(self.frame.f_lasti) def byte_JUMP_IF_TRUE_OR_POP(self, jump): val = self.top() @@ -682,6 +894,7 @@ def byte_JUMP_IF_TRUE_OR_POP(self, jump): self.jump(jump) else: self.pop() + self.jump(self.frame.f_lasti) def byte_JUMP_IF_FALSE_OR_POP(self, jump): val = self.top() @@ -689,6 +902,7 @@ def byte_JUMP_IF_FALSE_OR_POP(self, jump): self.jump(jump) else: self.pop() + self.jump(self.frame.f_lasti) ## Blocks @@ -703,6 +917,7 @@ def byte_FOR_ITER(self, jump): try: v = next(iterobj) self.push(v) + self.jump(self.frame.f_lasti) except StopIteration: self.pop() self.jump(jump) @@ -890,8 +1105,8 @@ def byte_MAKE_FUNCTION(self, argc): name = None code = self.pop() defaults = self.popn(argc) - globs = self.frame.f_globals - fn = Function(name, code, globs, defaults, None, self) + globs = self.get_globals_dict() + fn = self.make_function(name, code, globs, defaults, None) self.push(fn) def byte_LOAD_CLOSURE(self, name): @@ -905,34 +1120,26 @@ def byte_MAKE_CLOSURE(self, argc): name = None closure, code = self.popn(2) defaults = self.popn(argc) - globs = self.frame.f_globals - fn = Function(None, code, globs, defaults, closure, self) + globs = self.get_globals_dict() + fn = self.make_function(None, code, globs, defaults, closure) self.push(fn) def byte_CALL_FUNCTION(self, arg): - return self.call_function(arg, [], {}) + return self.call_function_from_stack(arg, [], {}) def byte_CALL_FUNCTION_VAR(self, arg): args = self.pop() - return self.call_function(arg, args, {}) + return self.call_function_from_stack(arg, args, {}) def byte_CALL_FUNCTION_KW(self, arg): kwargs = self.pop() - return self.call_function(arg, [], kwargs) + return self.call_function_from_stack(arg, [], kwargs) def byte_CALL_FUNCTION_VAR_KW(self, arg): args, kwargs = self.popn(2) - return self.call_function(arg, args, kwargs) - - def isinstance(self, obj, cls): - if isinstance(obj, Object): - return issubclass(obj._class, cls) - elif isinstance(cls, Class): - return False - else: - return isinstance(obj, cls) + return self.call_function_from_stack(arg, args, kwargs) - def call_function(self, arg, args, kwargs): + def call_function_from_stack(self, arg, args, kwargs): lenKw, lenPos = divmod(arg, 256) namedargs = {} for i in range(lenKw): @@ -941,8 +1148,22 @@ def call_function(self, arg, args, kwargs): namedargs.update(kwargs) posargs = self.popn(lenPos) posargs.extend(args) - func = self.pop() + self.push(self.call_function(func, posargs, namedargs)) + + def call_function(self, func, posargs, namedargs=None): + """Call a VM function with the given arguments and return the result. + + This is were subclass override should occur as well. + + Args: + func: The function to call. + posargs: The positional arguments. + namedargs: The keyword arguments (defaults to {}). + Returns: + The return value of the function. + """ + namedargs = namedargs or {} frame = self.frame if hasattr(func, 'im_func'): # Methods get self as an implicit first parameter. @@ -959,8 +1180,7 @@ def call_function(self, arg, args, kwargs): ) ) func = func.im_func - retval = func(*posargs, **namedargs) - self.push(retval) + return func(*posargs, **namedargs) def byte_RETURN_VALUE(self): self.return_value = self.pop() @@ -974,23 +1194,32 @@ def byte_YIELD_VALUE(self): ## Importing + def import_name(self, name, fromlist, level): + """Import the module and return the module object.""" + return __import__(name, self.get_globals_dict(), self.get_locals_dict(), + fromlist, level) + + def get_module_attributes(self, mod): + """Return the modules members as a dict.""" + return {name: getattr(mod, name) for name in dir(mod)} + def byte_IMPORT_NAME(self, name): level, fromlist = self.popn(2) frame = self.frame - self.push( - __import__(name, frame.f_globals, frame.f_locals, fromlist, level) - ) + self.push(self.import_name(name, fromlist, level)) def byte_IMPORT_STAR(self): # TODO: this doesn't use __all__ properly. mod = self.pop() - for attr in dir(mod): + attrs = self.get_module_attributes(mod) + for attr, val in attrs.iteritems(): if attr[0] != '_': - self.frame.f_locals[attr] = getattr(mod, attr) + self.store_local(attr, val) def byte_IMPORT_FROM(self, name): mod = self.top() - self.push(getattr(mod, name)) + attrs = self.get_module_attributes(mod) + self.push(attrs[name]) ## And the rest... @@ -1000,14 +1229,14 @@ def byte_EXEC_STMT(self): def byte_BUILD_CLASS(self): name, bases, methods = self.popn(3) - self.push(Class(name, bases, methods)) + self.push(self.make_class(name, bases, methods)) def byte_LOAD_BUILD_CLASS(self): # New in py3 self.push(__build_class__) def byte_STORE_LOCALS(self): - self.frame.f_locals = self.pop() + self.set_locals_dict_bytecode(self.pop()) if 0: # Not in py2.7 def byte_SET_LINENO(self, lineno): diff --git a/tests/test_abstractvm.py b/tests/test_abstractvm.py new file mode 100644 index 00000000..72cc9807 --- /dev/null +++ b/tests/test_abstractvm.py @@ -0,0 +1,195 @@ + + +import dis +import logging +import sys +import types +import unittest + + +from byterun import abstractvm +from byterun import pycfg +import mock + +# It does not accept any styling for several different members for some reason. +# pylint: disable=invalid-name + + +class MockVM(abstractvm.AbstractVirtualMachine): + + def __init__(self): + super(MockVM, self).__init__() + self.load_attr = mock.MagicMock(spec=self.load_attr) + + +class AbstractVirtualMachineTest(unittest.TestCase): + + def _run_code(self, code, consts, nlocals): + """Run the given raw byte code. + + Args: + code: A raw bytecode string to execute. + consts: The constants for the code. + nlocals: the number of locals the code uses. + """ + names = tuple("v" + str(i) for i in xrange(nlocals)) + code = types.CodeType(0, # argcount + nlocals, # nlocals + 16, # stacksize + 0, # flags + code, # codestring + consts, # constants + names, # names + names, # varnames + "<>", # filename + "", # name + 0, # firstlineno + "") # lnotab + self.vm.run_code(code) + + def setUp(self): + self.vm = MockVM() + + def testMagicOperator(self): + code = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 0, 0, # 0, arg=0 + dis.opmap["LOAD_CONST"], 1, 0, # 9, arg=1 + dis.opmap["BINARY_ADD"], # 12 + dis.opmap["STORE_NAME"], 0, 0, # 13, arg=1 + dis.opmap["LOAD_CONST"], 2, 0, # 16, arg=2 + dis.opmap["RETURN_VALUE"], # 19 + ]) + + method = mock.MagicMock(spec=(1).__add__) + self.vm.load_attr.return_value = method + self._run_code(code, (1, 2, None), 1) + self.vm.load_attr.assert_called_once_with(1, "__add__") + method.assert_called_once_with(2) + + def testIter(self): + code = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 0, 0, # 3, arg=0 + dis.opmap["LOAD_CONST"], 1, 0, # 6, arg=1 + dis.opmap["LOAD_CONST"], 2, 0, # 9, arg=2 + dis.opmap["BUILD_LIST"], 3, 0, # 12, arg=3 + dis.opmap["GET_ITER"], # 15 + dis.opmap["LOAD_CONST"], 3, 0, # 26, arg=3 + dis.opmap["RETURN_VALUE"], # 29 + ]) + + method = mock.MagicMock(spec=[1, 2, 3].__iter__) + self.vm.load_attr.return_value = method + self._run_code(code, (1, 2, 3, None), 0) + self.vm.load_attr.assert_called_once_with([1, 2, 3], "__iter__") + + +class TraceVM(abstractvm.AncestorTraversalVirtualMachine): + + def __init__(self): + super(TraceVM, self).__init__() + self.instructions_executed = set() + + def run_instruction(self): + self.instructions_executed.add(self.frame.f_lasti) + return super(TraceVM, self).run_instruction() + + +class AncestorTraversalVirtualMachineTest(unittest.TestCase): + + def setUp(self): + self.vm = TraceVM() + + srcNestedLoops = """ +y = [1,2,3] +z = 0 +for x in y: + for a in y: + if x: + z += x*a +""" + codeNestedLoops = compile(srcNestedLoops, "<>", "exec", 0, 1) + + codeNestedLoopsBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 0, 0, # 0, arg=0 + dis.opmap["LOAD_CONST"], 1, 0, # 3, arg=1 + dis.opmap["LOAD_CONST"], 2, 0, # 6, arg=2 + dis.opmap["BUILD_LIST"], 3, 0, # 9, arg=3 + dis.opmap["STORE_NAME"], 0, 0, # 12, arg=0 + dis.opmap["LOAD_CONST"], 3, 0, # 15, arg=3 + dis.opmap["STORE_NAME"], 1, 0, # 18, arg=1 + dis.opmap["SETUP_LOOP"], 54, 0, # 21, dest=78 + dis.opmap["LOAD_NAME"], 0, 0, # 24, arg=0 + dis.opmap["GET_ITER"], # 27 + dis.opmap["FOR_ITER"], 46, 0, # 28, dest=77 + dis.opmap["STORE_NAME"], 2, 0, # 31, arg=2 + dis.opmap["SETUP_LOOP"], 37, 0, # 34, dest=74 + dis.opmap["LOAD_NAME"], 0, 0, # 37, arg=0 + dis.opmap["GET_ITER"], # 40 + dis.opmap["FOR_ITER"], 29, 0, # 41, dest=73 + dis.opmap["STORE_NAME"], 3, 0, # 44, arg=3 + dis.opmap["LOAD_NAME"], 2, 0, # 47, arg=2 + dis.opmap["POP_JUMP_IF_FALSE"], 41, 0, # 50, dest=41 + dis.opmap["LOAD_NAME"], 1, 0, # 53, arg=1 + dis.opmap["LOAD_NAME"], 2, 0, # 56, arg=2 + dis.opmap["LOAD_NAME"], 3, 0, # 59, arg=3 + dis.opmap["BINARY_MULTIPLY"], # 62 + dis.opmap["INPLACE_ADD"], # 63 + dis.opmap["STORE_NAME"], 1, 0, # 64, arg=1 + dis.opmap["JUMP_ABSOLUTE"], 41, 0, # 67, dest=41 + dis.opmap["JUMP_ABSOLUTE"], 41, 0, # 70, dest=41 + dis.opmap["POP_BLOCK"], # 73 + dis.opmap["JUMP_ABSOLUTE"], 28, 0, # 74, dest=28 + dis.opmap["POP_BLOCK"], # 77 + dis.opmap["LOAD_CONST"], 4, 0, # 78, arg=4 + dis.opmap["RETURN_VALUE"], # 81 + ]) + + def testEachInstructionOnceLoops(self): + self.assertEqual(self.codeNestedLoops.co_code, + self.codeNestedLoopsBytecode) + self.vm.run_code(self.codeNestedLoops) + # The number below are the instruction offsets in the above bytecode. + self.assertItemsEqual(self.vm.instructions_executed, + [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 28, 31, 34, 37, + 40, 41, 44, 47, 50, 53, 56, 59, 62, 63, 64, 67, 70, + 73, 74, 77, 78, 81]) + + srcDeadCode = """ +if False: + x = 2 +raise RuntimeError +x = 42 +""" + codeDeadCode = compile(srcDeadCode, "<>", "exec", 0, 1) + + codeDeadCodeBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_NAME"], 0, 0, # 0, arg=0 + dis.opmap["POP_JUMP_IF_FALSE"], 15, 0, # 3, dest=15 + dis.opmap["LOAD_CONST"], 0, 0, # 6, arg=0 + dis.opmap["STORE_NAME"], 1, 0, # 9, arg=1 + dis.opmap["JUMP_FORWARD"], 0, 0, # 12, dest=15 + dis.opmap["LOAD_NAME"], 2, 0, # 15, arg=2 + dis.opmap["RAISE_VARARGS"], 1, 0, # 18, arg=1 + dis.opmap["LOAD_CONST"], 1, 0, # 21, arg=1 + dis.opmap["STORE_NAME"], 1, 0, # 24, arg=1 + dis.opmap["LOAD_CONST"], 2, 0, # 27, arg=2 + dis.opmap["RETURN_VALUE"], # 30 + ]) + + def testEachInstructionOnceDeadCode(self): + self.assertEqual(self.codeDeadCode.co_code, + self.codeDeadCodeBytecode) + try: + self.vm.run_code(self.codeDeadCode) + except RuntimeError: + pass # Ignore the exception that gets out. + self.assertItemsEqual(self.vm.instructions_executed, + [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30]) + + +if __name__ == "__main__": + # TODO(ampere): This is just a useful hack. Should be replaced with real + # argument handling. + if len(sys.argv) > 1: + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/test_basic.py b/tests/test_basic.py index 2fd02cd1..cffdd4c8 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,14 +1,17 @@ """Basic tests for Byterun.""" + from __future__ import print_function -from . import vmtest +import unittest +from tests import vmtest import six PY3, PY2 = six.PY3, not six.PY3 class TestIt(vmtest.VmTestCase): + def test_constant(self): self.assert_ok("17") @@ -186,6 +189,7 @@ def test_generator_expression(self): initial_indent=blanks, subsequent_indent=blanks) print(res) """) + def test_list_comprehension(self): self.assert_ok("""\ x = [z*z for z in range(5)] @@ -286,6 +290,28 @@ def meth(self, y): print(thing1.meth(4), thing2.meth(5)) """) + def test_class_mros(self): + self.assert_ok("""\ + class A(object): pass + class B(A): pass + class C(A): pass + class D(B, C): pass + class E(C, B): pass + print([c.__name__ for c in D.__mro__]) + print([c.__name__ for c in E.__mro__]) + """) + + def test_class_mro_method_calls(self): + self.assert_ok("""\ + class A(object): + def f(self): return 'A' + class B(A): pass + class C(A): + def f(self): return 'C' + class D(B, C): pass + print(D().f()) + """) + def test_calling_methods_wrong(self): self.assert_ok("""\ class Thing(object): @@ -310,6 +336,20 @@ class SubThing(Thing): print(st.foo()) """) + def test_other_class_methods(self): + self.assert_ok("""\ + class Thing(object): + def foo(self): + return 17 + + class SubThing(object): + def bar(self): + return 9 + + st = SubThing() + print(st.foo()) + """, raises=AttributeError) + def test_attribute_access(self): self.assert_ok("""\ class Thing(object): @@ -454,6 +494,7 @@ def __init__(self, x): if PY2: class TestPrinting(vmtest.VmTestCase): + def test_printing(self): self.assert_ok("print 'hello'") self.assert_ok("a = 3; print a+4") @@ -478,6 +519,7 @@ def test_printing_to_a_file(self): class TestLoops(vmtest.VmTestCase): + def test_for(self): self.assert_ok("""\ for i in range(10): @@ -530,6 +572,7 @@ def test_continue_in_try_finally(self): class TestComparisons(vmtest.VmTestCase): + def test_in(self): self.assert_ok("""\ assert "x" in "xyz" @@ -553,3 +596,6 @@ def test_greater(self): assert "z" > "a" assert "z" >= "a" and "z" >= "z" """) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 09e04361..d21fc86a 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,14 +1,17 @@ """Test exceptions for Byterun.""" + from __future__ import print_function -from . import vmtest +import unittest +from tests import vmtest import six PY3, PY2 = six.PY3, not six.PY3 class TestExceptions(vmtest.VmTestCase): + def test_catching_exceptions(self): # Catch the exception precisely self.assert_ok("""\ @@ -159,3 +162,6 @@ def test_coverage_issue_92(self): print(l) assert l == [0, 'f', 'e', 1, 'f', 'e', 2, 'f', 'e', 'r'] """) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_functions.py b/tests/test_functions.py index 19a615c0..9eab1343 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1,7 +1,10 @@ """Test functions etc, for Byterun.""" + from __future__ import print_function -from . import vmtest + +import unittest +from tests import vmtest class TestFunctions(vmtest.VmTestCase): @@ -17,6 +20,22 @@ def fn(a, b=17, c="Hello", d=[]): fn(5, "b", "c") """) + def test_function_locals(self): + self.assert_ok("""\ + def f(): + x = "Spite" + print(x) + def g(): + x = "Malice" + print(x) + x = "Humility" + f() + print(x) + g() + print(x) + """) + + def test_recursion(self): self.assert_ok("""\ def fact(n): @@ -228,6 +247,13 @@ def triples(): print(a, b, c) """) + def test_generator_reuse(self): + self.assert_ok("""\ + g = (x*x for x in range(5)) + print(list(g)) + print(list(g)) + """) + def test_generator_from_generator2(self): self.assert_ok("""\ g = (x*x for x in range(3)) @@ -258,3 +284,6 @@ def boom(self): print(Thing().boom()) """) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pycfg.py b/tests/test_pycfg.py new file mode 100644 index 00000000..17bb5daa --- /dev/null +++ b/tests/test_pycfg.py @@ -0,0 +1,654 @@ +"""Tests for pycfg. +""" + +import dis +import inspect +import logging +import unittest + + +from byterun import pycfg + +# Disable because pylint does not like any name for the nested test_code +# functions used to get the needed bytecode. +# pylint: disable=invalid-name + +# The bytecode constants used to check against the generated code are formatted +# as follows. Each line is one instruction. Blank lines separate basic blocks. +# +# dis.opmap[""], , , # , +# +# The is a decoded version of the argument. This is more useful for +# relative jumps. + + +def line_number(): + """Returns the line number of the call site.""" + return inspect.currentframe().f_back.f_lineno + + +class CFGTest(unittest.TestCase): + + def assertEndsWith(self, actual, expected): + self.assertTrue(actual.endswith(expected), + msg="'%s' does not end with '%s'" % (actual, expected)) + + # Copy this line into your test when developing it. It prints the formatted + # bytecode to use as the expected. + # print pycfg._bytecode_repr(test_code.func_code.co_code) + + def checkBlocks(self, table, expected): + self.assertEqual(len(table._blocks), len(expected)) + for block, (expected_begin, expected_end) in zip(table._blocks, expected): + self.assertEqual(block.begin, expected_begin) + self.assertEqual(block.end, expected_end) + + @staticmethod + def codeOneBlock(): + return x + 1 # pylint: disable=undefined-variable + + codeOneBlockBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_GLOBAL"], 0, 0, # 0 + dis.opmap["LOAD_CONST"], 1, 0, # 3 + dis.opmap["BINARY_ADD"], # 6 + dis.opmap["RETURN_VALUE"], # 7 + ]) + + def testOneBlock(self): + # Check the code to make sure the test will fail if the compilation changes. + self.assertEqual(self.codeOneBlock.func_code.co_code, + self.codeOneBlockBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeOneBlock.func_code) + # Should all be one basic block. + self.assertIs(table.get_basic_block(0), table.get_basic_block(3)) + self.assertIs(table.get_basic_block(0), table.get_basic_block(6)) + self.assertIs(table.get_basic_block(0), table.get_basic_block(7)) + # No incoming + self.assertItemsEqual(table.get_basic_block(0).incoming, []) + # Outgoing is an unknown return location + self.assertItemsEqual(table.get_basic_block(0).outgoing, [None]) + + @staticmethod + def codeTriangle(y): + x = y + if y > 10: + x -= 2 + return x + codeTriangleLineNumber = line_number() - 4 + # codeTriangleLineNumber is used to compute the correct line numbers for code + # in codeTriangle. This makes the tests less brittle if other tests in the + # file are changed. However the "- 4" will need to be changed if codeTriangle + # is changed or anything is inserted between the line_number() call and the + # definition of codeTriangle. + + codeTriangleBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_FAST"], 0, 0, # 0, arg=0 + dis.opmap["STORE_FAST"], 1, 0, # 3, arg=1 + dis.opmap["LOAD_FAST"], 0, 0, # 6, arg=0 + dis.opmap["LOAD_CONST"], 1, 0, # 9, arg=1 + dis.opmap["COMPARE_OP"], 4, 0, # 12, arg=4 + dis.opmap["POP_JUMP_IF_FALSE"], 31, 0, # 15, dest=31 + + dis.opmap["LOAD_FAST"], 1, 0, # 18, arg=1 + dis.opmap["LOAD_CONST"], 2, 0, # 21, arg=2 + dis.opmap["INPLACE_SUBTRACT"], # 24 + dis.opmap["STORE_FAST"], 1, 0, # 25, arg=1 + dis.opmap["JUMP_FORWARD"], 0, 0, # 28, dest=31 + + dis.opmap["LOAD_FAST"], 1, 0, # 31, arg=1 + dis.opmap["RETURN_VALUE"], # 34 + ]) + + def testTriangle(self): + self.assertEqual(self.codeTriangle.func_code.co_code, + self.codeTriangleBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeTriangle.func_code) + expected = [(0, 15), + (18, 28), + (31, 34)] + self.checkBlocks(table, expected) + bb = table.get_basic_block + # Check the POP_JUMP_IF_FALSE conditional jump + self.assertItemsEqual(bb(0).outgoing, [bb(18), bb(31)]) + # Check the return + self.assertItemsEqual(bb(44).outgoing, [None]) + # Check the incoming of the entry block + self.assertItemsEqual(bb(0).incoming, []) + # Check incoming of the merge block. + self.assertItemsEqual(bb(44).incoming, [bb(28), bb(15)]) + self.assertEndsWith( + bb(21).get_name(), + "tests/test_pycfg.py:{0}-{0}".format(self.codeTriangleLineNumber+2)) + + def testTriangleDominators(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeTriangle.func_code) + + bb = table.get_basic_block + self.assertEqual(bb(0)._dominators, {bb(0)}) + self.assertEqual(bb(18)._dominators, {bb(0), bb(18)}) + self.assertEqual(bb(31)._dominators, {bb(0), bb(31)}) + self.assertEqual(bb(41)._dominators, {bb(0), bb(41)}) + + self.assertTrue(table.dominates(0, 37)) + self.assertFalse(table.dominates(24, 41)) + self.assertTrue(table.dominates(21, 28)) + self.assertFalse(table.dominates(28, 21)) + + def testTriangleOrder(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeTriangle.func_code) + + bb = table.get_basic_block + self.assertEqual(table.get_ancestors_first_traversal(), + [bb(o) for o in [0, 18, 31]]) + + @staticmethod + def codeDiamond(y): + x = y + if y > 10: + x -= 2 + else: + x += 2 + return x + + codeDiamondBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_FAST"], 0, 0, # 0, arg=0 + dis.opmap["STORE_FAST"], 1, 0, # 3, arg=1 + dis.opmap["LOAD_FAST"], 0, 0, # 6, arg=0 + dis.opmap["LOAD_CONST"], 1, 0, # 9, arg=1 + dis.opmap["COMPARE_OP"], 4, 0, # 12, arg=4 + dis.opmap["POP_JUMP_IF_FALSE"], 31, 0, # 15, dest=31 + + dis.opmap["LOAD_FAST"], 1, 0, # 18, arg=1 + dis.opmap["LOAD_CONST"], 2, 0, # 21, arg=2 + dis.opmap["INPLACE_SUBTRACT"], # 24 + dis.opmap["STORE_FAST"], 1, 0, # 25, arg=1 + dis.opmap["JUMP_FORWARD"], 10, 0, # 28, dest=41 + + dis.opmap["LOAD_FAST"], 1, 0, # 31, arg=1 + dis.opmap["LOAD_CONST"], 2, 0, # 34, arg=2 + dis.opmap["INPLACE_ADD"], # 37 + dis.opmap["STORE_FAST"], 1, 0, # 38, arg=1 + + dis.opmap["LOAD_FAST"], 1, 0, # 41, arg=1 + dis.opmap["RETURN_VALUE"], # 44 + ]) + + def testDiamond(self): + self.assertEqual(self.codeDiamond.func_code.co_code, + self.codeDiamondBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeDiamond.func_code) + expected = [(0, 15), + (18, 28), + (31, 38), + (41, 44)] + self.checkBlocks(table, expected) + bb = table.get_basic_block + # Check the POP_JUMP_IF_FALSE conditional jump + self.assertItemsEqual(bb(0).outgoing, [bb(18), bb(31)]) + # Check the jumps at the end of the 2 of branches + self.assertItemsEqual(bb(18).outgoing, [bb(41)]) + self.assertItemsEqual(bb(38).outgoing, [bb(41)]) + # Check the return + self.assertItemsEqual(bb(44).outgoing, [None]) + # Check the incoming of the entry block + self.assertItemsEqual(bb(0).incoming, []) + # Check the incoming of the 2 if branches + self.assertItemsEqual(bb(18).incoming, [bb(15)]) + self.assertItemsEqual(bb(31).incoming, [bb(15)]) + # Check incoming of the merge block. + self.assertItemsEqual(bb(44).incoming, [bb(28), bb(38)]) + + def testDiamondDominators(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeDiamond.func_code) + + bb = table.get_basic_block + self.assertEqual(bb(0)._dominators, {bb(0)}) + self.assertEqual(bb(18)._dominators, {bb(0), bb(18)}) + self.assertEqual(bb(31)._dominators, {bb(0), bb(31)}) + self.assertEqual(bb(41)._dominators, {bb(0), bb(41)}) + + self.assertTrue(table.dominates(0, 37)) + self.assertFalse(table.dominates(24, 41)) + self.assertTrue(table.dominates(21, 28)) + self.assertFalse(table.dominates(28, 21)) + + def testDiamondOrder(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeDiamond.func_code) + + bb = table.get_basic_block + self.assertEqual(table.get_ancestors_first_traversal(), + [bb(o) for o in [0, 18, 31, 41]]) + + @staticmethod + def codeLoop(y): + z = 0 + for x in y: + z += x + return z + + codeLoopBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 1, 0, # 0, arg=1 + dis.opmap["STORE_FAST"], 1, 0, # 3, arg=1 + dis.opmap["SETUP_LOOP"], 24, 0, # 6, dest=33 + dis.opmap["LOAD_FAST"], 0, 0, # 9, arg=0 + dis.opmap["GET_ITER"], # 12 + + dis.opmap["FOR_ITER"], 16, 0, # 13, dest=32 + + dis.opmap["STORE_FAST"], 2, 0, # 16, arg=2 + dis.opmap["LOAD_FAST"], 1, 0, # 19, arg=1 + dis.opmap["LOAD_FAST"], 2, 0, # 22, arg=2 + dis.opmap["INPLACE_ADD"], # 25 + dis.opmap["STORE_FAST"], 1, 0, # 26, arg=1 + dis.opmap["JUMP_ABSOLUTE"], 13, 0, # 29, dest=13 + + dis.opmap["POP_BLOCK"], # 32 + + dis.opmap["LOAD_FAST"], 1, 0, # 33, arg=1 + dis.opmap["RETURN_VALUE"], # 36 + ]) + + def testLoop(self): + self.assertEqual(self.codeLoop.func_code.co_code, + self.codeLoopBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeLoop.func_code) + expected = [(0, 12), + (13, 13), + (16, 29), + (32, 32), + (33, 36)] + self.checkBlocks(table, expected) + bb = table.get_basic_block + # Check outgoing of the loop handler instruction. + self.assertItemsEqual(bb(13).outgoing, [bb(16), bb(32)]) + self.assertItemsEqual(bb(0).outgoing, [bb(13)]) + + def testLoopDominators(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeLoop.func_code) + + bb = table.get_basic_block + self.assertEqual(bb(0)._dominators, {bb(0)}) + self.assertEqual(bb(13)._dominators, {bb(0), bb(13)}) + self.assertEqual(bb(16)._dominators, {bb(0), bb(13), bb(16)}) + self.assertEqual(bb(32)._dominators, {bb(0), bb(13), bb(32)}) + self.assertEqual(bb(33)._dominators, + {bb(0), bb(13), bb(32), bb(33)}) + + def testLoopOrder(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeLoop.func_code) + + bb = table.get_basic_block + self.assertEqual(table.get_ancestors_first_traversal(), + [bb(o) for o in [0, 13, 16, 32, 33]]) + + @staticmethod + def codeNestedLoops(y): + z = 0 + for x in y: + for x in y: + z += x*x + return z + codeNestedLoopsLineNumber = line_number() - 5 + # See comment on codeTriangleLineNumber above. + + codeNestedLoopsBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 1, 0, # 0, arg=1 + dis.opmap["STORE_FAST"], 1, 0, # 3, arg=1 + dis.opmap["SETUP_LOOP"], 45, 0, # 6, dest=54 + dis.opmap["LOAD_FAST"], 0, 0, # 9, arg=0 + dis.opmap["GET_ITER"], # 12 + + dis.opmap["FOR_ITER"], 37, 0, # 13, dest=53 + + dis.opmap["STORE_FAST"], 2, 0, # 16, arg=2 + dis.opmap["SETUP_LOOP"], 28, 0, # 19, dest=50 + dis.opmap["LOAD_FAST"], 0, 0, # 22, arg=0 + dis.opmap["GET_ITER"], # 25 + + dis.opmap["FOR_ITER"], 20, 0, # 26, dest=49 + + dis.opmap["STORE_FAST"], 2, 0, # 29, arg=2 + dis.opmap["LOAD_FAST"], 1, 0, # 32, arg=1 + dis.opmap["LOAD_FAST"], 2, 0, # 35, arg=2 + dis.opmap["LOAD_FAST"], 2, 0, # 38, arg=2 + dis.opmap["BINARY_MULTIPLY"], # 41 + dis.opmap["INPLACE_ADD"], # 42 + dis.opmap["STORE_FAST"], 1, 0, # 43, arg=1 + dis.opmap["JUMP_ABSOLUTE"], 26, 0, # 46, dest=26 + + dis.opmap["POP_BLOCK"], # 49 + + dis.opmap["JUMP_ABSOLUTE"], 13, 0, # 50, dest=13 + + dis.opmap["POP_BLOCK"], # 53 + + dis.opmap["LOAD_FAST"], 1, 0, # 54, arg=1 + dis.opmap["RETURN_VALUE"], # 57 + ]) + + def testNestedLoops(self): + self.assertEqual(self.codeNestedLoops.func_code.co_code, + self.codeNestedLoopsBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeNestedLoops.func_code) + expected = [(0, 12), + (13, 13), + (16, 25), + (26, 26), + (29, 46), + (49, 49), + (50, 50), + (53, 53), + (54, 57)] + self.checkBlocks(table, expected) + bb = table.get_basic_block + self.assertItemsEqual(bb(13).incoming, [bb(12), bb(50)]) + self.assertItemsEqual(bb(13).outgoing, [bb(16), bb(53)]) + self.assertItemsEqual(bb(26).incoming, [bb(25), bb(46)]) + self.assertItemsEqual(bb(26).outgoing, [bb(29), bb(49)]) + self.assertEndsWith( + bb(43).get_name(), + "tests/test_pycfg.py:{}-{}".format(self.codeNestedLoopsLineNumber + 2, + self.codeNestedLoopsLineNumber + 3)) + + def testNestedLoopsDominators(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeNestedLoops.func_code) + bb = table.get_basic_block + + self.assertEqual(bb(0)._dominators, + {bb(0)}) + self.assertEqual(bb(13)._dominators, + {bb(0), bb(13)}) + self.assertEqual(bb(16)._dominators, + {bb(0), bb(13), bb(16)}) + self.assertEqual(bb(26)._dominators, + {bb(0), bb(13), bb(16), bb(26)}) + self.assertEqual(bb(29)._dominators, + {bb(0), bb(13), bb(16), bb(26), bb(29)}) + self.assertEqual(bb(49)._dominators, + {bb(0), bb(13), bb(16), bb(26), bb(49)}) + self.assertEqual(bb(50)._dominators, + {bb(0), bb(13), bb(16), bb(26), bb(49), bb(50)}) + self.assertEqual(bb(53)._dominators, + {bb(0), bb(13), bb(53)}) + self.assertEqual(bb(54)._dominators, + {bb(0), bb(13), bb(53), bb(54)}) + + def testNestedLoopsReachable(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeNestedLoops.func_code) + bb = table.get_basic_block + + self.assertEqual(bb(26)._reachable_from, + set([bb(0), bb(13), bb(16), bb(26), + bb(29), bb(49), bb(50)])) + + self.assertTrue(table.reachable_from(41, 50)) + self.assertTrue(table.reachable_from(50, 41)) + self.assertFalse(table.reachable_from(41, 53)) + + def testNestedLoopsOrder(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeNestedLoops.func_code) + + bb = table.get_basic_block + self.assertEqual(table.get_ancestors_first_traversal(), + [bb(o) for o in [0, 13, 16, 26, 29, 49, 50, 53, 54]]) + + @staticmethod + def codeContinue(y): + z = 0 + for x in y: + if x == 1: + continue + z += x*x + return z + + codeContinueBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 1, 0, # 0, arg=1 + dis.opmap["STORE_FAST"], 1, 0, # 3, arg=1 + dis.opmap["SETUP_LOOP"], 46, 0, # 6, dest=55 + dis.opmap["LOAD_FAST"], 0, 0, # 9, arg=0 + dis.opmap["GET_ITER"], # 12 + + dis.opmap["FOR_ITER"], 38, 0, # 13, dest=54 + + dis.opmap["STORE_FAST"], 2, 0, # 16, arg=2 + dis.opmap["LOAD_FAST"], 2, 0, # 19, arg=2 + dis.opmap["LOAD_CONST"], 2, 0, # 22, arg=2 + dis.opmap["COMPARE_OP"], 2, 0, # 25, arg=2 + dis.opmap["POP_JUMP_IF_FALSE"], 37, 0, # 28, dest=37 + + dis.opmap["JUMP_ABSOLUTE"], 13, 0, # 31, dest=13 + + dis.opmap["JUMP_FORWARD"], 0, 0, # 34, dest=37 + + dis.opmap["LOAD_FAST"], 1, 0, # 37, arg=1 + dis.opmap["LOAD_FAST"], 2, 0, # 40, arg=2 + dis.opmap["LOAD_FAST"], 2, 0, # 43, arg=2 + dis.opmap["BINARY_MULTIPLY"], # 46 + dis.opmap["INPLACE_ADD"], # 47 + dis.opmap["STORE_FAST"], 1, 0, # 48, arg=1 + dis.opmap["JUMP_ABSOLUTE"], 13, 0, # 51, dest=13 + + dis.opmap["POP_BLOCK"], # 54 + dis.opmap["LOAD_FAST"], 1, 0, # 55, arg=1 + dis.opmap["RETURN_VALUE"], # 58 + ]) + + def testContinue(self): + self.assertEqual(self.codeContinue.func_code.co_code, + self.codeContinueBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeContinue.func_code) + bb = table.get_basic_block + self.assertItemsEqual(bb(31).outgoing, [bb(13)]) + self.assertItemsEqual(bb(13).incoming, [bb(12), bb(51), bb(31)]) + self.assertItemsEqual(bb(13).outgoing, [bb(16), bb(54)]) + + @staticmethod + def codeBreak(y): + z = 0 + for x in y: + if x == 1: + break + z += x*x + return z + + codeBreakBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 1, 0, # 0, arg=1 + dis.opmap["STORE_FAST"], 1, 0, # 3, arg=1 + dis.opmap["SETUP_LOOP"], 44, 0, # 6, dest=53 + dis.opmap["LOAD_FAST"], 0, 0, # 9, arg=0 + dis.opmap["GET_ITER"], # 12 + + dis.opmap["FOR_ITER"], 36, 0, # 13, dest=52 + + dis.opmap["STORE_FAST"], 2, 0, # 16, arg=2 + dis.opmap["LOAD_FAST"], 2, 0, # 19, arg=2 + dis.opmap["LOAD_CONST"], 2, 0, # 22, arg=2 + dis.opmap["COMPARE_OP"], 2, 0, # 25, arg=2 + dis.opmap["POP_JUMP_IF_FALSE"], 35, 0, # 28, dest=35 + + dis.opmap["BREAK_LOOP"], # 31 + + dis.opmap["JUMP_FORWARD"], 0, 0, # 32, dest=35 + + dis.opmap["LOAD_FAST"], 1, 0, # 35, arg=1 + dis.opmap["LOAD_FAST"], 2, 0, # 38, arg=2 + dis.opmap["LOAD_FAST"], 2, 0, # 41, arg=2 + dis.opmap["BINARY_MULTIPLY"], # 44 + dis.opmap["INPLACE_ADD"], # 45 + dis.opmap["STORE_FAST"], 1, 0, # 46, arg=1 + dis.opmap["JUMP_ABSOLUTE"], 13, 0, # 49, dest=13 + + dis.opmap["POP_BLOCK"], # 52 + + dis.opmap["LOAD_FAST"], 1, 0, # 53, arg=1 + dis.opmap["RETURN_VALUE"], # 56 + ]) + + def testBreak(self): + self.assertEqual(self.codeBreak.func_code.co_code, + self.codeBreakBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeBreak.func_code) + bb = table.get_basic_block + self.assertItemsEqual(bb(13).incoming, [bb(12), bb(49)]) + self.assertItemsEqual(bb(31).incoming, [bb(28)]) + self.assertItemsEqual(bb(31).outgoing, [None]) + # TODO(ampere): This is correct, however more information would make the + # following succeed. + # self.assertItemsEqual(bb(31).incoming, [53]) + + @staticmethod + def codeYield(): + yield 1 + yield 2 + yield 3 + + codeYieldBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_CONST"], 1, 0, # 0, arg=1 + dis.opmap["YIELD_VALUE"], # 3 + + dis.opmap["POP_TOP"], # 4 + dis.opmap["LOAD_CONST"], 2, 0, # 5, arg=2 + dis.opmap["YIELD_VALUE"], # 8 + + dis.opmap["POP_TOP"], # 9 + dis.opmap["LOAD_CONST"], 3, 0, # 10, arg=3 + dis.opmap["YIELD_VALUE"], # 13 + + dis.opmap["POP_TOP"], # 14 + dis.opmap["LOAD_CONST"], 0, 0, # 15, arg=0 + dis.opmap["RETURN_VALUE"], # 18 + ]) + + def testYield(self): + self.assertEqual(self.codeYield.func_code.co_code, + self.codeYieldBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeYield.func_code) + expected = [(0, 3), + (4, 8), + (9, 13), + (14, 18)] + self.checkBlocks(table, expected) + bb = table.get_basic_block + # We both branch to unknown and to the best instruction for each yield. + self.assertItemsEqual(bb(0).outgoing, [None, bb(4)]) + self.assertItemsEqual(bb(4).outgoing, [None, bb(9)]) + self.assertItemsEqual(bb(9).incoming, [bb(8)]) + self.assertItemsEqual(bb(9).outgoing, [None, bb(14)]) + + @staticmethod + def codeRaise(): + raise ValueError() + return 0 # pylint: disable=unreachable + + codeRaiseBytecode = pycfg._list_to_string([ + dis.opmap["LOAD_GLOBAL"], 0, 0, # 0, arg=0 + dis.opmap["CALL_FUNCTION"], 0, 0, # 3, arg=0 + + dis.opmap["RAISE_VARARGS"], 1, 0, # 6, arg=1 + + dis.opmap["LOAD_CONST"], 1, 0, # 9, arg=1 + dis.opmap["RETURN_VALUE"], # 12 + ]) + + def testRaise(self): + self.assertEqual(self.codeRaise.func_code.co_code, + self.codeRaiseBytecode) + + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeRaise.func_code) + expected = [(0, 3), + (6, 6), + (9, 12)] + self.checkBlocks(table, expected) + bb = table.get_basic_block + # CALL_FUNCTION could either continue or raise + self.assertItemsEqual(bb(0).outgoing, [bb(6), None]) + # RAISE_VARARGS always raises + self.assertItemsEqual(bb(6).outgoing, [None]) + # This basic block is unreachable + self.assertItemsEqual(bb(9).incoming, []) + # We return to an unknown location + self.assertItemsEqual(bb(9).outgoing, [None]) + + def testRaiseOrder(self): + cfg = pycfg.CFG() + table = cfg.get_block_table(self.codeRaise.func_code) + + bb = table.get_basic_block + self.assertEqual(table.get_ancestors_first_traversal(), + [bb(o) for o in [0, 6, 9]]) + + +class InstructionsIndexTest(unittest.TestCase): + + @staticmethod + def simple_function(x): + x += 1 + y = 4 + x **= y + return x + y + + def setUp(self): + self.index = pycfg.InstructionsIndex(self.simple_function.func_code.co_code) + + def testNext(self): + self.assertEqual(self.index.next(0), 3) + self.assertEqual(self.index.next(6), 7) + self.assertEqual(self.index.next(23), 26) + + def testPrev(self): + self.assertEqual(self.index.prev(3), 0) + self.assertEqual(self.index.prev(7), 6) + self.assertEqual(self.index.prev(26), 23) + + def testRoundTrip(self): + offset = 3 + while offset < len(self.simple_function.func_code.co_code)-1: + self.assertEqual(self.index.prev(self.index.next(offset)), offset) + self.assertEqual(self.index.next(self.index.prev(offset)), offset) + offset = self.index.next(offset) + + +class BytecodeReprTest(unittest.TestCase): + + def checkRoundTrip(self, code): + self.assertEqual(eval(pycfg._bytecode_repr(code)), code) + + def testOtherTestMethods(self): + for method in CFGTest.__dict__: + if hasattr(method, "func_code"): + self.checkRoundTrip(method.func_code.co_code) + + def testThisTestMethods(self): + for method in BytecodeReprTest.__dict__: + if hasattr(method, "func_code"): + self.checkRoundTrip(method.func_code.co_code) + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/test_with.py b/tests/test_with.py index 6c18dae3..8c62e7f3 100644 --- a/tests/test_with.py +++ b/tests/test_with.py @@ -1,7 +1,9 @@ """Test the with statement for Byterun.""" from __future__ import print_function -from . import vmtest + +import unittest +from tests import vmtest class TestWithStatement(vmtest.VmTestCase): @@ -307,3 +309,6 @@ def my_context_manager(val): with my_context_manager(17) as x: assert x == 17 """) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vmtest.py b/tests/vmtest.py index 9e071838..da96c20e 100644 --- a/tests/vmtest.py +++ b/tests/vmtest.py @@ -8,14 +8,15 @@ import types import unittest -import six +from byterun.abstractvm import AbstractVirtualMachine from byterun.pyvm2 import VirtualMachine, VirtualMachineError +import six # Make this false if you need to run the debugger inside a test. -CAPTURE_STDOUT = ('-s' not in sys.argv) +CAPTURE_STDOUT = ("-s" not in sys.argv) # Make this false to see the traceback from a failure inside pyvm2. -CAPTURE_EXCEPTION = 1 +CAPTURE_EXCEPTION = True def dis_code(code): @@ -27,27 +28,17 @@ def dis_code(code): print("") print(code) dis.dis(code) + sys.stdout.flush() -class VmTestCase(unittest.TestCase): - - def assert_ok(self, code, raises=None): - """Run `code` in our VM and in real Python: they behave the same.""" - - code = textwrap.dedent(code) - code = compile(code, "<%s>" % self.id(), "exec", 0, 1) - - # Print the disassembly so we'll see it if the test fails. - dis_code(code) - - real_stdout = sys.stdout - +def run_with_byterun(code, vmclass=VirtualMachine): + real_stdout = sys.stdout + try: # Run the code through our VM. - vm_stdout = six.StringIO() if CAPTURE_STDOUT: # pragma: no branch sys.stdout = vm_stdout - vm = VirtualMachine() + vm = vmclass() vm_value = vm_exc = None try: @@ -66,9 +57,15 @@ def assert_ok(self, code, raises=None): finally: real_stdout.write("-- stdout ----------\n") real_stdout.write(vm_stdout.getvalue()) + return vm_value, vm_stdout.getvalue(), vm_exc + finally: + sys.stdout = real_stdout - # Run the code through the real Python interpreter, for comparison. +def run_with_eval(code): + real_stdout = sys.stdout + try: + # Run the code through the real Python interpreter, for comparison. py_stdout = six.StringIO() sys.stdout = py_stdout @@ -80,16 +77,39 @@ def assert_ok(self, code, raises=None): raise except Exception as e: py_exc = e - + return py_value, py_stdout.getvalue(), py_exc + finally: sys.stdout = real_stdout + +class VmTestCase(unittest.TestCase): + + def assert_ok(self, code, raises=None): + """Run `code` in our VM and in real Python: they behave the same.""" + + code = textwrap.dedent(code) + code = compile(code, "<%s>" % self.id(), "exec", 0, 1) + + # Print the disassembly so we'll see it if the test fails. + dis_code(code) + + vm_value, vm_stdout_value, vm_exc = run_with_byterun(code) + abstractvm_value, abstractvm_stdout_value, abstractvm_exc = ( + run_with_byterun(code, AbstractVirtualMachine)) + py_value, py_stdout_value, py_exc = run_with_eval(code) + self.assert_same_exception(vm_exc, py_exc) - self.assertEqual(vm_stdout.getvalue(), py_stdout.getvalue()) + self.assert_same_exception(abstractvm_exc, py_exc) + self.assertEqual(vm_stdout_value, py_stdout_value) + self.assertEqual(abstractvm_stdout_value, py_stdout_value) self.assertEqual(vm_value, py_value) + self.assertEqual(abstractvm_value, py_value) if raises: self.assertIsInstance(vm_exc, raises) + self.assertIsInstance(abstractvm_exc, raises) else: self.assertIsNone(vm_exc) + self.assertIsNone(abstractvm_exc) def assert_same_exception(self, e1, e2): """Exceptions don't implement __eq__, check it ourselves."""