Skip to content
This repository has been archived by the owner on May 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request boriel-basic#674 from boriel/refact/convert_optimi…
Browse files Browse the repository at this point in the history
…zer_into_class

refact: convert optimizer module into class
  • Loading branch information
boriel committed Sep 9, 2023
2 parents 0e8017c + 473b17c commit fbebabf
Show file tree
Hide file tree
Showing 13 changed files with 453 additions and 331 deletions.
2 changes: 1 addition & 1 deletion src/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_absolute_filename_path(fname: str) -> str:
return os.path.realpath(os.path.expanduser(fname))


def get_relative_filename_path(fname: str, current_dir: str = None) -> str:
def get_relative_filename_path(fname: str, current_dir: str | None = None) -> str:
"""Given an absolute path, returns it relative to the current directory,
that is, if the file is in the same folder or any of it children, only
the path from the current folder onwards is returned. Otherwise, the
Expand Down
14 changes: 14 additions & 0 deletions src/arch/interface/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod


class OptimizerInterface(ABC):
"""Implements the Peephole Optimizer"""

@abstractmethod
def init(self) -> None:
pass

@abstractmethod
def optimize(self, initial_memory: list[str]) -> str:
"""This will remove useless instructions"""
pass
4 changes: 2 additions & 2 deletions src/arch/z80/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .main import init, optimize
from .main import Optimizer

__all__ = "init", "optimize"
__all__ = ("Optimizer",)
13 changes: 6 additions & 7 deletions src/arch/z80/optimizer/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from src.arch.z80.peephole import evaluator

from . import helpers
from .common import JUMP_LABELS, LABELS
from .cpustate import CPUState
from .helpers import ALL_REGS
from .labelinfo import LabelInfo
Expand Down Expand Up @@ -137,12 +136,12 @@ def labels(self) -> tuple[str, ...]:
memory"""
return tuple(cell.inst for cell in self.mem if cell.is_label)

def get_first_partition_idx(self) -> int | None:
def get_first_partition_idx(self, jump_labels: set[str]) -> int | None:
"""Returns the first position where this block can be
partitioned or None if there's no such point
"""
for i, mem in enumerate(self):
if i > 0 and mem.is_label and mem.inst in JUMP_LABELS:
if i > 0 and mem.is_label and mem.inst in jump_labels:
return i

if (mem.is_ender or mem.code in src.arch.z80.backend.common.ASMS) and i < len(self) - 1:
Expand Down Expand Up @@ -210,7 +209,7 @@ def add_goes_to(self, basic_block: BasicBlock | None) -> None:
self.goes_to.add(basic_block)
basic_block.comes_from.add(self)

