diff --git a/hera/parser.py b/hera/parser.py index f316567..54c7b9c 100644 --- a/hera/parser.py +++ b/hera/parser.py @@ -140,57 +140,45 @@ def value(self, matches): ) -def parse(text, *, expand_includes=False): - """Parse a HERA program into a list of Op objects.""" +def parse(text, *, fpath=None, expand_includes=True, visited=None): + """Parse a HERA program from a string into a list of Op objects. + + `fpath` is the path of the file being parsed, as it will appear in error and + debugging messages. It defaults to "". + + `expand_includes` determines whether an #include statement should be executed during + parsing or not. + + `visited` is a set of file paths that have already been visited. If any #include + statement matches a path in this set, an error is raised. + """ + if visited is None: + visited = set() + + if fpath is not None: + visited.add(get_canonical_path(fpath)) + + linevector = text.splitlines() + loc = Location(fpath or "", linevector) + try: tree = _parser.parse(text) except UnexpectedCharacters as e: - raise HERAError("unexpected character", e.line, e.column) from None + raise HERAError("unexpected character", e.line, e.column, loc) from None except UnexpectedToken as e: if e.token.type == "$END": raise HERAError("unexpected end of file") from None else: - raise HERAError("unexpected character", e.line, e.column) from None + raise HERAError("unexpected character", e.line, e.column, loc) from None except LarkError as e: - raise HERAError("invalid syntax", e.line, e.column) from None + raise HERAError("invalid syntax", e.line, e.column, loc) from None if isinstance(tree, Tree): - return tree.children + ops = tree.children elif isinstance(tree, Op): - return [tree] + ops = [tree] else: - return tree - - -def parse_file(fpath, *, expand_includes=True, allow_stdin=False, visited=None): - """Parse a file containing a HERA program into a list of Op objects.""" - if visited is None: - visited = set() - - visited.add(get_canonical_path(fpath)) - - if allow_stdin and fpath == "-": - program = sys.stdin.read() - else: - try: - with open(fpath) as f: - program = f.read() - except FileNotFoundError: - raise HERAError('file "{}" does not exist.'.format(fpath)) - except PermissionError: - raise HERAError('permission denied to open file "{}".'.format(fpath)) - except OSError: - raise HERAError('could not open file "{}".'.format(fpath)) - - canonical_path = get_canonical_path(fpath) - linevector = program.splitlines() - loc = Location(fpath, linevector) - - try: - ops = parse(program) - except HERAError as e: - e.location = loc - raise e + ops = tree ops = [op._replace(location=loc) for op in ops] @@ -217,6 +205,30 @@ def parse_file(fpath, *, expand_includes=True, allow_stdin=False, visited=None): return ops +def parse_file(fpath, *, expand_includes=True, allow_stdin=False, visited=None): + """Convenience function for parsing a HERA file. Reads the contents of the file and + delegates parsing to the `parse` function. + + `allow_stdin` should be set to True if you wish the file path "-" to be interpreted + as standard input instead of a file with that actual name. See `parse` for the + meaning of `expand_includes` and `visited`. + """ + if allow_stdin and fpath == "-": + program = sys.stdin.read() + else: + try: + with open(fpath) as f: + program = f.read() + except FileNotFoundError: + raise HERAError('file "{}" does not exist.'.format(fpath)) + except PermissionError: + raise HERAError('permission denied to open file "{}".'.format(fpath)) + except OSError: + raise HERAError('could not open file "{}".'.format(fpath)) + + return parse(program, fpath=fpath, expand_includes=expand_includes, visited=visited) + + def replace_escapes(s): return re.sub(r"\\.", repl, s) diff --git a/hera/utils.py b/hera/utils.py index bb6f3be..af83f95 100644 --- a/hera/utils.py +++ b/hera/utils.py @@ -163,7 +163,7 @@ def _align_caret(line, col): def get_canonical_path(fpath): - if fpath == "-": + if fpath == "-" or fpath == "": return fpath else: return os.path.realpath(fpath) diff --git a/test/test_debugger.py b/test/test_debugger.py index dfbbf9d..2e82439 100644 --- a/test/test_debugger.py +++ b/test/test_debugger.py @@ -49,7 +49,7 @@ def test_set_breakpoint(debugger): assert should_continue assert len(debugger.breakpoints) == 1 assert 0 in debugger.breakpoints - assert debugger.breakpoints[0] == "2" + assert debugger.breakpoints[0] == ":2" def test_set_breakpoint_not_on_line_of_code(debugger, capsys): @@ -118,7 +118,10 @@ def test_execute_next_after_end_of_program(debugger, capsys): assert should_continue assert debugger.vm.pc == 9000 - assert capsys.readouterr().out == "Program has finished executing. Press 'r' to restart.\n" + assert ( + capsys.readouterr().out + == "Program has finished executing. Press 'r' to restart.\n" + ) def test_execute_next_with_too_many_args(debugger, capsys): @@ -265,11 +268,11 @@ def test_resolve_location_invalid_format(debugger): def test_get_breakpoint_name(debugger): # Zero'th instruction corresponds to second line. - assert debugger.get_breakpoint_name(0) == "2" + assert debugger.get_breakpoint_name(0) == ":2" def test_print_current_op(debugger, capsys): debugger.print_current_op() captured = capsys.readouterr() - assert captured.out == "0000 SET(R1, 10)\n" + assert captured.out == "[, line 2]\n\n0000 SET(R1, 10)\n" diff --git a/test/test_parser.py b/test/test_parser.py index 160c1dc..8b429ff 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -137,7 +137,10 @@ def test_parse_multiline_comment(): def test_parse_include_amidst_instructions(): program = 'SETLO(R1, 42)\n#include "whatever"\n' - assert parse(program) == [Op("SETLO", ["R1", 42]), Op("#include", ['"whatever"'])] + assert parse(program, expand_includes=False) == [ + Op("SETLO", ["R1", 42]), + Op("#include", ['"whatever"']), + ] def test_parse_missing_comma(): diff --git a/test/test_preprocessor.py b/test/test_preprocessor.py index 0a31e63..6fd13ae 100644 --- a/test/test_preprocessor.py +++ b/test/test_preprocessor.py @@ -5,7 +5,7 @@ from hera.parser import Op from hera.preprocessor import convert, convert_set, preprocess, substitute_label -from hera.utils import HERAError, IntToken +from hera.utils import IntToken def R(s): diff --git a/test/test_typechecker.py b/test/test_typechecker.py index 953b4dc..c024aac 100644 --- a/test/test_typechecker.py +++ b/test/test_typechecker.py @@ -16,7 +16,7 @@ U4, U16, ) -from hera.utils import HERAError, IntToken +from hera.utils import IntToken def R(s):