Skip to content

Commit f763e20

Browse files
fix(get_terms_ex): support negated implicit coefficients
- e.g. "-x" -> (-1, "x", None) - "-x^3" -> (-1, "x", 3)
1 parent ad23139 commit f763e20

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

libraries/mathy_python/mathy/util.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
1-
import os
2-
import numpy as np
3-
from . import time_step
4-
from pydantic import BaseModel
51
import json
62
import math
7-
from pathlib import Path
8-
from typing import Dict, List, NamedTuple, Optional, Union, Any
3+
import os
94
import re
10-
from typing import Dict, List, Tuple
11-
from wasabi import TracebackPrinter
5+
from pathlib import Path
6+
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union, cast
127

138
import numpy as np
9+
from pydantic import BaseModel
10+
from wasabi import TracebackPrinter
1411

15-
from .types import EnvRewards, Literal
16-
12+
from . import time_step
1713
from .core.expressions import (
1814
AddExpression,
1915
BinaryExpression,
@@ -28,6 +24,7 @@
2824
)
2925
from .core.parser import ExpressionParser
3026
from .core.tree import LEFT
27+
from .types import EnvRewards, Literal
3128

3229

3330
def is_debug_mode() -> bool:
@@ -523,6 +520,18 @@ def get_term_ex(node: Optional[MathExpression]) -> Optional[TermEx]:
523520
TermEx(coefficient=4, variable="x", exponent=7)
524521
```
525522
"""
523+
if isinstance(node, NegateExpression):
524+
child = cast(NegateExpression, node).get_child()
525+
526+
# "-x"
527+
if isinstance(child, VariableExpression):
528+
return TermEx(-1, child.identifier, None)
529+
# "-x^2"
530+
if isinstance(child, PowerExpression):
531+
if isinstance(child.left, VariableExpression) and isinstance(
532+
child.right, ConstantExpression
533+
):
534+
return TermEx(-1, child.left.identifier, child.right.value)
526535

527536
# "4"
528537
if isinstance(node, ConstantExpression):
@@ -818,4 +827,3 @@ def print_error(error, text, print_error=True):
818827
print(caught_at + caught_error)
819828
else:
820829
raise ValueError(caught_at + caught_error)
821-

libraries/mathy_python/tests/rules/distributive_factor_out.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
"valid": "two nodes connected by a plus can have common units extracted out"
44
},
55
"valid": [
6+
{
7+
"input": "g + -x^3 + 4x^3 + 19p^4 + -1y",
8+
"output": "g + (-1 + 4) * x^3 + 19p^4 + -1y",
9+
"why": "invariance to sibling grouping"
10+
},
611
{
712
"input": "5.8c + (3393c + 6o + -8614k)",
813
"output": "(5.8 + 3393) * c + (6o + -8614k)",

libraries/mathy_python/tests/test_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def test_get_sub_terms():
4747

4848
def test_get_term_ex():
4949
examples = [
50+
("-y", TermEx(-1, "y", None)),
51+
("-x^3", TermEx(-1, "x", 3)),
52+
("-2x^3", TermEx(-2, "x", 3)),
5053
("4x^2", TermEx(4, "x", 2)),
5154
("4x", TermEx(4, "x", None)),
5255
("x", TermEx(None, "x", None)),

0 commit comments

Comments
 (0)