def update_next_block(self):
def update_next_block(self, labels: dict[str, LabelInfo]) -> None:
"""If the last instruction of this block is a JP, JR or RET (with no
conditions) then goes_to set contains just a
single block
Expand All @@ -232,11 +231,11 @@ def update_next_block(self):
if last.inst == "ret":
return

if last.opers[0] not in LABELS.keys():
if last.opers[0] not in labels.keys():
__DEBUG__("INFO: %s is not defined. No optimization is done." % last.opers[0], 2)
LABELS[last.opers[0]] = LabelInfo(last.opers[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
labels[last.opers[0]] = LabelInfo(last.opers[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))

n_block = LABELS[last.opers[0]].basic_block
n_block = labels[last.opers[0]].basic_block
self.add_goes_to(n_block)

def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:
Expand Down
20 changes: 0 additions & 20 deletions src/arch/z80/optimizer/common.py

This file was deleted.

58 changes: 33 additions & 25 deletions src/arch/z80/optimizer/flow_graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from src.api.debug import __DEBUG__

from .basicblock import BasicBlock, DummyBasicBlock
from .common import JUMP_LABELS, LABELS
from .helpers import ALL_REGS
from .labelinfo import LabelInfo
from .labels_dict import LabelsDict

__all__ = ("get_basic_blocks",)


def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock, BasicBlock]:
def _split_block(block: BasicBlock, start_of_new_block: int, labels: LabelsDict) -> tuple[BasicBlock, BasicBlock]:
assert 0 <= start_of_new_block < len(block), f"Invalid split pos: {start_of_new_block}"
new_block = BasicBlock([])
new_block.mem = block.mem[start_of_new_block:]
Expand All @@ -28,9 +28,9 @@ def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock,
block.add_goes_to(new_block)

for i, mem in enumerate(new_block):
if mem.is_label and mem.inst in LABELS:
LABELS[mem.inst].basic_block = new_block
LABELS[mem.inst].position = i
if mem.is_label and mem.inst in labels:
labels[mem.inst].basic_block = new_block
labels[mem.inst].position = i

if block[-1].is_ender:
if not block[-1].condition_flag: # If it's an unconditional jp, jr, call, ret
Expand All @@ -39,28 +39,32 @@ def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock,
return block, new_block


def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None:
def _compute_calls(
basic_blocks: list[BasicBlock],
labels: LabelsDict,
jump_labels: set[str],
) -> None:
calling_blocks: dict[BasicBlock, BasicBlock] = {}

# Compute which blocks use jump labels
for bb in basic_blocks:
if bb[-1].is_ender and (op := bb[-1].branch_arg) in LABELS:
LABELS[op].used_by.add(bb)
if bb[-1].is_ender and (op := bb[-1].branch_arg) in labels:
labels[op].used_by.add(bb)

# For these blocks, add the referenced block in the goes_to
for label in jump_labels:
for bb in LABELS[label].used_by:
bb.add_goes_to(LABELS[label].basic_block)
for bb in labels[label].used_by:
bb.add_goes_to(labels[label].basic_block)

# Annotate which blocks uses call (which should be the last instruction)
for bb in basic_blocks:
if bb[-1].inst != "call":
continue

op = bb[-1].branch_arg
if op in LABELS:
LABELS[op].basic_block.called_by.add(bb)
calling_blocks[bb] = LABELS[op].basic_block
if op in labels:
labels[op].basic_block.called_by.add(bb)
calling_blocks[bb] = labels[op].basic_block

# For the annotated blocks, trace their goes_to, and their goes_to from
# their goes_to and so on, until ret (unconditional or not) is found, and
Expand Down Expand Up @@ -91,7 +95,7 @@ def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None
pending.add((caller, bb.next))


def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
def _get_jump_labels(main_basic_block: BasicBlock, labels: LabelsDict) -> set[str]:
"""Given the main basic block (which contain the entire program), populate
the global JUMP_LABEL set with LABELS used by CALL, JR, JP (i.e JP LABEL0)
Also updates the global LABELS index with the pertinent information.
Expand All @@ -103,8 +107,8 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:

