diff --git a/api/optimize.py b/api/optimize.py index e361fa9a0..c1f76cefa 100644 --- a/api/optimize.py +++ b/api/optimize.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from typing import NamedTuple from ast_ import NodeVisitor from .config import OPTIONS import api.errmsg @@ -24,28 +25,8 @@ def __init__(self, obj): self.obj = obj -class OptimizerVisitor(NodeVisitor): - """ Implements some optimizations - """ - NOP = symbols.NOP() # Return this for "erased" nodes - - @staticmethod - def TYPE(type_): - """ Converts a backend type (from api.constants) - to a SymbolTYPE object (taken from the SYMBOL_TABLE). - If type_ is already a SymbolTYPE object, nothing - is done. - """ - if isinstance(type_, symbols.TYPE): - return type_ - - assert TYPE.is_valid(type_) - return gl.SYMBOL_TABLE.basic_types[type_] - +class GenericVisitor(NodeVisitor): def visit(self, node): - if self.O_LEVEL < 1: # Optimize only if O1 or above - return node - stack = [ToVisit(node)] last_result = None @@ -76,6 +57,31 @@ def _visit(self, node): return meth(node.obj) + +class OptimizerVisitor(GenericVisitor): + """ Implements some optimizations + """ + NOP = symbols.NOP() # Return this for "erased" nodes + + @staticmethod + def TYPE(type_): + """ Converts a backend type (from api.constants) + to a SymbolTYPE object (taken from the SYMBOL_TABLE). + If type_ is already a SymbolTYPE object, nothing + is done. + """ + if isinstance(type_, symbols.TYPE): + return type_ + + assert TYPE.is_valid(type_) + return gl.SYMBOL_TABLE.basic_types[type_] + + def visit(self, node): + if self.O_LEVEL < 1: # Optimize only if O1 or above + return node + + return super().visit(node) + @property def O_LEVEL(self): return OPTIONS.optimization.value @@ -196,7 +202,7 @@ def visit_IF(self, node): if not block_accessed and chk.is_number(expr_): # constant condition if expr_.value: # always true (then_) yield then_ - else: # always false (else_) + else: # always false (else_) yield else_ return @@ -282,3 +288,70 @@ def _update_bound_status(self, arg: symbols.VARARRAY): if arg.scope == SCOPE.local and not arg.byref: arg.scopeRef.owner.locals_size = api.symboltable.SymbolTable.compute_offsets(arg.scopeRef) + + +class VarDependency(NamedTuple): + parent: symbols.VAR + dependency: symbols.VAR + + +class VariableVisitor(GenericVisitor): + _original_variable = None + _parent_variable = None + _visited = set() + + @staticmethod + def generic_visit(node): + if node not in VariableVisitor._visited: + VariableVisitor._visited.add(node) + for i in range(len(node.children)): + node.children[i] = yield ToVisit(node.children[i]) + yield node + + def has_circular_dependency(self, var_dependency: VarDependency) -> bool: + if var_dependency.dependency == VariableVisitor._original_variable: + api.errmsg.error(VariableVisitor._original_variable.lineno, + "Circular dependency between '{}' and '{}'".format( + VariableVisitor._original_variable.name, var_dependency.parent)) + return True + + return False + + def get_var_dependencies(self, var_entry: symbols.VAR): + visited = set() + result = set() + + def visit_var(entry): + if entry in visited: + return + + visited.add(entry) + if not isinstance(entry, symbols.VAR): + for child in entry.children: + visit_var(child) + if isinstance(child, symbols.VAR): + result.add(VarDependency(parent=VariableVisitor._parent_variable, dependency=child)) + return + + VariableVisitor._parent_variable = entry + if entry.alias is not None: + result.add(VarDependency(parent=entry, dependency=entry.alias)) + visit_var(entry.alias) + elif entry.addr is not None: + visit_var(entry.addr) + + visit_var(var_entry) + return result + + def visit_VARDECL(self, node: symbols.VARDECL): + """ Checks for cyclic dependencies in aliasing variables + """ + VariableVisitor._visited = set() + VariableVisitor._original_variable = node.entry + for dependency in self.get_var_dependencies(node.entry): + if self.has_circular_dependency(dependency): + break + + VariableVisitor._visited = set() + VariableVisitor._original_variable = VariableVisitor._parent_variable = None + yield node diff --git a/libzxbc/zxb.py b/libzxbc/zxb.py index f3aa528e3..39bff4556 100755 --- a/libzxbc/zxb.py +++ b/libzxbc/zxb.py @@ -331,6 +331,8 @@ def main(args=None, emitter=None): backend.MEMORY[:] = [] # This will fill MEMORY with global declared variables + var_checker = api.optimize.VariableVisitor() + var_checker.visit(zxbparser.data_ast) translator = arch.zx48k.VarTranslator() translator.visit(zxbparser.data_ast) if gl.has_errors: diff --git a/libzxbc/zxbparser.py b/libzxbc/zxbparser.py index 3696dca36..bb7ea3b2f 100755 --- a/libzxbc/zxbparser.py +++ b/libzxbc/zxbparser.py @@ -669,7 +669,7 @@ def p_var_decl_at(p): if entry is None: return - if p[5].token == 'CONST': + if p[5].token in 'CONST': tmp = p[5].expr if tmp.token == 'UNARY' and tmp.operator == 'ADDRESS': # Must be an ID if tmp.operand.token in ('VAR', 'LABEL'): @@ -691,7 +691,7 @@ def p_var_decl_at(p): api.errmsg.syntax_error_address_must_be_constant(p.lineno(4)) return else: - entry.addr = str(make_typecast(_TYPE(gl.PTR_TYPE), p[5], p.lineno(4)).value) + entry.addr = make_typecast(_TYPE(gl.PTR_TYPE), p[5], p.lineno(4)) entry.accessed = True if entry.scope == SCOPE.local: SYMBOL_TABLE.make_static(entry.name) diff --git a/tests/functional/dim_at_label4.bas b/tests/functional/dim_at_label4.bas new file mode 100644 index 000000000..e80a5e22d --- /dev/null +++ b/tests/functional/dim_at_label4.bas @@ -0,0 +1,7 @@ +REM Error: circular dependency + +DIM a at @b +DIM b at @c +DIM c at @a + + diff --git a/tests/functional/dim_at_label5.bas b/tests/functional/dim_at_label5.bas new file mode 100644 index 000000000..0c75e71a8 --- /dev/null +++ b/tests/functional/dim_at_label5.bas @@ -0,0 +1,4 @@ +REM Error: circular dependency + +DIM x at @x + diff --git a/tests/functional/dim_at_label6.bas b/tests/functional/dim_at_label6.bas new file mode 100644 index 000000000..bfb859298 --- /dev/null +++ b/tests/functional/dim_at_label6.bas @@ -0,0 +1,4 @@ + +DIM a at @b + 1 +DIM b at @c +DIM c at @a diff --git a/tests/functional/dim_at_label7.bas b/tests/functional/dim_at_label7.bas new file mode 100644 index 000000000..bfb859298 --- /dev/null +++ b/tests/functional/dim_at_label7.bas @@ -0,0 +1,4 @@ + +DIM a at @b + 1 +DIM b at @c +DIM c at @a