diff --git a/pythonwhat/check_funcs.py b/pythonwhat/check_funcs.py index 5ec63470..a8254c15 100644 --- a/pythonwhat/check_funcs.py +++ b/pythonwhat/check_funcs.py @@ -74,7 +74,8 @@ def check_node(name, index, typestr, missing_msg=MSG_MISSING, expand_msg=MSG_PRE # check if there are enough nodes for index fmt_kwargs = {'ordinal': get_ord(index+1) if isinstance(index, int) else "", - 'index': index} + 'index': index, + 'name': name} fmt_kwargs['typestr'] = typestr.format(**fmt_kwargs) # test if node can be indexed succesfully @@ -267,7 +268,8 @@ def check_args(name, missing_msg='FMT:Are you sure it is defined?', state=None): if name in ['*args', '**kwargs']: return check_part(name, name, state=state, missing_msg = missing_msg) else: - return check_part_index('args', name, "argument `%s`"%name, state=state, missing_msg = missing_msg) + arg_str = "%s argument"%get_ord(name+1) if isinstance(name, int) else "argument `%s`"%name + return check_part_index('args', name, arg_str, state=state, missing_msg = missing_msg) # CALL CHECK ================================================================== @@ -369,6 +371,18 @@ def call(args, # Expression tests ------------------------------------------------------------ from pythonwhat.tasks import ReprFail, UndefinedValue from pythonwhat import utils + +def has_equal_ast(incorrect_msg="FMT: Your code does not seem to match the solution.", state=None): + rep = Reporter.active_reporter + + stu_rep = ast.dump(state.student_tree) + sol_rep = ast.dump(state.solution_tree) + + _msg = state.build_message(incorrect_msg) + rep.do_test(EqualTest(stu_rep, sol_rep, Feedback(_msg, state.highlight))) + + return state + def has_expr(incorrect_msg="FMT:Unexpected expression {test}: expected `{sol_eval}`, got `{stu_eval}` with values{extra_env}.", error_msg="Running an expression in the student process caused an issue.", undefined_msg="FMT:Have you defined `{name}` without errors?", diff --git a/pythonwhat/check_function.py b/pythonwhat/check_function.py index 9f31a86b..7b734176 100644 --- a/pythonwhat/check_function.py +++ b/pythonwhat/check_function.py @@ -1,39 +1,72 @@ -from pythonwhat.check_funcs import check_node -from pythonwhat.test_funcs.test_function import mapped_name +from pythonwhat.Reporter import Reporter +from pythonwhat.check_funcs import part_to_child +from pythonwhat.test_funcs.test_function import bind_args from pythonwhat.tasks import getSignatureInProcess +from pythonwhat.utils import get_ord +from pythonwhat.Test import Test +from pythonwhat.Feedback import Feedback +from pythonwhat.parsing import IndexedDict from functools import partial -def check_function(name, index=0, - missing_msg = "Did you define {sol_part[name]}?", - expand_msg = "In your definition of {sol_part[name]}, ", +def bind_args(signature, args_part): + pos_args = []; kw_args = {} + for k, arg in args_part.items(): + if isinstance(k, int): pos_args.append(arg) + else: kw_args[k] = arg + + bound_args = signature.bind(*pos_args, **kw_args) + + return (IndexedDict(bound_args.arguments), signature) + +MSG_PREPEND = "__JINJA__:Check your code in the {{child['part']+ ' of the' if child['part']}} {{typestr}}. " +def check_function(name, index, + missing_msg = "FMT:Did you define {typestr}?", + params_not_matched_msg = "FMT:Something went wrong in figuring out how you specified the " + "arguments for `{name}`; have another look at your code and its output.", + expand_msg = MSG_PREPEND, + signature=None, + typestr = "{ordinal} function call", state=None): rep = Reporter.active_reporter stu_out = state.student_function_calls sol_out = state.solution_function_calls - # test if function exists - stud_name = get_mapped_name(name, state.student_mappings) - - func_list = check_node('function_calls', name, 'function call', missing_msg, expand_msg, state) - # get function state - if index is None: - return func_list - else: - # TODO make has_part more robust - # grab specific function call - child_func = check_part(index, "FUNCTION MSG", func_list, "not enough func calls") - stu_parts, sol_parts = child_func.student_parts, child_func.solution_parts - # Signatures + fmt_kwargs = {'ordinal': get_ord(index+1), + 'index': index, + 'name': name} + fmt_kwargs['typestr'] = typestr.format(**fmt_kwargs) + + # Get Parts ---- + try: + stu_parts = stu_out[name][index] + except (KeyError, IndexError): + _msg = state.build_message(missing_msg, fmt_kwargs) + rep.do_test(Test(Feedback(_msg, state.highlight))) + + sol_parts = sol_out[name][index] + + # Signatures ----- + if signature: + signature = None if isinstance(signature, bool) else signature get_sig = partial(getSignatureInProcess, name=name, signature=signature, - manual_sigs = state.get_manual_sigs()) + manual_sigs = state.get_manual_sigs()) - # TODO if can't parse, raise warnings - sol_sig = get_sig(mapped_name=sol_parts['name'], process=solution_process) - sol_parts['args'], _ = bind_ards(sol_sig, sol_parts['pos_args'], sol_parts['keywords']) + try: + sol_sig = get_sig(mapped_name=sol_parts['name'], process=state.solution_process) + sol_parts['args'], _ = bind_args(sol_sig, sol_parts['args']) + except: + raise ValueError("Something went wrong in matching call index {index} of {name} to its signature. " + "You might have to manually specify or correct the signature." + .format(index=index, name=name)) - # TODO if can't parse sig, send failed test msg - stu_sig = get_sig(mapped_name=stu_parts['name'], process=student_process) - stu_parts['args'], _ = bind_ards(stu_sig, stu_parts['pos_args'], stu_parts['keywords']) + try: + stu_sig = get_sig(mapped_name=stu_parts['name'], process=state.student_process) + stu_parts['args'], _ = bind_args(stu_sig, stu_parts['args']) + except Exception as e: + _msg = state.build_message(params_not_matched_msg, fmt_kwargs) + rep.do_test(Test(Feedback(_msg, state.highlight))) - # three types of parts: pos_args, keywords, args (e.g. these are bound to sig) - return child_func + # three types of parts: pos_args, keywords, args (e.g. these are bound to sig) + append_message = {'msg': expand_msg, 'kwargs': fmt_kwargs} + child = part_to_child(stu_parts, sol_parts, append_message, state, node_name='function_calls') + return child diff --git a/pythonwhat/check_wrappers.py b/pythonwhat/check_wrappers.py index dd80b133..c4561485 100644 --- a/pythonwhat/check_wrappers.py +++ b/pythonwhat/check_wrappers.py @@ -1,5 +1,6 @@ from pythonwhat.check_funcs import check_part, check_part_index, check_node, has_equal_part from pythonwhat import check_funcs, check_object +from pythonwhat.check_function import check_function from pythonwhat.test_funcs.test_data_frame import check_df from pythonwhat.test_funcs.test_dictionary import check_dict from pythonwhat import test_funcs @@ -53,9 +54,10 @@ for k, v in __NODE_WRAPPERS__.items(): scts['check_'+k] = partial(check_node, k+'s', typestr=v) +scts['check_function'] = check_function for k in ['set_context', - 'has_equal_value', 'has_equal_output', 'has_equal_error', 'call', + 'has_equal_value', 'has_equal_output', 'has_equal_error', 'has_equal_ast', 'call', 'extend', 'multi', 'test_not', 'fail', 'quiet', 'with_context', 'check_args', diff --git a/setup.py b/setup.py index 0c277ba3..890a497e 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name='pythonwhat', - version='2.1.2', + version='2.2.0', packages=['pythonwhat', 'pythonwhat.test_funcs'], install_requires=["dill", "IPython", "numpy", "pandas", "markdown2"] ) diff --git a/tests/test_test_function.py b/tests/test_test_function.py index 8b107d72..6afb47e9 100644 --- a/tests/test_test_function.py +++ b/tests/test_test_function.py @@ -34,6 +34,13 @@ def test_Pass(self): self.assertTrue(sct_payload['correct']) self.assertEqual(sct_payload['message'], "Great!") + def test_Pass_spec2(self): + self.data['DC_SCT'] = """ +Ex().check_function('print', 0).check_args(0).has_equal_ast() +""" + sct_payload = helper.run(self.data) + self.assertTrue(sct_payload['correct']) + class TestFunctionExerciseNumpy(unittest.TestCase): def setUp(self): @@ -599,9 +606,7 @@ def setUp(self): test_function("print", index = 3, highlight=True) ''' } - self.DC_SCT_SPEC2 = ''' -Ex().check_function("print", 0).check_arg(0).has_equal_value() - ''' + def test_multiple_1(self): self.data["DC_CODE"] = 'print("abc")' sct_payload = helper.run(self.data) @@ -659,6 +664,110 @@ def test_nohighlight_too_few_calls(self): self.assertFalse(sct_payload['correct']) self.assertEqual(sct_payload.get('line_start'), None) +class TestCheckFunction(unittest.TestCase): + def setUp(self): + self.data = { + "DC_PEC": "import numpy as np", + "DC_CODE": "np.array([1,2,3])", + "DC_SOLUTION": "np.array([1,2,3])", + "DC_SCT": "Ex().check_function('numpy.array', 0)" + } + + def run_append(self, sct): + self.data["DC_SCT"] += sct + return helper.run(self.data) + + def run_pass(self, sct): + sct_payload = self.run_append(sct) + print(sct_payload) + self.assertTrue(sct_payload['correct']) + return sct_payload + + def run_fail(self, sct): + self.assertFalse(self.run_append(sct)['correct']) + + def test_pass_np_call_exists(self): + sct_payload = helper.run(self.data) + self.assertTrue(sct_payload['correct']) + + def test_pass_test_student_typed(self): + self.run_pass(".test_student_typed(r'np\.array\(\[1,2,3\]\)')") + + def test_fail_test_student_typed(self): + self.data["DC_CODE"] = "np.array([1,2])" + self.run_fail(".test_student_typed(r'np\.array\(\[1,2,3\]\)')") + + def test_pass_func_has_equal_ast(self): + self.run_pass(".has_equal_ast()") + + def test_fail_func_has_equal_ast(self): + self.data["DC_CODE"] = "np.array([1,2])" + self.run_fail(".has_equal_ast()") + + def test_pass_check_args_pos_0(self): + self.run_pass(".check_args(0)") + + def test_fail_check_args_pos_0(self): + self.data["DC_CODE"] = "np.array()" + self.run_fail(".check_args(0)") + + def test_pass_pos_0_test_student_typed(self): + self.run_pass(".check_args(0).test_student_typed(r'\[1,2,3\]')") + + def test_fail_pos_0_test_student_typed(self): + self.data["DC_CODE"] = "np.array([1,2])" + self.run_fail(".check_args(0).test_student_typed(r'\[1,2,3\]')") + + def test_pass_pos_0_has_equal_ast(self): + self.run_pass(".check_args(0).has_equal_ast()") + + def test_fail_pos_0_has_equal_ast(self): + self.data["DC_CODE"] = "np.array([1,2])" + self.run_fail(".check_args(0).has_equal_ast()") + + def test_pass_pos_0_has_equal_value(self): + self.run_pass(".check_args(0).has_equal_value()") + + def test_fail_pos_0_has_equal_value(self): + self.data["DC_CODE"] = "np.array([1,2])" + self.run_fail(".check_args(0).has_equal_value()") + + def test_pass_pos_0_inline_if_body(self): + self.data["DC_CODE"] = "np.array([1,2,3] if True else [1])" + self.data["DC_SOLUTION"] = "np.array([1,2,3] if False else [1])" + self.run_pass(".check_args(0).check_if_exp(0).check_body().has_equal_ast()") + + def test_fail_pos_0_inline_if_body(self): + self.data["DC_CODE"] = "np.array([1,2,3] if True else [1])" + self.data["DC_SOLUTION"] = "np.array([1,2] if False else [1])" + self.run_fail(".check_args(0).check_if_exp(0).check_body().has_equal_ast()") + +class TestCheckFunctionCases(unittest.TestCase): + def setup_color(self): + self.data = { + 'DC_PEC': "def f(*args, **kwargs): pass", + 'DC_CODE': "f(color = 'blue')" + } + self.data["DC_SOLUTION"] = self.data["DC_CODE"] + + def test_pass_sig_false(self): + self.setup_color() + self.data['DC_SCT'] = "Ex().check_function('f', 0, signature=False).check_args('color').has_equal_ast()" + + sct_payload = helper.run(self.data) + self.assertTrue(sct_payload['correct']) + + @unittest.skip("TODO: implement override") + def test_pass_sig_false_override(self): + self.setup_color() + self.data["DC_SCT"].replace('color', 'c') + self.data['DC_SCT'] = """ +Ex().check_function('f', 0, signature=False).override("f(c = 'blue')").check_args('c').has_equal_ast() +""" + + sct_payload = helper.run(self.data) + self.assertTrue(sct_payload['correct']) + class TestFunctionComplexArgs(unittest.TestCase): def setUp(self): @@ -702,7 +811,5 @@ def test_fail_undillable_args(self): self.assertFalse(sct_payload['correct']) - - if __name__ == "__main__": unittest.main() diff --git a/tests/test_test_function_v2.py b/tests/test_test_function_v2.py index 47ef0faa..b1bf635d 100644 --- a/tests/test_test_function_v2.py +++ b/tests/test_test_function_v2.py @@ -669,6 +669,11 @@ def setUp(self): ''' } + self.SPEC2_SCT = """ +Ex().check_function('pandas.DataFrame', 0, missing_msg = "notcalledmsg", expand_msg="")\ + .check_args('data', missing_msg='paramsnotmatchedmsg') +""" + def test_step1(self): self.data["DC_CODE"] = "" sct_payload = helper.run(self.data) @@ -676,6 +681,10 @@ def test_step1(self): self.assertEqual('notcalledmsg', sct_payload['message']) helper.test_absent_lines(self, sct_payload) + def test_step1_spec2(self): + self.data["DC_SCT"] = self.SPEC2_SCT + self.test_step1() + def test_step2(self): self.data["DC_CODE"] = "df = pd.DataFrame(x=[1, 2, 3])" sct_payload = helper.run(self.data) @@ -683,6 +692,11 @@ def test_step2(self): self.assertEqual('paramsnotmatchedmsg', sct_payload['message']) helper.test_lines(self, sct_payload, 1, 1, 6, 30) + def test_step2_spec2(self): + self.data["DC_SCT"] = self.SPEC2_SCT + self.test_step2() + + def test_step3(self): self.data["DC_CODE"] = "df = pd.DataFrame(data=[1, 2, 3])" sct_payload = helper.run(self.data)