Skip to content

Commit

Permalink
Fix a regression introduced in cocotbgh-727 and cocotbgh-723, caused …
Browse files Browse the repository at this point in the history
…by a misunderstanding of how __new__ works

In those patches, I declared classes that were roughly
```python
class SomeClass:
    def __new__(cls, arg):
        try:
            return existing[arg]
        except KeyError:
            return super(SomeClass, cls).__new__(cls, arg)

    def __init__(self, arg):
        self.arg = arg
        self.state = 0
```

This approach has a fatal flaw (cocotbgh-729), with function calls shown in the following code:
```python
A = SomeClass(1)
# SomeClass.__new__(SomeClass, 1) -> A
# A.__init__(1)
B = SomeClass(1)
# SomeClass.__new__(SomeClass, 1) -> A   # reusing the existing instance
# A.__init__(1)  # uh oh, we reset A.state
```

We need to override class-construction without allowing `__init__` to run a second time.

This introduces a simple metaclass to allow overriding the behavior of `RisingEdge.__call__(...)`
  • Loading branch information
eric-wieser committed Dec 20, 2018
1 parent 435f7c9 commit c293c99
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
36 changes: 30 additions & 6 deletions cocotb/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@
from cocotb.utils import get_sim_steps, get_time_from_sim_steps


class _CallableMeta(type):
"""
A metaclass that allows classes to override the call classmethod
"""

def __call__(cls, *args, **kwargs):
return cls.__class_call__(*args, **kwargs)


# Equivalent to `class _ClassCallableBase(metaclass=_CacheableType): pass`, but works
# on python 2 too.
# TODO: use six.with_metaclass instead
_ClassCallableBase = _CallableMeta('_ClassCallableBase', (), {})
class _ClassCallableBase(_ClassCallableBase):
@classmethod
def __class_call__(cls, *args, **kwargs):
"""
Overrides `cls(...)`.
"""
return type.__call__(cls, *args, **kwargs)


class TriggerException(Exception):
pass

Expand Down Expand Up @@ -206,7 +228,7 @@ def NextTimeStep():
return _nxts


class _EdgeBase(GPITrigger):
class _EdgeBase(GPITrigger, _ClassCallableBase):
"""
Execution will resume when an edge occurs on the provided signal
"""
Expand All @@ -222,13 +244,14 @@ def _edge_type(self):
# Using a weak dictionary ensures we don't create a reference cycle
_instances = weakref.WeakValueDictionary()

def __new__(cls, signal):
@classmethod
def __class_call__(cls, signal):
# find the existing instance, if possible - else create a new one
key = (signal, cls._edge_type)
try:
return cls._instances[key]
except KeyError:
instance = super(_EdgeBase, cls).__new__(cls)
instance = super(_EdgeBase, cls).__class_call__(signal)
cls._instances[key] = instance
return instance

Expand Down Expand Up @@ -507,20 +530,21 @@ def prime(self, callback):
callback(self)


class Join(PythonTrigger):
class Join(PythonTrigger, _ClassCallableBase):
"""
Join a coroutine, firing when it exits
"""
# Ensure that each coroutine has at most one join trigger.
# Using a weak dictionary ensures we don't create a reference cycle
_instances = weakref.WeakValueDictionary()

def __new__(cls, coroutine):
@classmethod
def __class_call__(cls, coroutine):
# find the existing instance, if possible - else create a new one
try:
return cls._instances[coroutine]
except KeyError:
instance = super(Join, cls).__new__(cls)
instance = super(Join, cls).__class_call__(coroutine)
cls._instances[coroutine] = instance
return instance

Expand Down
18 changes: 18 additions & 0 deletions tests/test_cases/test_cocotb/test_cocotb.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,24 @@ def test_join_identity(dut):
clk_gen.kill()


@cocotb.test()
def test_edge_identity(dut):
"""
Test that Edge triggers returns the same object each time
"""

re = RisingEdge(dut.clk)
fe = FallingEdge(dut.clk)
e = Edge(dut.clk)

assert re is RisingEdge(dut.clk)
assert fe is FallingEdge(dut.clk)
assert e is Edge(dut.clk)

# check they are all unique
assert len({re, fe, e}) == 3
yield Timer(1)


if sys.version_info[:2] >= (3, 3):
# this would be a syntax error in older python, so we do the whole
Expand Down

0 comments on commit c293c99

Please sign in to comment.