Skip to content

Commit

Permalink
[FIX] core: cursor hooks API and implementation
Browse files Browse the repository at this point in the history
Python 3.8 changed the equality rules for bound methods to be based on
the *identity* of the receiver (`__self__`) rather than its *equality*.
This means that in 3.7, methods from different instances will compare
(and hash) equal, thereby landing in the same map "slot", but that isn't
the case in 3.8.

While it's usually not relevant, it's an issue for `GroupCalls` which is
indexed by a function: in 3.7, that being a method from recordsets
comparing equal will deduplicate them, but not anymore in 3.8, leading
to duplicated callbacks (exactly the thing GroupCalls aims to avoid).

Also, the API of `GroupCalls` turned out to be unusual and weird.  The
bug above is fixed by using a plain list for callbacks, thereby avoiding
comparisons between registered functions.  The API is now:

    callbacks.add(func)     # add func to callbacks
    callbacks.run()         # run all callbacks in addition order
    callbacks.clear()       # remove all callbacks

In order to handle aggregated data, the `callbacks` object provides a
dictionary `callbacks.data` that any callback function can freely use.
For the sake of consistency, the `callbacks.data` dict is automatically
cleared upon execution of callbacks.

Discovered by @william-andre

Related to odoo#56583

References:

* https://bugs.python.org/issue1617161
* python/cpython#7848
* https://docs.python.org/3/whatsnew/changelog.html#python-3-8-0-alpha-1
  (no direct link because individual entries are not linkable, look for
  bpo-1617161)

X-original-commit: a3a4d14
  • Loading branch information
rco-odoo committed Sep 2, 2020
1 parent 612a8f7 commit 68058da
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 93 deletions.
11 changes: 6 additions & 5 deletions addons/mail/models/mail_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def _prepare_tracking(self, fields):
fnames = self._get_tracked_fields().intersection(fields)
if not fnames:
return
func = self.browse()._finalize_tracking
[initial_values] = self.env.cr.precommit.add(func, dict)
self.env.cr.precommit.add(self._finalize_tracking)
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
for record in self:
if not record.id:
continue
Expand All @@ -433,16 +433,17 @@ def _discard_tracking(self):
""" Prevent any tracking of fields on ``self``. """
if not self._get_tracked_fields():
return
func = self.browse()._finalize_tracking
[initial_values] = self.env.cr.precommit.add(func, dict)
self.env.cr.precommit.add(self._finalize_tracking)
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
# disable tracking by setting initial values to None
for id_ in self.ids:
initial_values[id_] = None

def _finalize_tracking(self, initial_values):
def _finalize_tracking(self):
""" Generate the tracking messages for the records that have been
prepared with ``_prepare_tracking``.
"""
initial_values = self.env.cr.precommit.data.pop(f'mail.tracking.{self._name}', {})
ids = [id_ for id_, vals in initial_values.items() if vals]
if not ids:
return
Expand Down
2 changes: 1 addition & 1 deletion addons/mail/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _reset_mail_context(cls, record):
def flush_tracking(self):
""" Force the creation of tracking values. """
self.env['base'].flush()
self.cr.precommit()
self.cr.precommit.run()

# ------------------------------------------------------------
# MAIL MOCKS
Expand Down
56 changes: 37 additions & 19 deletions odoo/addons/base/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,36 +272,54 @@ def test_01_code_and_format(self):
self.assertEqual(misc.format_time(lang.with_context(lang='zh_CN').env, time_part, time_format='medium', lang_code='fr_FR'), '16:30:22')


class TestGroupCalls(BaseCase):
def test_callbacks(self):
class TestCallbacks(BaseCase):
def test_callback(self):
log = []
callbacks = misc.Callbacks()

# add foo
def foo():
log.append("foo")

def bar(items):
log.extend(items)
callbacks.add(baz)
callbacks.add(foo)

def baz():
log.append("baz")
# add bar
@callbacks.add
def bar():
log.append("bar")

callbacks = misc.GroupCalls()
# add foo again
callbacks.add(foo)
callbacks.add(bar, list)[0].append(1)
callbacks.add(bar, list)[0].append(2)
self.assertEqual(log, [])

