Skip to content

Commit

Permalink
Merge pull request #140 from lambda-feedback/tr114-comparing-real-and…
Browse files Browse the repository at this point in the history
…-imaginary-as-criteria

Fixed bug in input symbol substitution
  • Loading branch information
KarlLundengaard committed Nov 3, 2023
2 parents 1089316 + 73ce5df commit cb817bd
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
28 changes: 22 additions & 6 deletions app/expression_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ def substitute_input_symbols(exprs, params):

substitutions = [(expr, expr) for expr in params.get("reserved_keywords",[])]

if params.get("elementary_functions", False) is True:
alias_substitutions = []
for expr in exprs:
for (name, alias_list) in elementary_functions_names+special_symbols_names:
if name in expr:
alias_substitutions += [(name, " "+name)]
for alias in alias_list:
if alias in expr:
alias_substitutions += [(alias, " "+name)]
substitutions += alias_substitutions

input_symbols = params.get("symbols",dict())

if "symbols" in params.keys():
Expand Down Expand Up @@ -497,12 +508,11 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass
parse_expression function.
'''

unsplittable_symbols = list(unsplittable_symbols)
if "symbols" in params.keys():
to_keep = []
for symbol in params["symbols"].keys():
if len(symbol) > 1:
to_keep.append(symbol)
unsplittable_symbols += tuple(to_keep)
unsplittable_symbols.append(symbol)

if params.get("specialFunctions", False) is True:
from sympy import beta, gamma, zeta
Expand All @@ -512,6 +522,12 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass
zeta = Symbol("zeta")
if params.get("complexNumbers", False) is True:
from sympy import I
# imaginary_constant_index = None
# for (k, symbol) in enumerate(unsplittable_symbols):
# if "I" == symbol[0]:
# imaginary_constant_index = k
# if imaginary_constant_index is not None:
# unsplittable_symbols = unsplittable_symbols[0:imaginary_constant_index]+unsplittable_symbols[imaginary_constant_index+1:]
else:
I = Symbol("I")
if params.get("elementary_functions", False) is True:
Expand All @@ -535,10 +551,10 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass
"E": E
}

for symbol in unsplittable_symbols:
symbol_dict.update({symbol: Symbol(symbol)})
# for symbol in unsplittable_symbols:
# symbol_dict.update({symbol: Symbol(symbol)})

symbol_dict.update(sympy_symbols(params.get("symbols", {})))
symbol_dict.update(sympy_symbols(unsplittable_symbols))

strict_syntax = params.get("strict_syntax", True)

Expand Down
27 changes: 26 additions & 1 deletion app/symbolic_comparison_evaluation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,32 @@ def test_no_reserved_keywords_in_old_format_input_symbol_alternatives(self):
("5*exp(lambda*x)/(1+5*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", True, [], {}),
("6*exp(lambda*x)/(1+7*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", False, [], {}),
("c*exp(lambda*x)/(1+c*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", True, [], {}),
("-A/r^2*cos(omega*t-k*r)+k*A/r*sin(omega*t-k*r)", "(-A/(r**2))*exp(I*(omega*t-k*r))*(1+I*k*r)", "re(response)=re(answer)", True, [], {"complexNumbers": True, "symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')"}),
("-A/r^2*cos(omega*t-k*r)+k*A/r*sin(omega*t-k*r)", "(-A/(r**2))*exp(i*(omega*t-k*r))*(1+i*k*r)", "re(response)=re(answer)", True, [],
{
"complexNumbers": True,
"symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')",
'symbols': {
'r': {'aliases': ['R'], 'latex': r'\(r\)'},
'A': {'aliases': ['a'], 'latex': r'\(A\)'},
'omega': {'aliases': ['OMEGA', 'Omega'], 'latex': r'\(\omega\)'},
'k': {'aliases': ['K'], 'latex': r'\(k\)'},
't': {'aliases': ['T'], 'latex': r'\(t\)'},
'I': {'aliases': ['i'], 'latex': r'\(i\)'},
}
}),
("-A/r^2*(cos(omega*t-kr)+I*sin(omega*t-kr))*(1+Ikr)", "(-A/(r**2))*exp(I*(omega*t-k*r))*(1+I*k*r)", "re(response)=re(answer)", True, [],
{
"complexNumbers": True,
"symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')",
'symbols': {
'r': {'aliases': ['R'], 'latex': r'\(r\)'},
'A': {'aliases': ['a'], 'latex': r'\(A\)'},
'omega': {'aliases': ['OMEGA', 'Omega'], 'latex': r'\(\omega\)'},
'k': {'aliases': ['K'], 'latex': r'\(k\)'},
't': {'aliases': ['T'], 'latex': r'\(t\)'},
'I': {'aliases': ['i'], 'latex': r'\(i\)'},
}
}),
]
)
def test_criteria_based_comparison(self, response, answer, criteria, value, feedback_tags, additional_params):
Expand Down

0 comments on commit cb817bd

Please sign in to comment.