Skip to content

Add support for complex, Decimal, and Fraction suffixes #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions basilisp/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import uuid
from collections import OrderedDict
from datetime import datetime
from decimal import Decimal
from enum import Enum
from fractions import Fraction
from itertools import chain
from typing import (Dict, Iterable, Pattern, Tuple, Optional, List, Union, Callable, Mapping, NamedTuple, cast, Deque,
Any)
Expand Down Expand Up @@ -393,6 +395,8 @@ def _expressionize(body: MixedNodeStream,
_NS_VAR_VALUE = f'{_NS_VAR}.value'

_NS_VAR_NAME = _load_attr(f'{_NS_VAR_VALUE}.name')
_NEW_DECIMAL_FN_NAME = _load_attr(f'{_UTIL_ALIAS}.decimal_from_str')
_NEW_FRACTION_FN_NAME = _load_attr(f'{_UTIL_ALIAS}.fraction')
_NEW_INST_FN_NAME = _load_attr(f'{_UTIL_ALIAS}.inst_from_str')
_NEW_KW_FN_NAME = _load_attr(f'{_KW_ALIAS}.keyword')
_NEW_LIST_FN_NAME = _load_attr(f'{_LIST_ALIAS}.list')
Expand Down Expand Up @@ -1570,16 +1574,29 @@ def _sym_ast(ctx: CompilerContext, form: sym.Symbol) -> ASTStream:
ctx=ast.Load()))


def _regex_ast(_: CompilerContext, form: Pattern) -> ASTStream:
def _decimal_ast(_: CompilerContext, form: Decimal) -> ASTStream:
yield _node(ast.Call(
func=_NEW_REGEX_FN_NAME, args=[ast.Str(form.pattern)], keywords=[]))
func=_NEW_DECIMAL_FN_NAME, args=[ast.Str(str(form))], keywords=[]))


def _fraction_ast(_: CompilerContext, form: Fraction) -> ASTStream:
yield _node(ast.Call(
func=_NEW_FRACTION_FN_NAME,
args=[ast.Num(form.numerator),
ast.Num(form.denominator)],
keywords=[]))


def _inst_ast(_: CompilerContext, form: datetime) -> ASTStream:
yield _node(ast.Call(
func=_NEW_INST_FN_NAME, args=[ast.Str(form.isoformat())], keywords=[]))


def _regex_ast(_: CompilerContext, form: Pattern) -> ASTStream:
yield _node(ast.Call(
func=_NEW_REGEX_FN_NAME, args=[ast.Str(form.pattern)], keywords=[]))


def _uuid_ast(_: CompilerContext, form: uuid.UUID) -> ASTStream:
yield _node(ast.Call(
func=_NEW_UUID_FN_NAME, args=[ast.Str(str(form))], keywords=[]))
Expand Down Expand Up @@ -1675,12 +1692,18 @@ def _to_ast(ctx: CompilerContext, form: LispForm) -> ASTStream: # pylint: disab
elif isinstance(form, float):
yield _node(ast.Num(form))
return
elif isinstance(form, int):
elif isinstance(form, (complex, int)):
yield _node(ast.Num(form))
return
elif isinstance(form, datetime):
yield from _inst_ast(ctx, form)
return
elif isinstance(form, Decimal):
yield from _decimal_ast(ctx, form)
return
elif isinstance(form, Fraction):
yield from _fraction_ast(ctx, form)
return
elif isinstance(form, uuid.UUID):
yield from _uuid_ast(ctx, form)
return
Expand Down
7 changes: 4 additions & 3 deletions basilisp/lang/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from datetime import datetime
from decimal import Decimal
from fractions import Fraction
from typing import Union, Pattern

Expand All @@ -11,7 +12,7 @@
import basilisp.lang.vector as vec