callbacks()
self.assertEqual(log, ["foo", 1, 2, "baz"])
# this should call foo(), bar(), foo()
callbacks.run()
self.assertEqual(log, ["foo", "bar", "foo"])

# this should do nothing
callbacks.run()
self.assertEqual(log, ["foo", "bar", "foo"])

def test_aggregate(self):
log = []
callbacks = misc.Callbacks()

# register foo once
@callbacks.add
def foo():
log.append(callbacks.data["foo"])

# aggregate data
callbacks.data.setdefault("foo", []).append(1)
callbacks.data.setdefault("foo", []).append(2)
callbacks.data.setdefault("foo", []).append(3)

callbacks()
self.assertEqual(log, ["foo", 1, 2, "baz"])
# foo() is called once
callbacks.run()
self.assertEqual(log, [[1, 2, 3]])
self.assertFalse(callbacks.data)

callbacks.add(bar, list)[0].append(3)
callbacks.clear()
callbacks()
self.assertEqual(log, ["foo", 1, 2, "baz"])
callbacks.run()
self.assertEqual(log, [[1, 2, 3]])


class TestRemoveAccents(BaseCase):
Expand Down
2 changes: 1 addition & 1 deletion odoo/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def checked_call(___dbname, *a, **kw):
# flush here to avoid triggering a serialization error outside
# of this context, which would not retry the call
flush_env(self._cr)
self._cr.precommit()
self._cr.precommit.run()
return result

if self.db:
Expand Down
34 changes: 10 additions & 24 deletions odoo/sql_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,11 @@ def check(f, self, *args, **kwargs):


class BaseCursor:
""" Base class for cursors that manages pre/post commit/rollback hooks. """
""" Base class for cursors that manage pre/post commit hooks. """

def __init__(self):
self.precommit = tools.GroupCalls()
self.postcommit = tools.GroupCalls()
self.prerollback = tools.GroupCalls()
self.postrollback = tools.GroupCalls()
self.precommit = tools.Callbacks()
self.postcommit = tools.Callbacks()

@contextmanager
@check
Expand All @@ -111,20 +109,17 @@ def savepoint(self, flush=True):
name = uuid.uuid1().hex
if flush:
flush_env(self, clear=False)
self.precommit()
self.prerollback.clear()
self.precommit.run()
self.execute('SAVEPOINT "%s"' % name)
try:
yield
if flush:
flush_env(self, clear=False)
self.precommit()
self.prerollback.clear()
self.precommit.run()
except Exception:
if flush:
clear_env(self)
self.precommit.clear()
self.prerollback()
self.execute('ROLLBACK TO SAVEPOINT "%s"' % name)
raise
else:
Expand Down Expand Up @@ -428,17 +423,15 @@ def after(self, event, func):
if event == 'commit':
self.postcommit.add(func)
elif event == 'rollback':
self.postrollback.add(func)
raise NotImplementedError()

@check
def commit(self):
""" Perform an SQL `COMMIT` """
flush_env(self)
self.precommit()
self.precommit.run()
result = self._cnx.commit()
self.prerollback.clear()
self.postrollback.clear()
self.postcommit()
self.postcommit.run()
return result

@check
Expand All @@ -447,9 +440,7 @@ def rollback(self):
clear_env(self)
self.precommit.clear()
self.postcommit.clear()
self.prerollback()
result = self._cnx.rollback()
self.postrollback()
return result

@check
Expand Down Expand Up @@ -506,23 +497,18 @@ def autocommit(self, on):
def commit(self):
""" Perform an SQL `COMMIT` """
flush_env(self)
self.precommit()
self.precommit.run()
self._cursor.execute('SAVEPOINT "%s"' % self._savepoint)
self.prerollback.clear()
# ignore post-commit/rollback hooks
# ignore post-commit hooks
self.postcommit.clear()
self.postrollback.clear()

@check
def rollback(self):
""" Perform an SQL `ROLLBACK` """
clear_env(self)
self.precommit.clear()
self.prerollback()
self._cursor.execute('ROLLBACK TO SAVEPOINT "%s"' % self._savepoint)
# ignore post-commit/rollback hooks
self.postcommit.clear()
self.postrollback.clear()

