Skip to content

Commit

Permalink
Update to Parsita 1.8.0b1
Browse files Browse the repository at this point in the history
  • Loading branch information
drhagen committed Jun 15, 2023
1 parent 58a5f8e commit ed72e92
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
cffi >= 1.11.0
parsita >= 1.2.0
parsita == 1.8.0b2
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def finalize_options(self):
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
],

cmdclass={
Expand Down
4 changes: 2 additions & 2 deletions src/tensora/expression/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from functools import reduce

from parsita import TextParsers, lit, reg, rep, rep1sep, Result
from parsita import ParserContext, lit, reg, rep, rep1sep, Result
from parsita.util import splat

from .ast import Assignment, Add, Subtract, Multiply, Tensor, Scalar, Integer, Float
Expand All @@ -18,7 +18,7 @@ def make_expression(first, rest):
return value


class TensorExpressionParsers(TextParsers):
class TensorExpressionParsers(ParserContext, whitespace=r'[ ]*'):
name = reg(r'[A-Za-z][A-Za-z0-9]*')

# taco does not support negatives or exponents
Expand Down
10 changes: 5 additions & 5 deletions src/tensora/format/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ['parse_format']

from parsita import TextParsers, reg, lit, rep, eof, Result, Success, Failure
from parsita import ParserContext, reg, lit, rep, eof, Result, Success, Failure, ParseError
from parsita.util import constant

from .format import Mode, Format
Expand All @@ -15,7 +15,7 @@ def make_format_with_orderings(dims):
return Format(tuple(modes), tuple(orderings))


class FormatTextParsers(TextParsers, whitespace=None):
class FormatTextParsers(ParserContext):
integer = reg(r'[0-9]+') > int
dense = lit('d') > constant(Mode.dense)
compressed = lit('s') > constant(Mode.compressed)
Expand All @@ -33,9 +33,9 @@ def parse_format(format: str) -> Result[Format]:
if isinstance(parse_result, Failure):
return parse_result
elif isinstance(parse_result, Success):
parse_value = parse_result.value
parse_value = parse_result.unwrap()
if set(range(parse_value.order)) != set(parse_value.ordering):
return Failure(f'Format ordering must be some order of the set {set(range(parse_value.order))} not '
f'{parse_value.ordering}')
return Failure(ParseError(f'Format ordering must be some order of the set {set(range(parse_value.order))} '
f'not {parse_value.ordering}'))
else:
return parse_result
6 changes: 3 additions & 3 deletions src/tensora/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ def tensor_method(assignment: str, input_formats: Dict[str, str], output_format:
def cachable_tensor_method(assignment: str, input_formats: Tuple[Tuple[str, str], ...], output_format: str
) -> PureTensorMethod:
from .expression.parser import parse_assignment
parsed_assignment = parse_assignment(assignment).or_die()
parsed_assignment = parse_assignment(assignment).unwrap()

parsed_input_formats = {name: parse_format(format).or_die() for name, format in input_formats}
parsed_input_formats = {name: parse_format(format).unwrap() for name, format in input_formats}

parsed_output = parse_format(output_format).or_die()
parsed_output = parse_format(output_format).unwrap()

if parsed_assignment.is_mutating():
raise NotImplementedError(f'Mutating tensor assignments like {assignment} not implemented yet.')
Expand Down
2 changes: 1 addition & 1 deletion src/tensora/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def from_aos(coordinates: Iterable[Tuple[int, ...]], values: Iterable[float], *,
coordinates = list(coordinates)
format = default_format_given_nnz(dimensions, len(coordinates))
elif isinstance(format, str):
format = parse_format(format).or_die()
format = parse_format(format).unwrap()

# Reorder with first level first, etc.
level_dimensions = tuple(dimensions[i] for i in format.ordering)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@pytest.mark.parametrize('string,assignment', assignment_strings)
def test_assignment_parsing(string, assignment):
actual = parse_assignment(string).or_die()
actual = parse_assignment(string).unwrap()
assert actual == assignment


Expand All @@ -34,7 +34,7 @@ def test_assignment_deparsing(string, assignment):


def parse(string):
return parse_assignment(string).or_die()
return parse_assignment(string).unwrap()


def test_assignment_to_string():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@pytest.mark.parametrize('string,format', format_strings)
def test_parse_format(string, format):
actual = parse_format(string).or_die()
actual = parse_format(string).unwrap()
assert actual == format


Expand Down

0 comments on commit ed72e92

Please sign in to comment.