LispNumber = Union[int, float, Fraction]
LispForm = Union[bool, datetime, int, float, kw.Keyword, llist.List,
lmap.Map, None, Pattern, lset.Set, str, sym.Symbol,
vec.Vector, uuid.UUID]
LispForm = Union[bool, complex, datetime, Decimal, int, float, Fraction,
kw.Keyword, llist.List, lmap.Map, None, Pattern, lset.Set,
str, sym.Symbol, vec.Vector, uuid.UUID]
IterableLispForm = Union[llist.List, lmap.Map, lset.Set, vec.Vector]
23 changes: 19 additions & 4 deletions basilisp/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import keyword
import re
import uuid
from decimal import Decimal
from fractions import Fraction
from typing import Pattern

Expand All @@ -21,16 +22,20 @@ def lrepr(f) -> str:
return "nil"
elif isinstance(f, str):
return f'"{f}"'
elif isinstance(f, complex):
return repr(f).upper()
elif isinstance(f, datetime.datetime):
inst_str = f.isoformat()
return f'#inst "{inst_str}"'
elif isinstance(f, Decimal):
return str(f)
elif isinstance(f, Fraction):
return f"{f.numerator}/{f.denominator}"
elif isinstance(f, Pattern):
return f'#"{f.pattern}"'
elif isinstance(f, uuid.UUID):
uuid_str = str(f)
return f'#uuid "{uuid_str}"'
elif isinstance(f, Pattern):
return f'#"{f.pattern}"'
elif isinstance(f, Fraction):
return f"{f.numerator}/{f.denominator}"
else:
return repr(f)

Expand Down Expand Up @@ -85,6 +90,16 @@ def genname(prefix: str) -> str:
return f"{prefix}_{i}"


def decimal_from_str(decimal_str: str) -> Decimal:
"""Create a Decimal from a numeric string."""
return Decimal(decimal_str)


def fraction(numerator: int, denominator: int) -> Fraction:
"""Create a Fraction from a numerator and denominator."""
return Fraction(numerator=numerator, denominator=denominator)


def inst_from_str(inst_str: str) -> datetime.datetime:
"""Create a datetime instance from an RFC 3339 formatted date string."""
return dateparser.parse(inst_str)
Expand Down
58 changes: 52 additions & 6 deletions basilisp/reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import collections
import contextlib
import decimal
import functools
import io
import re
import uuid
from datetime import datetime
from fractions import Fraction
from typing import (Deque, List, Tuple, Optional, Collection, Callable, Any, Union, MutableMapping, Pattern, Iterable,
TypeVar, cast, Dict)

Expand Down Expand Up @@ -408,14 +410,19 @@ def _read_map(ctx: ReaderContext) -> lmap.Map:
# special keywords `true`, `false`, and `nil`, we have to have a looser
# type defined for the return from these reader functions.
MaybeSymbol = Union[bool, None, symbol.Symbol]
MaybeNumber = Union[float, int, MaybeSymbol]
MaybeNumber = Union[complex, decimal.Decimal, float, Fraction, int, MaybeSymbol]


def _read_num(ctx: ReaderContext) -> MaybeNumber:
"""Return a numeric (integer or float) from the input stream."""
def _read_num(ctx: ReaderContext) -> MaybeNumber: # noqa: C901 # pylint: disable=too-many-statements
"""Return a numeric (complex, Decimal, float, int, Fraction) from the input stream."""
chars: List[str] = []
reader = ctx.reader

is_complex = False
is_decimal = False
is_float = False
is_integer = False
is_ratio = False
while True:
token = reader.peek()
if token == '-':
Expand All @@ -435,16 +442,55 @@ def _read_num(ctx: ReaderContext) -> MaybeNumber:
raise SyntaxError(
"Found extra '.' in float; expected decimal portion")
is_float = True
elif token == 'J':
if is_complex:
raise SyntaxError("Found extra 'J' suffix in complex literal")
is_complex = True
elif token == 'M':
if is_decimal:
raise SyntaxError("Found extra 'M' suffix in decimal literal")
is_decimal = True
elif token == 'N':
if is_integer:
raise SyntaxError("Found extra 'N' suffix in integer literal")
is_integer = True
elif token == '/':
if is_ratio:
raise SyntaxError("Found extra '/' in ratio literal")
is_ratio = True
elif not num_chars.match(token):
break
reader.next_token()
chars.append(token)