def __getattr__(self, name):
value = getattr(self._cursor, name)
Expand Down
8 changes: 4 additions & 4 deletions odoo/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,12 @@ def assertQueryCount(self, default=0, flush=True, **counters):
expected = counters.get(login, default)
if flush:
self.env.user.flush()
self.env.cr.precommit()
self.env.cr.precommit.run()
count0 = self.cr.sql_log_count
yield
if flush:
self.env.user.flush()
self.env.cr.precommit()
self.env.cr.precommit.run()
count = self.cr.sql_log_count - count0
if count != expected:
# add some info on caller to allow semi-automatic update of query count
Expand All @@ -455,11 +455,11 @@ def assertQueryCount(self, default=0, flush=True, **counters):
# same operations, otherwise the caches might not be ready!
if flush:
self.env.user.flush()
self.env.cr.precommit()
self.env.cr.precommit.run()
yield
if flush:
self.env.user.flush()
self.env.cr.precommit()
self.env.cr.precommit.run()

def assertRecordValues(self, records, expected_values):
''' Compare a recordset with a list of dictionaries representing the expected results.
Expand Down
96 changes: 57 additions & 39 deletions odoo/tools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,54 +1084,72 @@ def add(self, elem):
OrderedSet.add(self, elem)


class GroupCalls:
""" A collection of callbacks with support for aggregated arguments. Upon
call, every registered function is called once with positional arguments.
When registering a function, a tuple of positional arguments is returned, so
that the caller can modify the arguments in place. This allows to
accumulate some data to process once::
class Callbacks:
""" A simple queue of callback functions. Upon run, every function is
called (in addition order), and the queue is emptied.
callbacks = GroupCalls()
callbacks = Callbacks()
# register print (by default with a list)
[args] = callbacks.register(print, list)
args.append(42)
# add foo
def foo():
print("foo")
# add an element to the list to print
[args] = callbacks.register(print, list)
args.append(43)
callbacks.add(foo)
# print "[42, 43]"
callbacks()
# add bar
callbacks.add
def bar():
print("bar")
# add foo again
callbacks.add(foo)
# call foo(), bar(), foo(), then clear the callback queue
callbacks.run()
The queue also provides a ``data`` dictionary, that may be freely used to
store anything, but is mostly aimed at aggregating data for callbacks. The
dictionary is automatically cleared by ``run()`` once all callback functions
have been called.
# register foo to process aggregated data
@callbacks.add
def foo():
print(sum(callbacks.data['foo']))
callbacks.data.setdefault('foo', []).append(1)
...
callbacks.data.setdefault('foo', []).append(2)
...
callbacks.data.setdefault('foo', []).append(3)
# call foo(), which prints 6
callbacks.run()
Given the global nature of ``data``, the keys should identify in a unique
way the data being stored. It is recommended to use strings with a
structure like ``"{module}.{feature}"``.
"""
__slots__ = ['_funcs', 'data']

def __init__(self):
self._func_args = {} # {func: args}
self._funcs = []
self.data = {}

def __call__(self):
""" Call all the registered functions (in first addition order) with
their respective arguments. Only recurrent functions remain registered
after the call.
"""
func_args = self._func_args
while func_args:
func = next(iter(func_args))
args = func_args.pop(func)
func(*args)

def add(self, func, *types):
""" Register the given function, and return the tuple of positional
arguments to call the function with. If the function is not registered
yet, the list of arguments is made up by invoking the given types.
"""
try:
return self._func_args[func]
except KeyError:
args = self._func_args[func] = [type_() for type_ in types]
return args
def add(self, func):
""" Add the given function. """
self._funcs.append(func)

def run(self):
""" Call all the functions (in addition order), then clear. """
for func in self._funcs:
func()
self.clear()

def clear(self):
""" Remove all callbacks from self. """
self._func_args.clear()
""" Remove all callbacks and data from self. """
self._funcs.clear()
self.data.clear()


class IterableGenerator:
Expand Down

0 comments on commit 68058da

Please sign in to comment.