From 73ce5dfdffc495d842daa298b1bb05662f08ecc9 Mon Sep 17 00:00:00 2001 From: KarlLundengaard Date: Fri, 3 Nov 2023 18:38:11 +0000 Subject: [PATCH] Fixed bug in input symbol substitution - Input symbols substitutions did not take elementary function names into account - Added exception input symbols handling so that "I" is treated as the imaginary constant if "complexNumbers" is set to true, regardless of if there is an input symbol with code "I" or not. --- app/expression_utilities.py | 28 ++++++++++++++++----- app/symbolic_comparison_evaluation_tests.py | 27 +++++++++++++++++++- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/app/expression_utilities.py b/app/expression_utilities.py index e2ec82f..d0c1980 100644 --- a/app/expression_utilities.py +++ b/app/expression_utilities.py @@ -222,6 +222,17 @@ def substitute_input_symbols(exprs, params): substitutions = [(expr, expr) for expr in params.get("reserved_keywords",[])] + if params.get("elementary_functions", False) is True: + alias_substitutions = [] + for expr in exprs: + for (name, alias_list) in elementary_functions_names+special_symbols_names: + if name in expr: + alias_substitutions += [(name, " "+name)] + for alias in alias_list: + if alias in expr: + alias_substitutions += [(alias, " "+name)] + substitutions += alias_substitutions + input_symbols = params.get("symbols",dict()) if "symbols" in params.keys(): @@ -497,12 +508,11 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass parse_expression function. ''' + unsplittable_symbols = list(unsplittable_symbols) if "symbols" in params.keys(): - to_keep = [] for symbol in params["symbols"].keys(): if len(symbol) > 1: - to_keep.append(symbol) - unsplittable_symbols += tuple(to_keep) + unsplittable_symbols.append(symbol) if params.get("specialFunctions", False) is True: from sympy import beta, gamma, zeta @@ -512,6 +522,12 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass zeta = Symbol("zeta") if params.get("complexNumbers", False) is True: from sympy import I +# imaginary_constant_index = None +# for (k, symbol) in enumerate(unsplittable_symbols): +# if "I" == symbol[0]: +# imaginary_constant_index = k +# if imaginary_constant_index is not None: +# unsplittable_symbols = unsplittable_symbols[0:imaginary_constant_index]+unsplittable_symbols[imaginary_constant_index+1:] else: I = Symbol("I") if params.get("elementary_functions", False) is True: @@ -535,10 +551,10 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass "E": E } - for symbol in unsplittable_symbols: - symbol_dict.update({symbol: Symbol(symbol)}) +# for symbol in unsplittable_symbols: +# symbol_dict.update({symbol: Symbol(symbol)}) - symbol_dict.update(sympy_symbols(params.get("symbols", {}))) + symbol_dict.update(sympy_symbols(unsplittable_symbols)) strict_syntax = params.get("strict_syntax", True) diff --git a/app/symbolic_comparison_evaluation_tests.py b/app/symbolic_comparison_evaluation_tests.py index 8b56f0e..415c8ae 100644 --- a/app/symbolic_comparison_evaluation_tests.py +++ b/app/symbolic_comparison_evaluation_tests.py @@ -1075,7 +1075,32 @@ def test_no_reserved_keywords_in_old_format_input_symbol_alternatives(self): ("5*exp(lambda*x)/(1+5*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", True, [], {}), ("6*exp(lambda*x)/(1+7*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", False, [], {}), ("c*exp(lambda*x)/(1+c*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", True, [], {}), - ("-A/r^2*cos(omega*t-k*r)+k*A/r*sin(omega*t-k*r)", "(-A/(r**2))*exp(I*(omega*t-k*r))*(1+I*k*r)", "re(response)=re(answer)", True, [], {"complexNumbers": True, "symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')"}), + ("-A/r^2*cos(omega*t-k*r)+k*A/r*sin(omega*t-k*r)", "(-A/(r**2))*exp(i*(omega*t-k*r))*(1+i*k*r)", "re(response)=re(answer)", True, [], + { + "complexNumbers": True, + "symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')", + 'symbols': { + 'r': {'aliases': ['R'], 'latex': r'\(r\)'}, + 'A': {'aliases': ['a'], 'latex': r'\(A\)'}, + 'omega': {'aliases': ['OMEGA', 'Omega'], 'latex': r'\(\omega\)'}, + 'k': {'aliases': ['K'], 'latex': r'\(k\)'}, + 't': {'aliases': ['T'], 'latex': r'\(t\)'}, + 'I': {'aliases': ['i'], 'latex': r'\(i\)'}, + } + }), + ("-A/r^2*(cos(omega*t-kr)+I*sin(omega*t-kr))*(1+Ikr)", "(-A/(r**2))*exp(I*(omega*t-k*r))*(1+I*k*r)", "re(response)=re(answer)", True, [], + { + "complexNumbers": True, + "symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')", + 'symbols': { + 'r': {'aliases': ['R'], 'latex': r'\(r\)'}, + 'A': {'aliases': ['a'], 'latex': r'\(A\)'}, + 'omega': {'aliases': ['OMEGA', 'Omega'], 'latex': r'\(\omega\)'}, + 'k': {'aliases': ['K'], 'latex': r'\(k\)'}, + 't': {'aliases': ['T'], 'latex': r'\(t\)'}, + 'I': {'aliases': ['i'], 'latex': r'\(i\)'}, + } + }), ] ) def test_criteria_based_comparison(self, response, answer, criteria, value, feedback_tags, additional_params):