Skip to content
This repository has been archived by the owner on May 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request boriel-basic#669 from boriel/bugfix/next_O4_mul16
Browse files Browse the repository at this point in the history
fix: implement ZX Next MUL dependencies
  • Loading branch information
boriel committed Sep 9, 2023
2 parents fbebabf + a322391 commit 9afbb6d
Show file tree
Hide file tree
Showing 26 changed files with 396 additions and 208 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import pathlib

from setuptools import setup

packages = ["src"]
Expand Down
3 changes: 3 additions & 0 deletions src/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import importlib
from types import ModuleType

from src.api.debug import __DEBUG__

__all__ = (
"zx48k",
"zxnext",
Expand All @@ -23,6 +25,7 @@ def set_target_arch(target_arch: str):
global target
assert target_arch in AVAILABLE_ARCHITECTURES
target = importlib.import_module(f".{target_arch}", "src.arch")
__DEBUG__(f"Target architecture set to {target_arch}")


set_target_arch(AVAILABLE_ARCHITECTURES[0])
109 changes: 64 additions & 45 deletions src/arch/z80/optimizer/basicblock.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Final, Iterable, Iterator, List
from typing import TYPE_CHECKING, Final, Iterable, Iterator, Sequence

import src.api.config
import src.arch.z80.backend.common
from src.api import errmsg
from src.api.config import OPTIONS
from src.api.debug import __DEBUG__
from src.api.utils import sfirst
from src.api.utils import flatten_list, sfirst
from src.arch.z80.backend.common import ASMS
from src.arch.z80.peephole import evaluator

from . import helpers
from .cpustate import CPUState
from .helpers import ALL_REGS
from .helpers import (
ALL_REGS,
dict_intersection,
idx_args,
is_16bit_oper_register,
new_tmp_val,
simplify_asm_args,
single_registers,
)
from .labelinfo import LabelInfo
from .labels_dict import LabelsDict
from .memcell import MemCell
from .patterns import RE_ID_OR_NUMBER

if TYPE_CHECKING:
from .main import Optimizer

__all__ = "BasicBlock", "DummyBasicBlock"


class BasicBlock(Iterable[MemCell]):
class BasicBlock(Sequence[MemCell]):
"""A Class describing a basic block"""

__UNIQUE_ID = 0
Expand All @@ -29,24 +41,24 @@ def __new__(cls, *args, **kwargs):
cls.__UNIQUE_ID += 1
return super().__new__(cls)

def __init__(self, memory: Iterable[str]):
def __init__(self, memory: Iterable[str], optimizer: Optimizer) -> None:
"""Initializes the internal array of instructions."""
self.mem: List[MemCell] = []
self.optimizer = optimizer
self.mem: list[MemCell] = []
self.next: BasicBlock | None = None # Which (if any) basic block follows this one in memory
self.prev: BasicBlock | None = None # Which (if any) basic block precedes to this one in the code
self.lock = False # True if this block is being accessed by other subroutine
self.comes_from: set[BasicBlock] = set() # A list/tuple containing possible jumps to this block
self.goes_to: set[BasicBlock] = set() # A list/tuple of possible block to jump from here
self.modified = False # True if something has been changed during optimization
self.called_by: set[BasicBlock] = set()
self.label_goes = []
self.ignored = False # True if this block can be ignored (it's useless)
self.id: Final[int] = BasicBlock.__UNIQUE_ID
self._bytes = None
self._sizeof = None
self._max_tstates = None
self.optimized = False # True if this block was already optimized
self.code = memory
self.code = list(memory)
self.cpu = CPUState()

def __hash__(self) -> int:
Expand All @@ -61,7 +73,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return "<{}: id: {}, len: {}>".format(self.__class__.__name__, self.id, len(self))

def __getitem__(self, key) -> MemCell | list[MemCell]:
def __getitem__(self, key):
return self.mem[key]

def __setitem__(self, key, value: MemCell):
Expand All @@ -74,31 +86,36 @@ def __iter__(self) -> Iterator[MemCell]:
for mem in self.mem:
yield mem

@property
def jump_labels(self) -> set[str]:
return self.optimizer.JUMP_LABELS

@property
def opt_labels(self) -> LabelsDict:
return self.optimizer.LABELS

def pop(self, i: int) -> MemCell:
self._bytes = None
self._sizeof = None
self._max_tstates = None
return self.mem.pop(i)

def insert(self, i: int, value: str):
memcell = MemCell(value, i)
self.mem.insert(i, memcell)
self._bytes = None
self._sizeof = None
self._max_tstates = None

@property
def code(self) -> List[str]:
def code(self) -> list[str]:
return [x.code for x in self.mem]

@code.setter
def code(self, value: Iterable[str]):
self._set_code(value)

def _set_code(self, value: Iterable[str]) -> None:
assert isinstance(value, Iterable)
assert all(isinstance(x, str) for x in value)
mems = tuple(value)
assert all(isinstance(x, str) for x in mems)
if self.clean_asm_args:
self.mem = [MemCell(helpers.simplify_asm_args(asm), i) for i, asm in enumerate(value)]
self.mem = [MemCell(simplify_asm_args(asm), i) for i, asm in enumerate(mems)]
else:
self.mem = [MemCell(asm, i) for i, asm in enumerate(value)]
self.mem = [MemCell(asm, i) for i, asm in enumerate(mems)]

self._bytes = None
self._sizeof = None
Expand Down Expand Up @@ -136,15 +153,15 @@ def labels(self) -> tuple[str, ...]:
memory"""
return tuple(cell.inst for cell in self.mem if cell.is_label)

def get_first_partition_idx(self, jump_labels: set[str]) -> int | None:
def get_first_partition_idx(self) -> int | None:
"""Returns the first position where this block can be
partitioned or None if there's no such point
"""
for i, mem in enumerate(self):
if i > 0 and mem.is_label and mem.inst in jump_labels:
if i > 0 and mem.is_label and mem.inst in self.jump_labels:
return i

if (mem.is_ender or mem.code in src.arch.z80.backend.common.ASMS) and i < len(self) - 1:
if (mem.is_ender or mem.code in ASMS) and i < len(self) - 1:
return i + 1

return None
Expand Down Expand Up @@ -209,7 +226,7 @@ def add_goes_to(self, basic_block: BasicBlock | None) -> None:
self.goes_to.add(basic_block)
basic_block.comes_from.add(self)

def update_next_block(self, labels: dict[str, LabelInfo]) -> None:
def update_next_block(self) -> None:
"""If the last instruction of this block is a JP, JR or RET (with no
conditions) then goes_to set contains just a
single block
Expand All @@ -231,14 +248,16 @@ def update_next_block(self, labels: dict[str, LabelInfo]) -> None:
if last.inst == "ret":
return

if last.opers[0] not in labels.keys():
if last.opers[0] not in self.opt_labels:
__DEBUG__("INFO: %s is not defined. No optimization is done." % last.opers[0], 2)
labels[last.opers[0]] = LabelInfo(last.opers[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
self.opt_labels[last.opers[0]] = LabelInfo(
last.opers[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS, self.optimizer)
)

n_block = labels[last.opers[0]].basic_block
n_block = self.opt_labels[last.opers[0]].basic_block
self.add_goes_to(n_block)

def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:
def is_used(self, regs: Sequence[str], i: int, top: int | None = None) -> bool:
"""Checks whether any of the given regs are required from the given point
to the end or not.
"""
Expand All @@ -249,8 +268,8 @@ def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:
top = len(self) if top is None else top + 1

if regs and regs[0][0] == "(" and regs[0][-1] == ")": # A memory address
r16 = helpers.single_registers(regs[0][1:-1]) if helpers.is_16bit_oper_register(regs[0][1:-1]) else []
ix = helpers.single_registers(helpers.idx_args(regs[0][1:-1])[0]) if helpers.idx_args(regs[0][1:-1]) else []
r16 = single_registers(regs[0][1:-1]) if is_16bit_oper_register(regs[0][1:-1]) else []
ix = single_registers(idx_args(regs[0][1:-1])[0]) if idx_args(regs[0][1:-1]) else [] # type: ignore

rr = set(r16 + ix)
mem_vars = set([] if rr else RE_ID_OR_NUMBER.findall(regs[0]))
Expand All @@ -274,7 +293,7 @@ def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:

return True

regs = src.api.utils.flatten_list([helpers.single_registers(x) for x in regs]) # make a copy
regs = flatten_list([single_registers(x) for x in regs]) # make a copy
for ii in range(i, top):
if any(r in regs for r in self.mem[ii].requires):
return True
Expand Down Expand Up @@ -385,11 +404,11 @@ def guesses_initial_state_from_origin_blocks(self) -> tuple[dict[str, str], dict
return {}, {}

regs = sfirst(self.comes_from).cpu.regs
mems = sfirst(self.comes_from).cpu.mem
mems = dict(sfirst(self.comes_from).cpu.mem)

for blk in list(self.comes_from)[1:]:
regs = helpers.dict_intersection(regs, blk.cpu.regs)
mems = helpers.dict_intersection(mems, blk.cpu.mem)
regs = dict_intersection(regs, blk.cpu.regs)
mems = dict_intersection(mems, blk.cpu.mem)

return regs, mems

Expand Down Expand Up @@ -417,12 +436,12 @@ def optimize(self, patterns_list):
# monkey-patches some functions in this optimizer level (> 2)
evaluator.UNARY["GVAL"] = lambda x: self.cpu.get(x)
evaluator.UNARY["FLAGVAL"] = lambda x: {
"c": str(self.cpu.C) if self.cpu.C is not None else helpers.new_tmp_val(),
"z": str(self.cpu.Z) if self.cpu.Z is not None else helpers.new_tmp_val(),
}.get(x.lower(), helpers.new_tmp_val())
"c": str(self.cpu.C) if self.cpu.C is not None else new_tmp_val(),
"z": str(self.cpu.Z) if self.cpu.Z is not None else new_tmp_val(),
}.get(x.lower(), new_tmp_val())
evaluator.UNARY["IS_REQUIRED"] = lambda x: self.is_used([x], i + len(p.patt))

if src.api.config.OPTIONS.optimization_level > 3:
if OPTIONS.optimization_level > 3:
regs, mems = self.guesses_initial_state_from_origin_blocks()
else:
regs, mems = {}, {}
Expand All @@ -447,8 +466,8 @@ def optimize(self, patterns_list):
new_code = list(code)
matched = new_code[i : i + len(p.patt)]
new_code[i : i + len(p.patt)] = p.template.filter(match)
src.api.errmsg.info("pattern applied [{}:{}]".format("%03i" % p.flag, p.fname))
src.api.debug.__DEBUG__("matched: \n {}".format("\n ".join(matched)), level=1)
errmsg.info("pattern applied [{}:{}]".format("%03i" % p.flag, p.fname))
__DEBUG__("matched: \n {}".format("\n ".join(matched)), level=1)
changed = new_code != code
if changed:
code = new_code
Expand All @@ -470,8 +489,8 @@ class DummyBasicBlock(BasicBlock):
about what registers uses an destroys
"""

def __init__(self, destroys: Iterable[str], requires: Iterable[str]):
BasicBlock.__init__(self, [])
def __init__(self, destroys: Iterable[str], requires: Iterable[str], optimizer: Optimizer) -> None:
BasicBlock.__init__(self, [], optimizer)
self.__destroys = tuple(destroys)
self.__requires = set(requires)
self.code = ["ret"]
Expand Down

0 comments on commit 9afbb6d

Please sign in to comment.