Skip to content

Commit

Permalink
Add more type annotations and fix some existing ones
Browse files Browse the repository at this point in the history
  • Loading branch information
iafisher committed Feb 7, 2019
1 parent d79031d commit 4e374d0
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 41 deletions.
6 changes: 3 additions & 3 deletions hera/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_labels(
return (symbol_table, messages)


def operation_length(op):
def operation_length(op: AbstractOperation) -> int:
if isinstance(op, RegisterBranch):
if len(op.tokens) == 1 and op.tokens[0].type == Token.SYMBOL:
return 3
Expand All @@ -180,7 +180,7 @@ def operation_length(op):
return 1


def looks_like_a_CONSTANT(op):
def looks_like_a_CONSTANT(op: AbstractOperation) -> bool:
return (
op.name == "CONSTANT"
and len(op.args) == 2
Expand All @@ -189,7 +189,7 @@ def looks_like_a_CONSTANT(op):
)


def out_of_range(n):
def out_of_range(n: int) -> bool:
return n < -32768 or n >= 65536


Expand Down
3 changes: 2 additions & 1 deletion hera/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Version: February 2019
"""
import string
from typing import Optional

from hera.data import HERAError, Location, Messages, Token
from hera.utils import NAMED_REGISTERS
Expand All @@ -14,7 +15,7 @@
class Lexer:
"""A lexer for HERA (and for the debugging mini-language)."""

def __init__(self, text, *, path=None):
def __init__(self, text: str, *, path: Optional[str] = None) -> None:
self.text = text
self.file_lines = text.splitlines()
if self.text.endswith("\n"):
Expand Down
7 changes: 2 additions & 5 deletions hera/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
Version: February 2019
"""
import sys
from typing import Dict, Tuple

from .checker import check
from .data import HERAError, Messages, Program, Settings
from .parser import parse
from .utils import handle_messages, read_file


def load_program(text: str, settings=Settings()) -> Tuple[Program, Dict[str, int]]:
def load_program(text: str, settings=Settings()) -> Program:
"""Parse the string into a program, type-check it, and preprocess it. A tuple
(ops, symbol_table) is returned.
Expand All @@ -23,9 +22,7 @@ def load_program(text: str, settings=Settings()) -> Tuple[Program, Dict[str, int
return handle_messages(settings, check(oplist, settings))


def load_program_from_file(
path: str, settings=Settings()
) -> Tuple[Program, Dict[str, int]]:
def load_program_from_file(path: str, settings=Settings()) -> Program:
"""Convenience function to a read a file and then invoke `load_program_from_str` on
its contents.
"""
Expand Down
13 changes: 8 additions & 5 deletions hera/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import sys
import functools
from typing import Optional

from .data import Settings, VOLUME_QUIET, VOLUME_VERBOSE
from .debugger import debug
Expand All @@ -13,14 +14,14 @@
from .vm import VirtualMachine


def external_main(argv=None):
def external_main(argv=None) -> None:
"""A wrapper around main that ignores its return value, so it is not printed to the
console when the program exits.
"""
main(argv)


def main(argv=None):
def main(argv=None) -> Optional[VirtualMachine]:
"""The main entry point into hera-py."""
arguments = parse_args(argv)
path = arguments["<path>"]
Expand All @@ -45,19 +46,21 @@ def main(argv=None):
if arguments["preprocess"]:
settings.allow_interrupts = True
main_preprocess(path, settings)
return None
elif arguments["debug"]:
main_debug(path, settings)
return None
else:
return main_execute(path, settings)


def main_debug(path, settings):
def main_debug(path: str, settings: Settings) -> None:
"""Debug the program."""
program = load_program_from_file(path, settings)
debug(program, settings)


def main_execute(path, settings):
def main_execute(path: str, settings: Settings) -> VirtualMachine:
"""Execute the program."""
program = load_program_from_file(path, settings)

Expand All @@ -71,7 +74,7 @@ def main_execute(path, settings):
return vm


def main_preprocess(path, settings):
def main_preprocess(path: str, settings: Settings) -> None:
"""Preprocess the program and print it to standard output."""
program = load_program_from_file(path, settings)
if program.data:
Expand Down
26 changes: 13 additions & 13 deletions hera/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Version: February 2019
"""
import os.path
from typing import List, Set, Tuple # noqa: F401
from typing import List, Optional, Set, Tuple, Union # noqa: F401

from .data import HERAError, Messages, Settings, Token
from .lexer import Lexer
Expand Down Expand Up @@ -57,7 +57,7 @@ def parse(self) -> List[AbstractOperation]:
self.messages.extend(self.lexer.messages)
return ops

def match_program(self):
def match_program(self) -> List[AbstractOperation]:
expecting_brace = False
ops = []
while self.lexer.tkn.type != Token.EOF:
Expand Down Expand Up @@ -94,7 +94,7 @@ def match_program(self):

return ops

def match_op(self, name_tkn):
def match_op(self, name_tkn: Token) -> Optional[AbstractOperation]:
"""Match an operation, assuming that self.lexer.tkn is on the left parenthesis.
"""
self.lexer.next_token()
Expand All @@ -110,7 +110,7 @@ def match_op(self, name_tkn):

VALUE_TOKENS = {Token.INT, Token.REGISTER, Token.SYMBOL, Token.STRING, Token.CHAR}

def match_optional_arglist(self):
def match_optional_arglist(self) -> List[Token]:
if self.lexer.tkn.type == Token.RPAREN:
return []

Expand Down Expand Up @@ -142,7 +142,7 @@ def match_optional_arglist(self):
self.lexer.next_token()
return args

def match_value(self):
def match_value(self) -> Token:
if self.lexer.tkn.type == Token.INT:
# Detect zero-prefixed octal numbers.
prefix = self.lexer.tkn.value[:2]
Expand Down Expand Up @@ -176,7 +176,7 @@ def match_value(self):
else:
return self.lexer.tkn

def match_include(self):
def match_include(self) -> List[AbstractOperation]:
root_path = self.lexer.path
tkn = self.lexer.next_token()
msg = "expected quote or angle-bracket delimited string"
Expand Down Expand Up @@ -206,7 +206,7 @@ def match_include(self):
else:
return self.expand_angle_include(tkn)

def handle_cpp_boilerplate(self):
def handle_cpp_boilerplate(self) -> None:
self.lexer.next_token()
if self.expect(Token.LPAREN, "expected left parenthesis"):
self.lexer.next_token()
Expand All @@ -217,7 +217,7 @@ def handle_cpp_boilerplate(self):
self.expect(Token.LBRACE, "expected left curly brace")
self.lexer.next_token()

def expand_angle_include(self, include_path):
def expand_angle_include(self, include_path: Token) -> List[AbstractOperation]:
# There is no check for recursive includes in this function, under the
# assumption that system libraries do not have recursive includes.
if include_path.value == "HERA.h":
Expand All @@ -241,7 +241,7 @@ def expand_angle_include(self, include_path):
self.lexer = old_lexer
return ops

def expect(self, types, msg="unexpected token"):
def expect(self, types: Union[str, Set[str]], msg="unexpected token") -> bool:
if isinstance(types, str):
types = {types}

Expand All @@ -255,23 +255,23 @@ def expect(self, types, msg="unexpected token"):
else:
return True

def skip_until(self, types):
def skip_until(self, types: Set[str]) -> None:
types.add(Token.EOF)
while self.lexer.tkn.type not in types:
self.lexer.next_token()

def err(self, msg, tkn=None):
def err(self, msg: str, tkn: Optional[Token] = None) -> None:
if tkn is None:
tkn = self.lexer.tkn
self.messages.err(msg, tkn.location)

def warn(self, msg, tkn=None):
def warn(self, msg: str, tkn: Optional[Token] = None) -> None:
if tkn is None:
tkn = self.lexer.tkn
self.messages.warn(msg, tkn.location)


def get_canonical_path(fpath):
def get_canonical_path(fpath: str) -> str:
if fpath == "-" or fpath == "<string>":
return fpath
else:
Expand Down
28 changes: 14 additions & 14 deletions hera/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
"""
import sys

from .data import HERAError, Location, Messages, Token
from .data import HERAError, Location, Messages, Settings, Token


def to_u16(n):
def to_u16(n: int) -> int:
"""Reinterpret the signed integer `n` as a 16-bit unsigned integer.
If `n` is too large for 16 bits, a HERAError is raised.
Expand All @@ -24,15 +24,15 @@ def to_u16(n):
return n


def from_u16(n):
def from_u16(n: int) -> int:
"""Reinterpret the unsigned 16-bit integer `n` as a signed integer."""
if n >= 2 ** 15:
return -(2 ** 16 - n)
else:
return n


def to_u32(n):
def to_u32(n: int) -> int:
"""Reinterpret the signed integer `n` as an unsigned 32-bit integer.
If `n` is too large for 32 bits, a HERAError is raised.
Expand All @@ -51,7 +51,7 @@ def to_u32(n):
NAMED_REGISTERS = {"rt": 11, "fp": 14, "sp": 15, "pc_ret": 13, "fp_alt": 12}


def register_to_index(rname):
def register_to_index(rname: str) -> int:
"""Return the index of the register with the given name in the register array."""
original = rname
rname = rname.lower()
Expand All @@ -64,7 +64,7 @@ def register_to_index(rname):
raise HERAError("{} is not a valid register".format(original))


def format_int(v, *, spec="xdsc"):
def format_int(v: int, *, spec="xdsc") -> str:
"""Return a string of the form "... = ... = ..." where each ellipsis stands for a
formatted integer determined by a character in the `spec` parameter. The following
formats are supported: d for decimal, x for hexadecimal, o for octal, b for binary,
Expand Down Expand Up @@ -104,23 +104,23 @@ def format_int(v, *, spec="xdsc"):
return " = ".join(ret)


def print_warning(settings, msg, *, loc=None):
def print_warning(settings: Settings, msg: str, *, loc=None) -> None:
if settings.color:
msg = ANSI_MAGENTA_BOLD + "Warning" + ANSI_RESET + ": " + msg
else:
msg = "Warning: " + msg
print_message(msg, loc=loc)


def print_error(settings, msg, *, loc=None):
def print_error(settings: Settings, msg: str, *, loc=None) -> None:
if settings.color:
msg = ANSI_RED_BOLD + "Error" + ANSI_RESET + ": " + msg
else:
msg = "Error: " + msg
print_message(msg, loc=loc)


def print_message(msg, *, loc=None):
def print_message(msg: str, *, loc=None) -> None:
"""Print a message to stderr. If `loc` is provided as either a Location object, or
a Token object with a `location` field, then the line of code that the location
indicates will be printed with the message.
Expand All @@ -141,14 +141,14 @@ def print_message(msg, *, loc=None):
sys.stderr.write(msg + "\n")


def align_caret(line, col):
def align_caret(line: str, col: int) -> str:
"""Return the whitespace necessary to align a caret to underline the desired
column in the line of text. Mainly this means handling tabs.
"""
return "".join("\t" if c == "\t" else " " for c in line[: col - 1])


def read_file(path) -> str:
def read_file(path: str) -> str:
"""Read a file and return its contents."""
try:
with open(path, encoding="ascii") as f:
Expand All @@ -163,11 +163,11 @@ def read_file(path) -> str:
raise HERAError("non-ASCII byte in file")


def pad(s, n):
def pad(s: str, n: int) -> str:
return (" " * (n - len(s))) + s


def handle_messages(settings, ret_messages_pair):
def handle_messages(settings: Settings, ret_messages_pair):
if (
isinstance(ret_messages_pair, tuple)
and len(ret_messages_pair) == 2
Expand Down Expand Up @@ -198,7 +198,7 @@ def handle_messages(settings, ret_messages_pair):
# you can use them unconditionally in your code without worrying about --no-color.


def make_ansi(*params):
def make_ansi(*params) -> str:
return "\033[" + ";".join(map(str, params)) + "m"


Expand Down

0 comments on commit 4e374d0

Please sign in to comment.