diff --git a/fault/actions.py b/fault/actions.py index 6c81840d..9418d80a 100644 --- a/fault/actions.py +++ b/fault/actions.py @@ -2,31 +2,44 @@ class Action(ABC): - pass + @abstractmethod + def retarget(self, new_circuit, clock): + """ + Create a copy of the action for `new_circuit` with `clock` + """ + raise NotImplementedError() -class Poke(Action): +class PortAction(Action): def __init__(self, port, value): super().__init__() - if port.isinput(): - raise ValueError(f"Can only poke inputs: {port} {type(port)}") self.port = port self.value = value def __str__(self): - return f"Poke({self.port.debug_name}, {self.value})" + type_name = type(self).__name__ + return f"{type_name}({self.port.debug_name}, {self.value})" + + def retarget(self, new_circuit, clock): + cls = type(self) + new_port = new_circuit.interface.ports[str(self.port.name)] + return cls(new_port, self.value) -class Expect(Action): +class Poke(PortAction): def __init__(self, port, value): - super().__init__() - if port.isoutput(): - raise ValueError(f"Can only expect on outputs: {port} {type(port)}") - self.port = port - self.value = value + if port.isinput(): + raise ValueError(f"Can only poke inputs: {port.debug_name} " + f"{type(port)}") + super().__init__(port, value) - def __str__(self): - return f"Expect({self.port.debug_name}, {self.value})" + +class Expect(PortAction): + def __init__(self, port, value): + if port.isoutput(): + raise ValueError(f"Can only expect on outputs: {port.debug_name} " + f"{type(port)}") + super().__init__(port, value) class Eval(Action): @@ -34,14 +47,21 @@ def __init__(self): super().__init__() def __str__(self): - return f"Eval()" + return "Eval()" + + def retarget(self, new_circuit, clock): + return Eval() class Step(Action): def __init__(self, clock, steps): super().__init__() + # TODO(rsetaluri): Check if `clock` is a clock type? self.clock = clock self.steps = steps def __str__(self): return f"Step({self.clock.debug_name}, steps={self.steps})" + + def retarget(self, new_circuit, clock): + return Step(clock, self.steps) diff --git a/fault/circuit_utils.py b/fault/circuit_utils.py new file mode 100644 index 00000000..d66c0488 --- /dev/null +++ b/fault/circuit_utils.py @@ -0,0 +1,18 @@ +def check_interface_is_subset(circuit1, circuit2): + """ + Checks that the interface of circuit1 is a subset of circuit2 + + Subset is defined as circuit2 contains all the ports of circuit1. Ports are + matched by name comparison, then the types are checked to see if one could + be converted to another. + """ + circuit1_port_names = circuit1.interface.ports.keys() + for name in circuit1_port_names: + if name not in circuit2.interface.ports: + raise ValueError(f"{circuit2} (circuit2) does not have port {name}") + circuit1_kind = type(type(getattr(circuit1, name))) + circuit2_kind = type(type(getattr(circuit2, name))) + # Check that the type of one could be converted to the other + if not (issubclass(circuit2_kind, circuit1_kind) or + issubclass(circuit1_kind, circuit2_kind)): + raise ValueError("Types don't match") diff --git a/fault/tester.py b/fault/tester.py index 34cbbd52..010391d7 100644 --- a/fault/tester.py +++ b/fault/tester.py @@ -5,6 +5,9 @@ from fault.vector_builder import VectorBuilder from fault.value_utils import make_value from fault.verilator_target import VerilatorTarget +from fault.actions import Poke, Expect, Step +from fault.circuit_utils import check_interface_is_subset +import copy class Tester: @@ -20,11 +23,13 @@ def make_target(self, target, **kwargs): return VerilatorTarget(self.circuit, self.actions, **kwargs) if target == "coreir": return MagmaSimulatorTarget(self.circuit, self.actions, - backend='coreir', **kwargs) + clock=self.clock, backend='coreir', + **kwargs) if target == "python": warning("Python simulator is not actively supported") return MagmaSimulatorTarget(self.circuit, self.actions, - backend='python', **kwargs) + clock=self.clock, backend='python', + **kwargs) raise NotImplementedError(target) def poke(self, port, value): @@ -53,3 +58,18 @@ def serialize(self): def compile_and_run(self, target="verilator", **kwargs): target_inst = self.make_target(target, **kwargs) target_inst.run() + + def retarget(self, new_circuit, clock=None): + """ + Generates a new instance of the Tester object that targets + `new_circuit`. This allows you to copy a set of actions for a new + circuit with the same interface (or an interface that is a super set of + self.circuit) + """ + # Check that the interface of self.circuit is a subset of new_circuit + check_interface_is_subset(self.circuit, new_circuit) + + new_tester = Tester(new_circuit, clock) + new_tester.actions = [action.retarget(new_circuit, clock) for action in + self.actions] + return new_tester diff --git a/tests/common.py b/tests/common.py index b5ec9356..f3455484 100644 --- a/tests/common.py +++ b/tests/common.py @@ -22,3 +22,5 @@ def definition(io): TestNestedArraysCircuit = define_simple_circuit(m.Array(3, m.Bits(4)), "NestedArraysCircuit") TestBasicClkCircuit = define_simple_circuit(m.Bit, "BasicClkCircuit", True) +TestBasicClkCircuitCopy = define_simple_circuit(m.Bit, "BasicClkCircuitCopy", + True) diff --git a/tests/test_tester.py b/tests/test_tester.py index 0f370948..bff13794 100644 --- a/tests/test_tester.py +++ b/tests/test_tester.py @@ -46,3 +46,32 @@ def test_tester_nested_arrays(): expected.append(Expect(circ.O[i], val)) for i, exp in enumerate(expected): check(tester.actions[i], exp) + + +def test_copy_tester(): + circ = common.TestBasicClkCircuit + expected = [ + Poke(circ.I, 0), + Expect(circ.O, 0), + Poke(circ.CLK, 0), + Step(circ.CLK, 1) + ] + tester = fault.Tester(circ, circ.CLK) + tester.poke(circ.I, 0) + tester.expect(circ.O, 0) + tester.poke(circ.CLK, 0) + tester.step() + print(tester.actions) + for i, exp in enumerate(expected): + check(tester.actions[i], exp) + + circ_copy = common.TestBasicClkCircuitCopy + copy = tester.retarget(circ_copy, circ_copy.CLK) + copy_expected = [ + Poke(circ_copy.I, 0), + Expect(circ_copy.O, 0), + Poke(circ_copy.CLK, 0), + Step(circ_copy.CLK, 1) + ] + for i, exp in enumerate(copy_expected): + check(copy.actions[i], exp)