diff --git a/src/api/utils.py b/src/api/utils.py index 76c578eae..887be1991 100644 --- a/src/api/utils.py +++ b/src/api/utils.py @@ -6,7 +6,7 @@ import shelve import signal from functools import wraps -from typing import IO, Any, Callable, Iterable, List, Optional, Union, TypeVar +from typing import IO, Any, Callable, Iterable, List, Optional, TypeVar, Union from src.api import constants, errmsg, global_ diff --git a/src/arch/z80/optimizer/__init__.py b/src/arch/z80/optimizer/__init__.py index e5af6cba9..4f9c989f1 100644 --- a/src/arch/z80/optimizer/__init__.py +++ b/src/arch/z80/optimizer/__init__.py @@ -14,9 +14,6 @@ def init(): - global LABELS - global JUMP_LABELS - LABELS.clear() JUMP_LABELS.clear() @@ -172,7 +169,7 @@ def initialize_memory(basic_block): get_labels(basic_block) -def optimize(initial_memory): +def optimize(initial_memory: list[str]) -> str: """This will remove useless instructions""" global BLOCKS global PROC_COUNTER diff --git a/src/arch/z80/optimizer/basicblock.py b/src/arch/z80/optimizer/basicblock.py index da856c259..de17139e3 100644 --- a/src/arch/z80/optimizer/basicblock.py +++ b/src/arch/z80/optimizer/basicblock.py @@ -1,22 +1,19 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import Iterable, Iterator, List + +from typing import Final, Iterable, Iterator, List import src.api.config import src.arch.z80.backend.common from src.api.debug import __DEBUG__ +from src.api.utils import first from src.arch.z80.optimizer import helpers from src.arch.z80.optimizer.common import JUMP_LABELS, LABELS from src.arch.z80.optimizer.cpustate import CPUState -from src.arch.z80.optimizer.errors import ( - OptimizerError, - OptimizerInvalidBasicBlockError, -) -from src.arch.z80.optimizer.helpers import ALL_REGS, END_PROGRAM_LABEL +from src.arch.z80.optimizer.helpers import ALL_REGS from src.arch.z80.optimizer.labelinfo import LabelInfo from src.arch.z80.optimizer.memcell import MemCell from src.arch.z80.optimizer.patterns import RE_ID_OR_NUMBER -from src.api.utils import flatten_list, first from src.arch.z80.peephole import evaluator @@ -26,27 +23,33 @@ class BasicBlock(Iterable[MemCell]): __UNIQUE_ID = 0 clean_asm_args = False + def __new__(cls, *args, **kwargs): + cls.__UNIQUE_ID += 1 + return super().__new__(cls) + def __init__(self, memory: Iterable[str]): """Initializes the internal array of instructions.""" self.mem: List[MemCell] = [] - self.next = None # Which (if any) basic block follows this one in memory - self.prev = None # Which (if any) basic block precedes to this one in the code + self.next: BasicBlock | None = None # Which (if any) basic block follows this one in memory + self.prev: BasicBlock | None = None # Which (if any) basic block precedes to this one in the code self.lock = False # True if this block is being accessed by other subroutine self.comes_from: set[BasicBlock] = set() # A list/tuple containing possible jumps to this block self.goes_to: set[BasicBlock] = set() # A list/tuple of possible block to jump from here self.modified = False # True if something has been changed during optimization - self.calls: set[BasicBlock] = set() + self.called_by: set[BasicBlock] = set() self.label_goes = [] self.ignored = False # True if this block can be ignored (it's useless) - self.id = BasicBlock.__UNIQUE_ID + self.id: Final[int] = BasicBlock.__UNIQUE_ID self._bytes = None self._sizeof = None self._max_tstates = None self.optimized = False # True if this block was already optimized - BasicBlock.__UNIQUE_ID += 1 self.code = memory self.cpu = CPUState() + def __hash__(self) -> int: + return self.id + def __len__(self) -> int: return len(self.mem) @@ -56,7 +59,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return "<{}: id: {}, len: {}>".format(self.__class__.__name__, self.id, len(self)) - def __getitem__(self, key): + def __getitem__(self, key) -> MemCell | list[MemCell]: return self.mem[key] def __setitem__(self, key, value: MemCell): @@ -126,115 +129,83 @@ def max_tstates(self): return self._max_tstates @property - def labels(self): - """Returns a t-uple containing labels within this block""" - return [cell.inst for cell in self.mem if cell.is_label] + def labels(self) -> tuple[str, ...]: + """Returns a t-uple containing labels within this block, sorted by position in + memory""" + return tuple(cell.inst for cell in self.mem if cell.is_label) + + def get_first_partition_idx(self) -> int | None: + """Returns the first position where this block can be + partitioned or None if there's no such point + """ + for i, mem in enumerate(self): + if i > 0 and mem.is_label and mem.inst in JUMP_LABELS: + return i + + if (mem.is_ender or mem.code in src.arch.z80.backend.common.ASMS) and i < len(self) - 1: + return i + 1 + + return None @property - def is_partitionable(self): + def is_partitionable(self) -> bool: """Returns if this block can be partitions in 2 or more blocks, because if contains enders. """ - if len(self.mem) < 2: - return False # An atomic block - - if any(x.is_ender or x.code in src.arch.z80.backend.common.ASMS for x in self.mem[:]): - return True - - for label in JUMP_LABELS: - if LABELS[label].basic_block != self: - continue - - for i in range(len(self)): - if not self.mem[i].is_label: - return True # An instruction? Should start with a Jump Label - - if self.mem[i].inst == label: - break # found - else: - raise OptimizerInvalidBasicBlockError(self) # Label is pointing to the wrong block? not found - - return False - - def update_labels(self): - """Update global labels table so they point to the current block""" - for l in self.labels: - LABELS[l].basic_block = self + return self.get_first_partition_idx() is not None - def delete_comes_from(self, basic_block: BasicBlock) -> None: + def delete_comes_from(self, basic_block: BasicBlock | None) -> None: """Removes the basic_block ptr from the list for "comes_from" if it exists. It also sets self.prev to None if it is basic_block. """ if basic_block is None: return - if self.lock: + if basic_block not in self.comes_from: return - self.lock = True - - for elem in self.comes_from: - if elem.id == basic_block.id: - self.comes_from.remove(elem) - break - - self.lock = False + self.comes_from.remove(basic_block) + basic_block.goes_to.remove(self) - def delete_goes_to(self, basic_block: BasicBlock) -> None: + def delete_goes_to(self, basic_block: BasicBlock | None) -> None: """Removes the basic_block ptr from the list for "goes_to" if it exists. It also sets self.next to None if it is basic_block. """ if basic_block is None: return - if self.lock: + if basic_block not in self.goes_to: return - self.lock = True - - for elem in self.goes_to: - if elem.id is basic_block.id: - self.goes_to.remove(elem) - basic_block.delete_comes_from(self) - break - - self.lock = False + self.goes_to.remove(basic_block) + basic_block.comes_from.remove(self) - def add_comes_from(self, basic_block: BasicBlock) -> None: + def add_comes_from(self, basic_block: BasicBlock | None) -> None: """This simulates a set. Adds the basic_block to the comes_from list if not done already. """ if basic_block is None: return - if self.lock: - return - # Return if already added if basic_block in self.comes_from: return - self.lock = True self.comes_from.add(basic_block) - basic_block.add_goes_to(self) - self.lock = False + basic_block.goes_to.add(self) - def add_goes_to(self, basic_block: BasicBlock) -> None: + def add_goes_to(self, basic_block: BasicBlock | None) -> None: """This simulates a set. Adds the basic_block to the goes_to list if not done already. """ - assert basic_block is not None - - if self.lock: + if basic_block is None: return if basic_block in self.goes_to: return - self.lock = True self.goes_to.add(basic_block) - basic_block.add_comes_from(self) - self.lock = False + basic_block.comes_from.add(self) def update_next_block(self): """If the last instruction of this block is a JP, JR or RET (with no @@ -252,6 +223,8 @@ def update_next_block(self): if self.next is not None and last.condition_flag is None: # jp NNN, call NNN, rst, jr NNNN, ret self.next.delete_comes_from(self) + for blk in self.goes_to: + self.delete_goes_to(blk) if last.inst == "ret": return @@ -263,112 +236,15 @@ def update_next_block(self): n_block = LABELS[last.opers[0]].basic_block self.add_goes_to(n_block) - def update_used_by_list(self): - """Every label has a set containing - which blocks jumps (jp, jr, call) if any. - A block can "use" (call/jump) only another block - and only one""" - - # Searches all labels and remove this block out - # of their used_by set, since this might have changed - for label in LABELS.values(): - label.used_by.remove(self) # Delete this bblock - - def clean_up_goes_to(self): - for x in self.goes_to: - if x is not self.next: - self.delete_goes_to(x) - - def clean_up_comes_from(self): - for x in self.comes_from: - if x is not self.prev: - self.delete_comes_from(x) - - def update_goes_and_comes(self): - """Once the block is a Basic one, check the last instruction and updates - goes_to and comes_from set of the receivers. - Note: jp, jr and ret are already done in update_next_block() - """ - if not len(self): - return - - last = self.mem[-1] - inst = last.inst - oper = last.opers - cond = last.condition_flag - - if not last.is_ender: - return - - if cond is None: - self.delete_goes_to(self.next) - - if last.inst in {"ret", "reti", "retn"} and cond is None: - return # subroutine returns are updated from CALLer blocks - - if oper and oper[0]: - if oper[0] not in LABELS: - __DEBUG__("INFO: %s is not defined. No optimization is done." % oper[0], 1) - LABELS[oper[0]] = LabelInfo(oper[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS)) - - LABELS[oper[0]].used_by.add(self) - self.add_goes_to(LABELS[oper[0]].basic_block) - - if inst in {"djnz", "jp", "jr"}: - return - - assert inst in ("call", "rst") - - if self.next is None: - raise OptimizerError("Unexpected NULL next block") - - final_blk = self.next # The block all the final returns should go to - stack = [LABELS[oper[0]].basic_block] - bbset: set[BasicBlock] = set() - - while stack: - bb = stack.pop(0) - while True: - if bb is None: - bb = DummyBasicBlock(ALL_REGS, ALL_REGS) - - if bb in bbset: - break - - bbset.add(bb) - - if isinstance(bb, DummyBasicBlock): - bb.add_goes_to(final_blk) - break - - if bb: - bb1 = bb[-1] - if bb1.inst in {"ret", "reti", "retn"}: - bb.add_goes_to(final_blk) - if bb1.condition_flag is None: # 'ret' - break - elif bb1.inst in ("jp", "jr") and bb1.condition_flag is not None: # jp/jr nc/nz/.. LABEL - if bb1.opers[0] in LABELS: # some labels does not exist (e.g. immediate numeric addresses) - stack.append(LABELS[bb1.opers[0]].basic_block) - else: - raise OptimizerError("Unknown block label '{}'".format(bb1.opers[0])) - - bb = bb.next # next contiguous block - - def is_used(self, regs, i, top=None): + def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool: """Checks whether any of the given regs are required from the given point to the end or not. """ - if i < 0: - i = 0 - if self.lock: return True - if top is None: - top = len(self) - else: - top -= 1 + i = max(i, 0) + top = len(self) if top is None else top + 1 if regs and regs[0][0] == "(" and regs[0][-1] == ")": # A memory address r16 = helpers.single_registers(regs[0][1:-1]) if helpers.is_16bit_oper_register(regs[0][1:-1]) else [] @@ -414,33 +290,17 @@ def is_used(self, regs, i, top=None): return result - def safe_to_write(self, regs, i=0, end_=0): - """Given a list of registers (8 or 16 bits) returns a list of them - that are safe to modify from the given index until the position given - which, if omitted, defaults to the end of the block. - :param regs: register or iterable of registers (8 or 16 bit one) - :param i: initial position of the block to examine - :param end_: final position to examine - :returns: registers safe to write - """ - if helpers.is_register(regs): - regs = set(helpers.single_registers(regs)) - else: - regs = set(helpers.single_registers(x) for x in regs) - return not regs.intersection(self.requires(i, end_)) - - def requires(self, i=0, end_=None): + def requires(self, i: int = 0, end_: int | None = None) -> set[str]: """Returns a list of registers and variables this block requires. By default checks from the beginning (i = 0). :param i: initial position of the block to examine :param end_: final position to examine :returns: registers safe to write """ - if i < 0: - i = 0 + i = max(i, 0) end_ = len(self) if end_ is None or end_ > len(self) else end_ regs = {"a", "b", "c", "d", "e", "f", "h", "l", "i", "ixh", "ixl", "iyh", "iyl", "sp"} - result = set() + result: set[str] = set() for ii in range(i, end_): for r in self.mem[ii].requires: @@ -459,13 +319,13 @@ def requires(self, i=0, end_=None): return result - def destroys(self, i=0): + def destroys(self, i: int = 0) -> list[str]: """Returns a list of registers this block destroys By default checks from the beginning (i = 0). """ regs = {"a", "b", "c", "d", "e", "f", "h", "l", "i", "ixh", "ixl", "iyh", "iyl", "sp"} top = len(self) - result = [] + result: list[str] = [] for ii in range(i, top): for r in self.mem[ii].destroys: @@ -478,10 +338,6 @@ def destroys(self, i=0): return result - def swap(self, a, b): - """Swaps mem positions a and b""" - self.mem[a], self.mem[b] = self.mem[b], self.mem[a] - def goes_requires(self, regs): """Returns whether any of the goes_to block requires any of the given registers. @@ -492,16 +348,6 @@ def goes_requires(self, regs): return False - def get_label_idx(self, label): - """Returns the index of a label. - Returns None if not found. - """ - for i in range(len(self)): - if self.mem[i].is_label and self.mem[i].inst == label: - return i - - return None - def get_first_non_label_instruction(self): """Returns the memcell of the given block, which is not a LABEL. @@ -593,7 +439,7 @@ def optimize(self, patterns_list): if not p.cond.eval(match): continue - # all patterns applied successfully. Apply this pattern + # all patterns matched successfully. Apply this rule new_code = list(code) matched = new_code[i : i + len(p.patt)] new_code[i : i + len(p.patt)] = p.template.filter(match) @@ -620,130 +466,154 @@ class DummyBasicBlock(BasicBlock): about what registers uses an destroys """ - def __init__(self, destroys, requires): + def __init__(self, destroys: Iterable[str], requires: Iterable[str]): BasicBlock.__init__(self, []) - self.__destroys = [x for x in destroys] - self.__requires = [x for x in requires] + self.__destroys = tuple(destroys) + self.__requires = set(requires) + self.code = ["ret"] - def destroys(self, i: int = 0): - return [x for x in self.__destroys] + def destroys(self, i: int = 0) -> list[str]: + return list(self.__destroys) - def requires(self, i: int = 0, end_=None): - return [x for x in self.__requires] + def requires(self, i: int = 0, end_=None) -> set[str]: + return set(self.__requires) - def is_used(self, regs, i, top=None): + def is_used(self, regs: Iterable[str], i: int, top: int | None = None) -> bool: return len([x for x in regs if x in self.__requires]) > 0 -def block_partition(block, i): - """Returns two blocks, as a result of partitioning the given one at - i-th instruction. - """ - i += 1 +def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock, BasicBlock]: + assert 0 <= start_of_new_block < len(block), f"Invalid split pos: {start_of_new_block}" new_block = BasicBlock([]) - new_block.mem = block.mem[i:] - block.mem = block.mem[:i] - - for label, lbl_info in LABELS.items(): - if lbl_info.basic_block != block or lbl_info.position < len(block): - continue - - lbl_info.basic_block = new_block - lbl_info.position -= len(block) - - for b_ in list(block.goes_to): - block.delete_goes_to(b_) - new_block.add_goes_to(b_) - - new_block.label_goes = block.label_goes - block.label_goes = [] + new_block.mem = block.mem[start_of_new_block:] + block.mem = block.mem[:start_of_new_block] new_block.next = block.next - new_block.prev = block block.next = new_block - new_block.add_comes_from(block) + new_block.prev = block if new_block.next is not None: new_block.next.prev = new_block - if block in new_block.next.comes_from: - new_block.next.delete_comes_from(block) - new_block.next.add_comes_from(new_block) - block.update_next_block() + for blk in list(block.goes_to): + block.delete_goes_to(blk) + new_block.add_goes_to(blk) + + block.add_goes_to(new_block) + + for i, mem in enumerate(new_block): + if mem.is_label and mem.inst in LABELS: + LABELS[mem.inst].basic_block = new_block + LABELS[mem.inst].position = i + + if block[-1].is_ender: + if not block[-1].condition_flag: # If it's an unconditional jp, jr, call, ret + block.delete_goes_to(block.next) return block, new_block -def get_basic_blocks(block): - """If a block is not partitionable, returns a list with the same block. - Otherwise, returns a list with the resulting blocks, recursively. - """ - result = [] - EDP = END_PROGRAM_LABEL + ":" - - new_block = block - while new_block: - block = new_block - new_block = None - - for i, mem in enumerate(block): - if i and mem.code == EDP: # END_PROGRAM label always starts a basic block - block, new_block = block_partition(block, i - 1) - LABELS[END_PROGRAM_LABEL].basic_block = new_block - break +def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None: + calling_blocks: dict[BasicBlock, BasicBlock] = {} - if mem.is_ender: - block, new_block = block_partition(block, i) - if not mem.condition_flag: - block.delete_goes_to(new_block) + # Compute which blocks use jump labels + for bb in basic_blocks: + if bb[-1].is_ender and (op := bb[-1].branch_arg) in LABELS: + LABELS[op].used_by.add(bb) - for l in mem.opers: - if l in LABELS: - JUMP_LABELS.add(l) - block.label_goes.append(l) - break + # For these blocks, add the referenced block in the goes_to + for label in jump_labels: + for bb in LABELS[label].used_by: + bb.add_goes_to(LABELS[label].basic_block) - if mem.is_label and mem.code[:-1] not in LABELS: - raise OptimizerError("Missing label '{}' in labels list".format(mem.code[:-1])) + # Annotate which blocks uses call (which should be the last instruction) + for bb in basic_blocks: + if bb[-1].inst != "call": + continue - if mem.code in src.arch.z80.backend.common.ASMS: # An inline ASM block - block, new_block = block_partition(block, max(0, i - 1)) - break + op = bb[-1].branch_arg + if op in LABELS: + LABELS[op].basic_block.called_by.add(bb) + calling_blocks[bb] = LABELS[op].basic_block - result.append(block) + # For the annotated blocks, trace their goes_to, and their goes_to from + # their goes_to and so on, until ret (unconditional or not) is found, and + # save that block in a set for later + visited: set[tuple[BasicBlock, BasicBlock]] = set() + pending: set[tuple[BasicBlock, BasicBlock]] = set(calling_blocks.items()) - for label in JUMP_LABELS: - blk = LABELS[label].basic_block - if isinstance(blk, DummyBasicBlock): + while pending: + caller, bb = pending.pop() + if (caller, bb) in visited: continue - must_partition = False - # This label must point to the beginning of blk, just before the code - # Otherwise we must partition it (must_partition = True) - for i, cell in enumerate(blk): - if cell.inst == label: - break # already starts with this label + visited.add((caller, bb)) - if cell.is_label: - continue # It's another label + if not bb[-1].is_ender: # if it does not branch, search in the next block + pending.add((caller, bb.next)) + continue - if cell.is_ender: - raise OptimizerInvalidBasicBlockError(blk) + if bb[-1].inst in {"ret", "reti", "retn"}: + if bb[-1].condition_flag: + pending.add((caller, bb.next)) - must_partition = True - else: - __DEBUG__("Label {} not found in BasicBlock {}".format(label, blk.id)) + bb.add_goes_to(caller.next) continue - if must_partition: - j = result.index(blk) - block_, new_block_ = block_partition(blk, i - 1) - LABELS[label].basic_block = new_block_ - result.pop(j) - result.insert(j, block_) - result.insert(j + 1, new_block_) + if bb[-1].inst in {"call", "rst"}: # A call from this block + if bb[-1].condition_flag: # if it has conditions, it can return from the next block + pending.add((caller, bb.next)) + + +def get_jump_labels(main_basic_block: BasicBlock) -> set[str]: + """Given the main basic block (which contain the entire program), populate + the global JUMP_LABEL set with LABELS used by CALL, JR, JP (i.e JP LABEL0) + Also updates the global LABELS index with the pertinent information. + + Any BasicBlock containing a JUMP_LABEL in any position which is not the initial + one (0 position) must be split at that point into two basic blocks. + """ + jump_labels: set[str] = set() + + for i, mem in enumerate(main_basic_block): + if mem.is_label: + LABELS.pop(mem.inst) + LABELS[mem.inst] = LabelInfo( + label=mem.inst, addr=i, basic_block=main_basic_block, position=i # Unknown yet + ) + continue + + if not mem.is_ender: + continue + + lbl = mem.branch_arg + if lbl is None: + continue + + jump_labels.add(lbl) + + if lbl not in LABELS: + __DEBUG__(f"INFO: {lbl} is not defined. No optimization is done.", 2) + LABELS[lbl] = LabelInfo(lbl, 0, DummyBasicBlock(ALL_REGS, ALL_REGS)) + + return jump_labels + + +def get_basic_blocks(block: BasicBlock) -> list[BasicBlock]: + """If a block is not partitionable, returns a list with the same block. + Otherwise, returns a list with the resulting blocks. + """ + result: list[BasicBlock] = [block] + JUMP_LABELS.clear() + JUMP_LABELS.update(get_jump_labels(block)) + + # Split basic blocks per label or branch instruction + split_pos = block.get_first_partition_idx() + while split_pos is not None: + _, block = split_block(block, split_pos) + result.append(block) + split_pos = block.get_first_partition_idx() - for b in result: - b.update_goes_and_comes() + compute_calls(result, JUMP_LABELS) return result diff --git a/src/arch/z80/optimizer/common.py b/src/arch/z80/optimizer/common.py index 61385e2fd..e98428b6c 100644 --- a/src/arch/z80/optimizer/common.py +++ b/src/arch/z80/optimizer/common.py @@ -1,16 +1,18 @@ # -*- config: utf-8 -*- +from __future__ import annotations -from typing import Dict +from typing import TYPE_CHECKING -from . import labelinfo +if TYPE_CHECKING: + from .labelinfo import LabelInfo # counter for generating unique random fake values RAND_COUNT = 0 # Labels which must start a basic block, because they're used in a JP/CALL -LABELS: Dict[str, labelinfo.LabelInfo] = {} # Label -> LabelInfo object +LABELS: dict[str, LabelInfo] = {} # Label -> LabelInfo object -JUMP_LABELS = set([]) +JUMP_LABELS: set[str] = set([]) MEMORY = [] # Instructions emitted by the backend # PROC labels name space counter diff --git a/src/arch/z80/optimizer/labelinfo.py b/src/arch/z80/optimizer/labelinfo.py index 4494a9f14..0b1c5e5a1 100644 --- a/src/arch/z80/optimizer/labelinfo.py +++ b/src/arch/z80/optimizer/labelinfo.py @@ -1,24 +1,28 @@ # -*- coding: utf-8 -*- +from __future__ import annotations +from dataclasses import dataclass, field from typing import TYPE_CHECKING + from . import common, errors if TYPE_CHECKING: from .basicblock import BasicBlock -class LabelInfo(object): - """Class describing label information""" +@dataclass +class LabelInfo: + """Class describing label information + Stores the label name, the address counter into memory (rather useless) + and which basic block contains it. + """ - def __init__(self, label, addr, basic_block=None, position=0): - """Stores the label name, the address counter into memory (rather useless) - and which basic block contains it. - """ - self.label = label - self.addr = addr - self.basic_block = basic_block - self.position = position # Position within the block - self.used_by: set[BasicBlock] = set() # Which BB uses this label, if any + label: str + addr: int # Memory address or 0 + basic_block: BasicBlock | None = None # Basic Block this label is in + position: int = 0 # Position within the Basic Block + used_by: set[BasicBlock] = field(default_factory=set) # Which BB uses this label, if any - if label in common.LABELS: - raise errors.DuplicatedLabelError(label) + def __post_init__(self): + if self.label in common.LABELS: + raise errors.DuplicatedLabelError(self.label) diff --git a/src/arch/z80/optimizer/memcell.py b/src/arch/z80/optimizer/memcell.py index b0d95ec81..f488f6f2e 100644 --- a/src/arch/z80/optimizer/memcell.py +++ b/src/arch/z80/optimizer/memcell.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import re +from functools import cached_property from typing import List, Optional, Set, Union import src.arch.z80.backend.common @@ -16,7 +17,7 @@ class MemCell: the instruction. """ - __slots__ = "addr", "__instr" + __slots__ = "addr", "__instr", "__dict__" __instr: Asm def __init__(self, instr: str, addr: int): @@ -71,7 +72,7 @@ def is_ender(self) -> bool: """Returns if this instruction is a BLOCK ender""" return self.inst in helpers.BLOCK_ENDERS - @property + @cached_property def inst(self) -> str: """Returns just the asm instruction in lower case. E.g. 'ld', 'jp', 'pop' @@ -91,12 +92,19 @@ def condition_flag(self) -> Optional[str]: """ return self.__instr.cond - @property + @cached_property def opers(self) -> List[str]: """Returns a list of operands (i.e. register) this mnemonic uses""" return self.__instr.oper - @property + @cached_property + def branch_arg(self) -> str | None: + if self.__instr.inst not in {"jr", "jp", "call", "rst", "djnz"}: + return None + + return self.__instr.asm.split()[-1] + + @cached_property def destroys(self) -> Set[str]: """Returns which single registers (including f, flag) this instruction changes. @@ -165,7 +173,7 @@ def destroys(self) -> Set[str]: return res - @property + @cached_property def requires(self) -> Set[str]: """Returns the registers, operands, etc. required by an instruction.""" if self.code in src.arch.z80.backend.common.ASMS: diff --git a/src/arch/z80/peephole/engine.py b/src/arch/z80/peephole/engine.py index fc8ea7920..1e2f9784a 100644 --- a/src/arch/z80/peephole/engine.py +++ b/src/arch/z80/peephole/engine.py @@ -15,6 +15,7 @@ REG_IF, REG_REPLACE, REG_WITH, + DefineLine, ) from src.arch.z80.peephole.pattern import BlockPattern from src.arch.z80.peephole.template import BlockTemplate @@ -27,7 +28,7 @@ class OptPattern(NamedTuple): cond: Evaluator template: BlockTemplate parsed: Dict[str, Union[List[str], int]] - defines: int + defines: list[tuple[str, DefineLine]] fname: str diff --git a/src/arch/z80/peephole/opts/103_o3_or_and_a.opt b/src/arch/z80/peephole/opts/103_o3_or_and_a.opt index f2f008e44..fb00a768e 100644 --- a/src/arch/z80/peephole/opts/103_o3_or_and_a.opt +++ b/src/arch/z80/peephole/opts/103_o3_or_and_a.opt @@ -1,4 +1,4 @@ -;; Removes useless XOR a +;; Removes useless AND, OR a OLEVEL: 3 OFLAG: 103 diff --git a/tests/arch/zx48k/optimizer/test_basicblock.py b/tests/arch/zx48k/optimizer/test_basicblock.py index c0c2a3ea7..001072b72 100644 --- a/tests/arch/zx48k/optimizer/test_basicblock.py +++ b/tests/arch/zx48k/optimizer/test_basicblock.py @@ -194,3 +194,23 @@ def test_is_used_xor_ix(self): """ self.blk.code = [x for x in code.split("\n") if x.strip()] assert self.blk.is_used(["(ix-1)"], 0) + + def test_loop_goes_and_comes(self): + code = """ + ld (_dir), hl + .LABEL.__LABEL0: + ld a, (_n) + ld hl, (_dir) + ld (hl), a + ld hl, _n + inc (hl) + jp .LABEL.__LABEL0 + ld (ix-1), a + """ + self.blk.code = [x for x in code.split("\n") if x.strip()] + optimizer.initialize_memory(self.blk) + blks = basicblock.get_basic_blocks(self.blk) + assert len(blks) == 3 + b1, b2, b3 = blks + assert b1.goes_to == b2.goes_to == {b2} + assert not b3.comes_from diff --git a/tests/arch/zx48k/optimizer/test_optimizer.py b/tests/arch/zx48k/optimizer/test_optimizer.py new file mode 100644 index 000000000..c364751f5 --- /dev/null +++ b/tests/arch/zx48k/optimizer/test_optimizer.py @@ -0,0 +1,42 @@ +from contextlib import contextmanager + +from src.arch.z80 import optimizer +from src.arch.z80.peephole import engine + + +@contextmanager +def mock_options_level(level: int): + initial_level = optimizer.OPTIONS.optimization_level + + try: + optimizer.OPTIONS.optimization_level = level + yield + finally: + optimizer.OPTIONS.optimization_level = initial_level + + +class TestOptimizer: + def setup_class(cls): + engine.main() + + def test_unrequired_or_a(self): + code_src = """ + call .core.__LTI8 + or a + ld bc, 0 + di + ld hl, (.core.__CALL_BACK__) + ld sp, hl + exx + pop hl + pop iy + pop ix + exx + ei + ret + """ + code = [x.strip() for x in code_src.split("\n") if x.strip()] + + with mock_options_level(4): + optimized_code = optimizer.optimize(code) + assert optimized_code.split("\n")[:2] == ["call .core.__LTI8", "ld bc, 0"] diff --git a/tests/functional/zx48k/opt4_poke.asm b/tests/functional/zx48k/opt4_poke.asm new file mode 100644 index 000000000..246eea45e --- /dev/null +++ b/tests/functional/zx48k/opt4_poke.asm @@ -0,0 +1,48 @@ + org 32768 +.core.__START_PROGRAM: + di + push ix + push iy + exx + push hl + exx + ld hl, 0 + add hl, sp + ld (.core.__CALL_BACK__), hl + ei + jp .core.__MAIN_PROGRAM__ +.core.__CALL_BACK__: + DEFW 0 +.core.ZXBASIC_USER_DATA: + ; Defines USER DATA Length in bytes +.core.ZXBASIC_USER_DATA_LEN EQU .core.ZXBASIC_USER_DATA_END - .core.ZXBASIC_USER_DATA + .core.__LABEL__.ZXBASIC_USER_DATA_LEN EQU .core.ZXBASIC_USER_DATA_LEN + .core.__LABEL__.ZXBASIC_USER_DATA EQU .core.ZXBASIC_USER_DATA +_dir: + DEFB 00, 00 +_n: + DEFB 00 +.core.ZXBASIC_USER_DATA_END: +.core.__MAIN_PROGRAM__: + ld hl, 22528 + ld (_dir), hl +.LABEL.__LABEL0: + ld a, (_n) + ld hl, (_dir) + ld (hl), a + ld hl, _n + inc (hl) + jp .LABEL.__LABEL0 +.core.__END_PROGRAM: + di + ld hl, (.core.__CALL_BACK__) + ld sp, hl + exx + pop hl + pop iy + pop ix + exx + ei + ret + ;; --- end of user code --- + END diff --git a/tests/functional/zx48k/opt4_poke.bas b/tests/functional/zx48k/opt4_poke.bas new file mode 100644 index 000000000..f4205b084 --- /dev/null +++ b/tests/functional/zx48k/opt4_poke.bas @@ -0,0 +1,12 @@ +DIM dir as UInteger +DIM n as UByte +' Imprmimos una X +' Obtenemos la dirección del atributo +dir = 16384 + 6144 +' Bucle infinito +DO + ' Cambiamos el atributo con un poke (no borra la X) + poke dir,n + ' Incrementamos el atributo + n=n+1 +LOOP