In [None]:
#| default_exp verification

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export mocking.functions
from dataclasses import dataclass
from typing import Any

In [None]:
from pymoq.mocking.functions import FunctionMock, remove_self_parameter

from fastcore.basics import patch_to
from fastcore.test import test_fail

# Verification
> Verify that the mock was called the way you intended

## Proof of concept

In [None]:
from pymoq.mocking import objects
from typing import Protocol

In [None]:
Mock = objects.Mock

In [None]:
class IWeb(Protocol):
    def get(self, a: int, b:str, c:float|None=None):
        ...

Calls to the mock are recorded with the full argument list.

In [None]:
m = Mock(IWeb)

m.get(1,"2")

m.get._calls

[((None, 1, '2'), {'c': None})]

These can then be matched against any signature:

##### Constructin signature validator

In [None]:
from pymoq.argument_validators import AnyArg
from pymoq.signature_validators import signature_validator_from_arguments
from pymoq.mocking.functions import add_self_parameter

In [None]:
sign_val = signature_validator_from_arguments(['self', 'a', 'b', 'c'], AnyArg(), 1, "2", c=AnyArg())

##### Matching against recorded call

In [None]:
args, kwargs = m.get._calls[0]
kwargs = m.get.fill_up_arg_list(args, kwargs)

print(sign_val.is_valid(*args, **kwargs))

True


## Implementation

In [None]:
#| export mocking.functions
@dataclass
class VerifiedCalls:
    verified_calls: list[tuple[list[Any], dict[str, Any]]]
    all_calls: list[tuple[list[Any], dict[str, Any]]]
    
    @property
    def verified(self): return len(self.verified_calls)

    def times(self, amount: int):
        """Asserts that the number of verified calls is  exactly `amount`"""
        msg = self._build_error_msg(f"Expected {amount} calls, got {self.verified}.")
        assert self.verified==amount, msg
        
    def never(self):
        """Asserts that no verified call was made"""
        self.times(0)
        
    def more_than(self, lower_bound: int):
        """Asserts that more than `lower_bound` verified calls were made"""
        msg = self._build_error_msg(f"Expected more than {lower_bound} calls, got {self.verified}.")
        assert lower_bound < self.verified, msg
        
    def less_than(self, upper_bound: int):
        """Asserts that less than `upper_bound` verified calls were made"""
        msg = self._build_error_msg(f"Expected less than {upper_bound} calls, got {self.verified}.")
        assert self.verified < upper_bound, msg
        
    def more_than_or_equal_to(self, lower_bound: int):
        """Asserts that more than or equal to `lower_bound` verified calls were made"""
        msg = self._build_error_msg(f"Expected at least {lower_bound} calls, got {self.verified}.")
        assert lower_bound <= self.verified, msg
        
    def less_than_or_equal_to(self, upper_bound: int):
        """Asserts that less than or equal to `upper_bound` verified calls were made"""
        msg = self._build_error_msg(f"Expected at maximum {upper_bound} calls, got {self.verified}.")
        assert self.verified <= upper_bound, msg
        
    def _build_error_msg(self, general_msg: str) -> str:
        calls_str = "Matched Calls:\n\t" + "\n\t".join(map(str, self.verified_calls))
        total_calls_str = "All Calls:\n\t" + "\n\t".join(map(str, self.all_calls))
        
        msg = "\n".join((general_msg, calls_str, total_calls_str))
        return msg

In [None]:
#| export mocking.functions
@patch_to(FunctionMock)
def verify(self, *args, **kwargs) -> VerifiedCalls:
    kwargs = self.fill_up_arg_list(add_self_parameter(args), kwargs)
    args = (AnyArg(),) + args
    sign_val = signature_validator_from_arguments(self._argument_names, *args, **kwargs)
    
    calls = []
    
    for call_args, call_kwargs in self._calls:
        call_kwargs = self.fill_up_arg_list(call_args, call_kwargs)

        
        if sign_val.is_valid(*call_args, **call_kwargs):
            calls.append((call_args, call_kwargs))
    return VerifiedCalls(calls, self._calls)

In [None]:
m = Mock(IWeb)


m.get(1,"2")
m.get(2,"2")
m.get(2.3,"2")

calls = m.get.verify(int, "2")
assert calls.verified == 2
assert calls.verified_calls == [((None, 1, '2'), {'c': None}), ((None, 2, '2'), {'c': None})]

calls

VerifiedCalls(verified_calls=[((None, 1, '2'), {'c': None}), ((None, 2, '2'), {'c': None})], all_calls=[((None, 1, '2'), {'c': None}), ((None, 2, '2'), {'c': None}), ((None, 2.3, '2'), {'c': None})])

In [None]:
m.get.verify(int, "2").times(2)
m.get.verify(int, "2").more_than(1)
m.get.verify(int, "2").more_than_or_equal_to(2)
m.get.verify(int, "2").less_than(3)
m.get.verify(int, "2").less_than_or_equal_to(2)
m.get.verify(str, int).never()

A failing assertion gives the following error message:

In [None]:
try:
    m.get.verify(int, "2").times(1)
except Exception as e:
    print(e)

Expected 1 calls, got 2.
Matched Calls:
	((None, 1, '2'), {'c': None})
	((None, 2, '2'), {'c': None})
All Calls:
	((None, 1, '2'), {'c': None})
	((None, 2, '2'), {'c': None})
	((None, 2.3, '2'), {'c': None})


In [None]:
test_fail(lambda: m.get.verify(int, "2").times(1))
test_fail(lambda: m.get.verify(int, "2").never())
test_fail(lambda: m.get.verify(int, "2").more_than(3))
test_fail(lambda: m.get.verify(int, "2").more_than_or_equal_to(3))
test_fail(lambda: m.get.verify(int, "2").less_than(1))
test_fail(lambda: m.get.verify(int, "2").less_than_or_equal_to(1))

# Build library

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()