if len(chars) == 0:
raise SyntaxError("Expected integer or float")
assert len(chars) > 0, "Must have at least one digit in integer or float"

s = ''.join(chars)
return float(s) if is_float else int(s)
if sum([is_complex and is_decimal,
is_complex and is_integer,
is_complex and is_ratio,
is_decimal or is_float,
is_integer,
is_ratio]) > 1:
raise SyntaxError(f"Invalid number format: {s}")

if is_complex:
imaginary = float(s[:-1]) if is_float else int(s[:-1])
return complex(0, imaginary)
elif is_decimal:
try:
return decimal.Decimal(s[:-1])
except decimal.InvalidOperation:
raise SyntaxError(f"Invalid number format: {s}") from None
elif is_float:
return float(s)
elif is_ratio:
assert "/" in s, "Ratio must contain one '/' character"
num, denominator = s.split('/')
return Fraction(numerator=int(num), denominator=int(denominator))
elif is_integer:
return int(s[:-1])
return int(s)


def _read_str(ctx: ReaderContext) -> str:
Expand Down
44 changes: 34 additions & 10 deletions tests/compiler_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import decimal
import re
import types
import uuid
from fractions import Fraction
from typing import Optional
from unittest.mock import Mock

Expand Down Expand Up @@ -77,12 +79,30 @@ def test_string():


def test_int():
assert lcompile('1') == 1
assert lcompile('100') == 100
assert lcompile('99927273') == 99927273
assert lcompile('0') == 0
assert lcompile('-1') == -1
assert lcompile('-538282') == -538282
assert 1 == lcompile('1')
assert 100 == lcompile('100')
assert 99927273 == lcompile('99927273')
assert 0 == lcompile('0')
assert -1 == lcompile('-1')
assert -538282 == lcompile('-538282')

assert 1 == lcompile('1N')
assert 100 == lcompile('100N')
assert 99927273 == lcompile('99927273N')
assert 0 == lcompile('0N')
assert -1 == lcompile('-1N')
assert -538282 == lcompile('-538282N')


def test_decimal():
assert decimal.Decimal('0.0') == lcompile('0.0M')
assert decimal.Decimal('0.09387372') == lcompile('0.09387372M')
assert decimal.Decimal('1.0') == lcompile('1.0M')
assert decimal.Decimal('1.332') == lcompile('1.332M')
assert decimal.Decimal('-1.332') == lcompile('-1.332M')
assert decimal.Decimal('-1.0') == lcompile('-1.0M')
assert decimal.Decimal('-0.332') == lcompile('-0.332M')
assert decimal.Decimal('3.14') == lcompile('3.14M')


def test_float():
Expand Down Expand Up @@ -659,15 +679,19 @@ def test_var(ns_var: Var):
assert v.value == "a value"


def test_fraction(ns_var: Var):
assert Fraction('22/7') == lcompile('22/7')


def test_inst(ns_var: Var):
assert lcompile('#inst "2018-01-18T03:26:57.296-00:00"'
) == dateparser.parse('2018-01-18T03:26:57.296-00:00')
assert dateparser.parse('2018-01-18T03:26:57.296-00:00') == lcompile(
'#inst "2018-01-18T03:26:57.296-00:00"')


def test_regex(ns_var: Var):
assert lcompile('#"\s"') == re.compile('\s')


def test_uuid(ns_var: Var):
assert lcompile('#uuid "0366f074-a8c5-4764-b340-6a5576afd2e8"'
) == uuid.UUID('{0366f074-a8c5-4764-b340-6a5576afd2e8}')
assert uuid.UUID('{0366f074-a8c5-4764-b340-6a5576afd2e8}') == lcompile(
'#uuid "0366f074-a8c5-4764-b340-6a5576afd2e8"')
Loading