From 4e374d01768299385b982d81ea0e07841ad439b7 Mon Sep 17 00:00:00 2001 From: Ian Fisher Date: Thu, 7 Feb 2019 16:18:05 -0500 Subject: [PATCH] Add more type annotations and fix some existing ones --- hera/checker.py | 6 +++--- hera/lexer.py | 3 ++- hera/loader.py | 7 ++----- hera/main.py | 13 ++++++++----- hera/parser.py | 26 +++++++++++++------------- hera/utils.py | 28 ++++++++++++++-------------- 6 files changed, 42 insertions(+), 41 deletions(-) diff --git a/hera/checker.py b/hera/checker.py index 13e9271..5ea8e77 100644 --- a/hera/checker.py +++ b/hera/checker.py @@ -153,7 +153,7 @@ def get_labels( return (symbol_table, messages) -def operation_length(op): +def operation_length(op: AbstractOperation) -> int: if isinstance(op, RegisterBranch): if len(op.tokens) == 1 and op.tokens[0].type == Token.SYMBOL: return 3 @@ -180,7 +180,7 @@ def operation_length(op): return 1 -def looks_like_a_CONSTANT(op): +def looks_like_a_CONSTANT(op: AbstractOperation) -> bool: return ( op.name == "CONSTANT" and len(op.args) == 2 @@ -189,7 +189,7 @@ def looks_like_a_CONSTANT(op): ) -def out_of_range(n): +def out_of_range(n: int) -> bool: return n < -32768 or n >= 65536 diff --git a/hera/lexer.py b/hera/lexer.py index 91edc3b..1a6dd7d 100644 --- a/hera/lexer.py +++ b/hera/lexer.py @@ -6,6 +6,7 @@ Version: February 2019 """ import string +from typing import Optional from hera.data import HERAError, Location, Messages, Token from hera.utils import NAMED_REGISTERS @@ -14,7 +15,7 @@ class Lexer: """A lexer for HERA (and for the debugging mini-language).""" - def __init__(self, text, *, path=None): + def __init__(self, text: str, *, path: Optional[str] = None) -> None: self.text = text self.file_lines = text.splitlines() if self.text.endswith("\n"): diff --git a/hera/loader.py b/hera/loader.py index 4098f24..414abc6 100644 --- a/hera/loader.py +++ b/hera/loader.py @@ -5,7 +5,6 @@ Version: February 2019 """ import sys -from typing import Dict, Tuple from .checker import check from .data import HERAError, Messages, Program, Settings @@ -13,7 +12,7 @@ from .utils import handle_messages, read_file -def load_program(text: str, settings=Settings()) -> Tuple[Program, Dict[str, int]]: +def load_program(text: str, settings=Settings()) -> Program: """Parse the string into a program, type-check it, and preprocess it. A tuple (ops, symbol_table) is returned. @@ -23,9 +22,7 @@ def load_program(text: str, settings=Settings()) -> Tuple[Program, Dict[str, int return handle_messages(settings, check(oplist, settings)) -def load_program_from_file( - path: str, settings=Settings() -) -> Tuple[Program, Dict[str, int]]: +def load_program_from_file(path: str, settings=Settings()) -> Program: """Convenience function to a read a file and then invoke `load_program_from_str` on its contents. """ diff --git a/hera/main.py b/hera/main.py index 9f1e9bc..6858be5 100644 --- a/hera/main.py +++ b/hera/main.py @@ -5,6 +5,7 @@ """ import sys import functools +from typing import Optional from .data import Settings, VOLUME_QUIET, VOLUME_VERBOSE from .debugger import debug @@ -13,14 +14,14 @@ from .vm import VirtualMachine -def external_main(argv=None): +def external_main(argv=None) -> None: """A wrapper around main that ignores its return value, so it is not printed to the console when the program exits. """ main(argv) -def main(argv=None): +def main(argv=None) -> Optional[VirtualMachine]: """The main entry point into hera-py.""" arguments = parse_args(argv) path = arguments[""] @@ -45,19 +46,21 @@ def main(argv=None): if arguments["preprocess"]: settings.allow_interrupts = True main_preprocess(path, settings) + return None elif arguments["debug"]: main_debug(path, settings) + return None else: return main_execute(path, settings) -def main_debug(path, settings): +def main_debug(path: str, settings: Settings) -> None: """Debug the program.""" program = load_program_from_file(path, settings) debug(program, settings) -def main_execute(path, settings): +def main_execute(path: str, settings: Settings) -> VirtualMachine: """Execute the program.""" program = load_program_from_file(path, settings) @@ -71,7 +74,7 @@ def main_execute(path, settings): return vm -def main_preprocess(path, settings): +def main_preprocess(path: str, settings: Settings) -> None: """Preprocess the program and print it to standard output.""" program = load_program_from_file(path, settings) if program.data: diff --git a/hera/parser.py b/hera/parser.py index 759bd97..5c54dec 100644 --- a/hera/parser.py +++ b/hera/parser.py @@ -13,7 +13,7 @@ Version: February 2019 """ import os.path -from typing import List, Set, Tuple # noqa: F401 +from typing import List, Optional, Set, Tuple, Union # noqa: F401 from .data import HERAError, Messages, Settings, Token from .lexer import Lexer @@ -57,7 +57,7 @@ def parse(self) -> List[AbstractOperation]: self.messages.extend(self.lexer.messages) return ops - def match_program(self): + def match_program(self) -> List[AbstractOperation]: expecting_brace = False ops = [] while self.lexer.tkn.type != Token.EOF: @@ -94,7 +94,7 @@ def match_program(self): return ops - def match_op(self, name_tkn): + def match_op(self, name_tkn: Token) -> Optional[AbstractOperation]: """Match an operation, assuming that self.lexer.tkn is on the left parenthesis. """ self.lexer.next_token() @@ -110,7 +110,7 @@ def match_op(self, name_tkn): VALUE_TOKENS = {Token.INT, Token.REGISTER, Token.SYMBOL, Token.STRING, Token.CHAR} - def match_optional_arglist(self): + def match_optional_arglist(self) -> List[Token]: if self.lexer.tkn.type == Token.RPAREN: return [] @@ -142,7 +142,7 @@ def match_optional_arglist(self): self.lexer.next_token() return args - def match_value(self): + def match_value(self) -> Token: if self.lexer.tkn.type == Token.INT: # Detect zero-prefixed octal numbers. prefix = self.lexer.tkn.value[:2] @@ -176,7 +176,7 @@ def match_value(self): else: return self.lexer.tkn - def match_include(self): + def match_include(self) -> List[AbstractOperation]: root_path = self.lexer.path tkn = self.lexer.next_token() msg = "expected quote or angle-bracket delimited string" @@ -206,7 +206,7 @@ def match_include(self): else: return self.expand_angle_include(tkn) - def handle_cpp_boilerplate(self): + def handle_cpp_boilerplate(self) -> None: self.lexer.next_token() if self.expect(Token.LPAREN, "expected left parenthesis"): self.lexer.next_token() @@ -217,7 +217,7 @@ def handle_cpp_boilerplate(self): self.expect(Token.LBRACE, "expected left curly brace") self.lexer.next_token() - def expand_angle_include(self, include_path): + def expand_angle_include(self, include_path: Token) -> List[AbstractOperation]: # There is no check for recursive includes in this function, under the # assumption that system libraries do not have recursive includes. if include_path.value == "HERA.h": @@ -241,7 +241,7 @@ def expand_angle_include(self, include_path): self.lexer = old_lexer return ops - def expect(self, types, msg="unexpected token"): + def expect(self, types: Union[str, Set[str]], msg="unexpected token") -> bool: if isinstance(types, str): types = {types} @@ -255,23 +255,23 @@ def expect(self, types, msg="unexpected token"): else: return True - def skip_until(self, types): + def skip_until(self, types: Set[str]) -> None: types.add(Token.EOF) while self.lexer.tkn.type not in types: self.lexer.next_token() - def err(self, msg, tkn=None): + def err(self, msg: str, tkn: Optional[Token] = None) -> None: if tkn is None: tkn = self.lexer.tkn self.messages.err(msg, tkn.location) - def warn(self, msg, tkn=None): + def warn(self, msg: str, tkn: Optional[Token] = None) -> None: if tkn is None: tkn = self.lexer.tkn self.messages.warn(msg, tkn.location) -def get_canonical_path(fpath): +def get_canonical_path(fpath: str) -> str: if fpath == "-" or fpath == "": return fpath else: diff --git a/hera/utils.py b/hera/utils.py index 1e178c8..051da41 100644 --- a/hera/utils.py +++ b/hera/utils.py @@ -5,10 +5,10 @@ """ import sys -from .data import HERAError, Location, Messages, Token +from .data import HERAError, Location, Messages, Settings, Token -def to_u16(n): +def to_u16(n: int) -> int: """Reinterpret the signed integer `n` as a 16-bit unsigned integer. If `n` is too large for 16 bits, a HERAError is raised. @@ -24,7 +24,7 @@ def to_u16(n): return n -def from_u16(n): +def from_u16(n: int) -> int: """Reinterpret the unsigned 16-bit integer `n` as a signed integer.""" if n >= 2 ** 15: return -(2 ** 16 - n) @@ -32,7 +32,7 @@ def from_u16(n): return n -def to_u32(n): +def to_u32(n: int) -> int: """Reinterpret the signed integer `n` as an unsigned 32-bit integer. If `n` is too large for 32 bits, a HERAError is raised. @@ -51,7 +51,7 @@ def to_u32(n): NAMED_REGISTERS = {"rt": 11, "fp": 14, "sp": 15, "pc_ret": 13, "fp_alt": 12} -def register_to_index(rname): +def register_to_index(rname: str) -> int: """Return the index of the register with the given name in the register array.""" original = rname rname = rname.lower() @@ -64,7 +64,7 @@ def register_to_index(rname): raise HERAError("{} is not a valid register".format(original)) -def format_int(v, *, spec="xdsc"): +def format_int(v: int, *, spec="xdsc") -> str: """Return a string of the form "... = ... = ..." where each ellipsis stands for a formatted integer determined by a character in the `spec` parameter. The following formats are supported: d for decimal, x for hexadecimal, o for octal, b for binary, @@ -104,7 +104,7 @@ def format_int(v, *, spec="xdsc"): return " = ".join(ret) -def print_warning(settings, msg, *, loc=None): +def print_warning(settings: Settings, msg: str, *, loc=None) -> None: if settings.color: msg = ANSI_MAGENTA_BOLD + "Warning" + ANSI_RESET + ": " + msg else: @@ -112,7 +112,7 @@ def print_warning(settings, msg, *, loc=None): print_message(msg, loc=loc) -def print_error(settings, msg, *, loc=None): +def print_error(settings: Settings, msg: str, *, loc=None) -> None: if settings.color: msg = ANSI_RED_BOLD + "Error" + ANSI_RESET + ": " + msg else: @@ -120,7 +120,7 @@ def print_error(settings, msg, *, loc=None): print_message(msg, loc=loc) -def print_message(msg, *, loc=None): +def print_message(msg: str, *, loc=None) -> None: """Print a message to stderr. If `loc` is provided as either a Location object, or a Token object with a `location` field, then the line of code that the location indicates will be printed with the message. @@ -141,14 +141,14 @@ def print_message(msg, *, loc=None): sys.stderr.write(msg + "\n") -def align_caret(line, col): +def align_caret(line: str, col: int) -> str: """Return the whitespace necessary to align a caret to underline the desired column in the line of text. Mainly this means handling tabs. """ return "".join("\t" if c == "\t" else " " for c in line[: col - 1]) -def read_file(path) -> str: +def read_file(path: str) -> str: """Read a file and return its contents.""" try: with open(path, encoding="ascii") as f: @@ -163,11 +163,11 @@ def read_file(path) -> str: raise HERAError("non-ASCII byte in file") -def pad(s, n): +def pad(s: str, n: int) -> str: return (" " * (n - len(s))) + s -def handle_messages(settings, ret_messages_pair): +def handle_messages(settings: Settings, ret_messages_pair): if ( isinstance(ret_messages_pair, tuple) and len(ret_messages_pair) == 2 @@ -198,7 +198,7 @@ def handle_messages(settings, ret_messages_pair): # you can use them unconditionally in your code without worrying about --no-color. -def make_ansi(*params): +def make_ansi(*params) -> str: return "\033[" + ";".join(map(str, params)) + "m"