for i, mem in enumerate(main_basic_block):
if mem.is_label:
LABELS.pop(mem.inst)
LABELS[mem.inst] = LabelInfo(
labels.pop(mem.inst)
labels[mem.inst] = LabelInfo(
label=mem.inst, addr=i, basic_block=main_basic_block, position=i # Unknown yet
)
continue
Expand All @@ -118,28 +122,32 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:

jump_labels.add(lbl)

if lbl not in LABELS:
if lbl not in labels:
__DEBUG__(f"INFO: {lbl} is not defined. No optimization is done.", 2)
LABELS[lbl] = LabelInfo(lbl, 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
labels[lbl] = LabelInfo(lbl, 0, DummyBasicBlock(ALL_REGS, ALL_REGS))

return jump_labels


def get_basic_blocks(block: BasicBlock) -> list[BasicBlock]:
def get_basic_blocks(
block: BasicBlock,
labels: LabelsDict,
jump_labels: set[str],
) -> list[BasicBlock]:
"""If a block is not partitionable, returns a list with the same block.
Otherwise, returns a list with the resulting blocks.
"""
result: list[BasicBlock] = [block]
JUMP_LABELS.clear()
JUMP_LABELS.update(get_jump_labels(block))
jump_labels.clear()
jump_labels.update(_get_jump_labels(block, labels))

# Split basic blocks per label or branch instruction
split_pos = block.get_first_partition_idx()
split_pos = block.get_first_partition_idx(jump_labels)
while split_pos is not None:
_, block = split_block(block, split_pos)
_, block = _split_block(block, split_pos, labels)
result.append(block)
split_pos = block.get_first_partition_idx()
split_pos = block.get_first_partition_idx(jump_labels)

compute_calls(result, JUMP_LABELS)
_compute_calls(result, labels, jump_labels)

return result
130 changes: 97 additions & 33 deletions src/arch/z80/optimizer/helpers.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,117 @@
# -*- coding: utf-8 -*-

from typing import Any, Iterable, TypeVar
from typing import Any, Final, Iterable, TypeVar, cast

from . import patterns

__all__ = (
"ALL_REGS",
"END_PROGRAM_LABEL",
"init",
"new_tmp_val",
"new_tmp_val16",
"new_tmp_val16_from_label",
"is_unknown",
"is_unknown8",
"is_unknown16",
"get_orig_label_from_unknown16",
"get_L_from_unknown_value",
"get_H_from_unknown_value",
"is_mem_access",
"is_number",
"is_label",
"valnum",
"simplify_arg",
"simplify_asm_args",
"is_register",
"is_8bit_normal_register",
"is_8bit_idx_register",
"is_8bit_oper_register",
"is_16bit_normal_register",
"is_16bit_idx_register",
"is_16bit_composed_register",
"is_16bit_oper_register",
"LO16",
"HI16",
"single_registers",
"idx_args",
"LO16_val",
"HI16_val",
"dict_intersection",
)

from . import common, patterns

T = TypeVar("T")
K = TypeVar("K")


# All 'single' registers (even f FLAG one). SP is not decomposable so it's 'single' already
ALL_REGS = {"a", "b", "c", "d", "e", "f", "h", "l", "ixh", "ixl", "iyh", "iyl", "r", "i", "sp"}
ALL_REGS: Final[frozenset[str]] = frozenset(
[
"a",
"b",
"c",
"d",
"e",
"f",
"h",
"l",
"ixh",
"ixl",
"iyh",
"iyl",
"r",
"i",
"sp",
]
)

# The set of all registers as they can appear in any instruction as operands
REGS_OPER_SET = {
"a",
"b",
"c",
"d",
"e",
"h",
"l",
"bc",
"de",
"hl",
"sp",
"ix",
"iy",
"ixh",
"ixl",
"iyh",
"iyl",
"af",
"af'",
"i",
"r",
}
REGS_OPER_SET: Final[frozenset[str]] = frozenset(
[
"a",
"b",
"c",
"d",
"e",
"h",
"l",
"bc",
"de",
"hl",
"sp",
"ix",
"iy",
"ixh",
"ixl",
"iyh",
"iyl",
"af",
"af'",
"i",
"r",
]
)

# Instructions that marks the end of a basic block (any branching instruction)
BLOCK_ENDERS = {"jr", "jp", "call", "ret", "reti", "retn", "djnz", "rst"}
BLOCK_ENDERS: Final[frozenset[str]] = frozenset(["jr", "jp", "call", "ret", "reti", "retn", "djnz", "rst"])
UNKNOWN_PREFIX: Final[str] = "*UNKNOWN_"
END_PROGRAM_LABEL: Final[str] = "__END_PROGRAM" # Label for end program
HL_SEP: Final[str] = "|" # Hi/Low separator
_RAND_COUNT: int = 0 # Counter for unknown values

UNKNOWN_PREFIX = "*UNKNOWN_"
END_PROGRAM_LABEL = "__END_PROGRAM" # Label for end program
HL_SEP = "|" # Hi/Low separator

def init() -> None:
global _RAND_COUNT
_RAND_COUNT = 0


def new_tmp_val() -> str:
"""Generates an 8-bit unknown value"""
common.RAND_COUNT += 1
return f"{UNKNOWN_PREFIX}{common.RAND_COUNT}"
global _RAND_COUNT

_RAND_COUNT += 1
return f"{UNKNOWN_PREFIX}{_RAND_COUNT}"


def new_tmp_val16() -> str:
Expand Down Expand Up @@ -390,7 +454,7 @@ def LO16_val(x: int | str | None) -> str:
if not is_unknown(x):
return new_tmp_val()

return x.split(HL_SEP)[-1]
return cast(str, x).split(HL_SEP)[-1]


def HI16_val(x: int | str | None) -> str:
Expand Down

0 comments on commit fbebabf

Please sign in to comment.