diff --git a/hera/data.py b/hera/data.py index 587fc37..d53e230 100644 --- a/hera/data.py +++ b/hera/data.py @@ -80,6 +80,7 @@ class Token: FMT = "TOKEN_FMT" INCLUDE = "TOKEN_INCLUDE" EOF = "TOKEN_EOF" + ERROR = "TOKEN_ERROR" UNKNOWN = "TOKEN_UNKNOWN" def __init__(self, type_, value, location=None): diff --git a/hera/lexer.py b/hera/lexer.py index 1a6dd7d..c7646e5 100644 --- a/hera/lexer.py +++ b/hera/lexer.py @@ -8,7 +8,7 @@ import string from typing import Optional -from hera.data import HERAError, Location, Messages, Token +from hera.data import Location, Messages, Token from hera.utils import NAMED_REGISTERS @@ -28,10 +28,10 @@ def __init__(self, text: str, *, path: Optional[str] = None) -> None: # Set the current token. self.next_token() - def get_location(self): + def get_location(self) -> Location: return Location(self.line, self.column, self.path, self.file_lines) - def next_token(self): + def next_token(self) -> Token: self.skip() if self.position >= len(self.text): @@ -48,47 +48,19 @@ def next_token(self): length = self.read_int() self.set_token(Token.INT, length=length) elif ch == '"': - loc = self.get_location() - s = self.consume_str() - self.tkn = Token(Token.STRING, s, loc) + self.consume_str() elif ch == "'": - if self.peek_char() == "\\": - if self.peek_char(3) == "'": - ch = self.peek_char(2) - escape = escape_char(ch) - self.next_char() # open quote - self.next_char() # backslash - loc = self.get_location() - self.next_char() # character - self.next_char() # end quote - if len(escape) == 2: - self.tkn = Token(Token.CHAR, escape[1], loc) - self.warn("unrecognized backslash escape", loc) - else: - self.tkn = Token(Token.CHAR, escape, loc) - else: - if self.peek_char(2) == "'": - ch = self.peek_char() - self.next_char() # open quote - loc = self.get_location() - self.next_char() # character - self.next_char() # end quote - self.tkn = Token(Token.CHAR, ch, loc) - else: - self.set_token(Token.UNKNOWN) + self.consume_char() elif self.text[self.position :].startswith("#include"): self.set_token(Token.INCLUDE, length=len("#include")) elif ch == "<": - self.next_char() - length = self.read_bracketed() - self.set_token(Token.BRACKETED, length=length) - if self.position < len(self.text): - self.next_char() + self.consume_bracketed() elif ch == ":": self.position += 1 length = self.read_symbol() self.set_token(Token.FMT, length=length) elif ch == "-": + # TODO: This doesn't handle e.g. x-10. if self.peek_char().isdigit(): self.position += 1 length = self.read_int() @@ -121,7 +93,7 @@ def next_token(self): return self.tkn - def read_int(self): + def read_int(self) -> int: length = 1 digits = {str(i) for i in range(10)} peek = self.peek_char() @@ -135,7 +107,7 @@ def read_int(self): return length - def read_symbol(self): + def read_symbol(self) -> int: length = 1 while True: ch = self.peek_char(length) @@ -144,22 +116,29 @@ def read_symbol(self): length += 1 return length - def read_bracketed(self): - length = 1 - while self.position + length < len(self.text) and self.peek_char(length) != ">": - length += 1 - if self.position + length == len(self.text): - raise HERAError("unclosed bracketed expression", self.get_location()) - return length + def consume_bracketed(self) -> None: + self.next_char() + loc = self.get_location() + start = self.position + while self.position < len(self.text) and self.text[self.position] != ">": + self.next_char() - def consume_str(self): + if self.position == len(self.text): + self.tkn = Token(Token.ERROR, "unclosed bracketed expression", loc) + return + + self.tkn = Token(Token.BRACKETED, self.text[start : self.position], loc) + self.next_char() + + def consume_str(self) -> None: sbuilder = [] loc = self.get_location() self.next_char() while self.position < len(self.text) and self.text[self.position] != '"': if self.text[self.position] == "\\": if self.position == len(self.text) - 1: - raise HERAError("unclosed string literal", loc) + self.next_char() + break escape = escape_char(self.text[self.position + 1]) sbuilder.append(escape) @@ -170,13 +149,47 @@ def consume_str(self): else: sbuilder.append(self.text[self.position]) self.next_char() - if self.position < len(self.text): + + if self.position == len(self.text): + self.tkn = Token(Token.ERROR, "unclosed string literal", loc) + return + + self.next_char() + s = "".join(sbuilder) + self.tkn = Token(Token.STRING, s, loc) + + def consume_char(self) -> None: + loc = self.get_location() + self.next_char() + start = self.position + while self.position < len(self.text) and self.text[self.position] != "'": + if self.text[self.position] == "\\": + self.next_char() self.next_char() + + if self.position == len(self.text): + self.tkn = Token(Token.ERROR, "unclosed character literal", loc) + return + + contents = self.text[start : self.position] + + if len(contents) == 1: + loc = loc._replace(column=loc.column + 1) + self.tkn = Token(Token.CHAR, contents, loc) + elif len(contents) == 2 and contents[0] == "\\": + loc = loc._replace(column=loc.column + 2) + escape = escape_char(contents[1]) + if len(escape) == 2: + self.tkn = Token(Token.CHAR, escape[1], loc) + self.warn("unrecognized backslash escape", loc) + else: + self.tkn = Token(Token.CHAR, escape, loc) else: - raise HERAError("unclosed string literal", loc) - return "".join(sbuilder) + self.tkn = Token(Token.ERROR, "over-long character literal", loc) + + self.next_char() - def skip(self): + def skip(self) -> None: """Skip past whitespace and comments.""" while True: # Skip whitespace. @@ -206,7 +219,7 @@ def skip(self): else: break - def next_char(self): + def next_char(self) -> None: if self.text[self.position] == "\n": self.line += 1 self.column = 1 @@ -214,22 +227,22 @@ def next_char(self): self.column += 1 self.position += 1 - def peek_char(self, n=1): + def peek_char(self, n=1) -> str: return ( self.text[self.position + n] if self.position + n < len(self.text) else "" ) - def set_token(self, typ, *, length=1): + def set_token(self, typ: str, *, length=1) -> None: loc = self.get_location() value = self.text[self.position : self.position + length] for _ in range(length): self.next_char() self.tkn = Token(typ, value, loc) - def err(self, msg, loc): + def err(self, msg: str, loc) -> None: self.messages.err(msg, loc) - def warn(self, msg, loc): + def warn(self, msg: str, loc) -> None: self.messages.warn(msg, loc) diff --git a/hera/parser.py b/hera/parser.py index 5c54dec..526f82b 100644 --- a/hera/parser.py +++ b/hera/parser.py @@ -248,6 +248,8 @@ def expect(self, types: Union[str, Set[str]], msg="unexpected token") -> bool: if self.lexer.tkn.type not in types: if self.lexer.tkn.type == Token.EOF: self.err("premature end of input") + elif self.lexer.tkn.type == Token.ERROR: + self.err(self.lexer.tkn.value) else: self.err(msg) diff --git a/test/test_parse_error.py b/test/test_parse_error.py index 7ad0b1c..5312d1c 100644 --- a/test/test_parse_error.py +++ b/test/test_parse_error.py @@ -287,3 +287,21 @@ def test_parse_error_for_non_ASCII_byte_in_file(capsys): """ ) + + +def test_parse_error_for_over_long_character_literal(capsys): + with pytest.raises(SystemExit): + execute_program_helper("SET(R1, 'abc')") + + captured = capsys.readouterr().err + assert ( + captured + == """\ + +Error: over-long character literal, line 1 col 9 of + + SET(R1, 'abc') + ^ + +""" + ) diff --git a/test/test_unit/test_lexer.py b/test/test_unit/test_lexer.py index f784e76..fcf72d2 100644 --- a/test/test_unit/test_lexer.py +++ b/test/test_unit/test_lexer.py @@ -69,9 +69,10 @@ def test_lexer_with_character_literal_backslash_escape(): def test_lexer_with_over_long_character_literal(): - lexer = lex_helper("'abc'") + lexer = lex_helper("'abc' 10") - assert eq(lexer.tkn, Token(Token.UNKNOWN, "'")) + assert eq(lexer.tkn, Token(Token.ERROR, "over-long character literal")) + assert eq(lexer.next_token(), Token(Token.INT, "10")) def test_lexer_with_string(): @@ -92,6 +93,24 @@ def test_lexer_with_empty_string(): assert eq(lexer.next_token(), Token(Token.EOF, "")) +def test_lexer_with_unclosed_string_literal(): + lexer = lex_helper('"hello') + + assert eq(lexer.tkn, Token(Token.ERROR, "unclosed string literal")) + + +def test_lexer_with_unclosed_character_literal(): + lexer = lex_helper("'a") + + assert eq(lexer.tkn, Token(Token.ERROR, "unclosed character literal")) + + +def test_lexer_with_unclosed_bracketed_expression(): + lexer = lex_helper(" #include "lib.hera"')