Skip to content

Commit

Permalink
[IMP] api: new decorators preupdate and postupdate
Browse files Browse the repository at this point in the history
  • Loading branch information
atp-odoo authored and rco-odoo committed Apr 17, 2019
1 parent defbb16 commit 3d2d6f9
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 42 deletions.
1 change: 1 addition & 0 deletions odoo/addons/test_new_api/ir.model.access.csv
Expand Up @@ -31,3 +31,4 @@ access_test_new_api_field_with_caps,access_test_new_api_field_with_caps,model_te
access_test_new_api_req_m2o,access_test_new_api_req_m2o,model_test_new_api_req_m2o,,1,1,1,1
access_test_new_api_attachment,access_test_new_api_attachment,model_test_new_api_attachment,,1,1,1,1
access_test_new_api_attachment_host,access_test_new_api_attachment_host,model_test_new_api_attachment_host,,1,1,1,1
access_test_new_api_update,access_test_new_api_update,model_test_new_api_update,,1,1,1,1
21 changes: 21 additions & 0 deletions odoo/addons/test_new_api/models.py
Expand Up @@ -559,3 +559,24 @@ class AttachmentHost(models.Model):
'test_new_api.attachment', 'res_id', auto_join=True,
domain=lambda self: [('res_model', '=', self._name)],
)


class TestAPIUpdate(models.Model):
_name = 'test_new_api.update'
_description = 'test api preupdate and postupdate'

name = fields.Char()
sent = fields.Boolean()
count = fields.Integer()

@api.preupdate('count')
def _preupdate_count(self, vals):
if vals['count'] == 0:
vals['name'] = 'Duck'
elif vals['count'] == 100:
vals['name'] = 'Century'

@api.postupdate('count')
def _postupdate_count(self, vals):
if vals['count'] == 100:
self.write({'sent': True})
30 changes: 30 additions & 0 deletions odoo/addons/test_new_api/tests/test_new_fields.py
Expand Up @@ -1523,3 +1523,33 @@ def test_explicit_set_null(self):

with self.assertRaises(ValueError):
field._setup_regular_base(Model)


class TestAPIUpdate(common.TransactionCase):

def test_api_update(self):
# create with 0, then write 42, then write 100
record = self.env['test_new_api.update'].create({'name': 'test', 'count': 0})
self.assertEqual(record.name, 'Duck', "preupdate not executed on create")
self.assertEqual(record.sent, False, "postupdate should do nothing")

record.write({'count': 42})
self.assertEqual(record.name, 'Duck', "preupdate should do nothing")
self.assertEqual(record.sent, False, "postupdate should do nothing")

record.write({'count': 100})
self.assertEqual(record.name, 'Century', "preupdate not executed on write")
self.assertEqual(record.sent, True, "postupdate not executed on write")

# create with 100, then write 42, then write 0
record = self.env['test_new_api.update'].create({'name': 'test', 'count': 100})
self.assertEqual(record.name, 'Century', "preupdate not executed on create")
self.assertEqual(record.sent, True, "postupdate not executed on create")

record.write({'count': 42})
self.assertEqual(record.name, 'Century', "preupdate should do nothing")
self.assertEqual(record.sent, True, "postupdate should do nothing")

record.write({'count': 0})
self.assertEqual(record.name, 'Duck', "preupdate not executed on write")
self.assertEqual(record.sent, True, "postupdate should do nothing")
48 changes: 47 additions & 1 deletion odoo/api.py
Expand Up @@ -42,7 +42,7 @@
'cr_uid_ids', 'cr_uid_ids_context',
'cr_uid_records', 'cr_uid_records_context',
'constrains', 'depends', 'onchange', 'returns',
'call_kw',
'call_kw', 'preupdate', 'postupdate',
]

import logging
Expand All @@ -66,6 +66,8 @@
# - method._returns: set by @returns, specifies return model
# - method._onchange: set by @onchange, specifies onchange fields
# - method.clear_cache: set by @ormcache, used to clear the cache
# - method._preupdate: set by @preupdate, specifies preupdate fields
# - method._postupdate: set by @postupdate, specifies postupdate fields
#
# On wrapping method only:
# - method._api: decorator function, used for re-applying decorator
Expand Down Expand Up @@ -199,6 +201,50 @@ def _onchange_partner(self):
return attrsetter('_onchange', args)


def preupdate(*args):
""" Return a decorator for a method to run before creating or modifying
records. The preupdate method can modify the dictionary of field values
before the actual creation/update. All arguments must be field names::
@preupdate('foo', 'bar')
def _preupdate_name(self, vals):
vals['baz'] = vals.get('foo', "") + vals.get('bar', "")
The method is invoked if some of the given field names are part of the
update. If no argument is given, the method is always invoked. The
method is invoked on an empty ``self`` upon create, and with the
modified records upon write.
.. warning::
``@preupdate`` only supports simple field names, dotted names (like
'partner_id.customer') are not supported and will be ignored.
"""
return attrsetter('_preupdate', args)


def postupdate(*args):
""" Return a decorator for a method to run after creating or modifying
records. All arguments must be field names::
@api.postupdate('parent_id')
def _postupdate_parent_id(self, vals):
for record in self:
if record.parent_id.customer:
record.customer = True
The method is invoked if some of the given field names are part of the
update. If no argument is given, the method is always invoked. The
method is invoked on the created/modified records.
.. warning::
``@postupdate`` only supports simple field names, dotted names (like
'partner_id.customer') are not supported and will be ignored.
"""
return attrsetter('_postupdate', args)


def depends(*args):
""" Return a decorator that specifies the field dependencies of a "compute"
method (for new-style function fields). Each argument must be a string
Expand Down
121 changes: 80 additions & 41 deletions odoo/models.py
Expand Up @@ -557,50 +557,45 @@ def _build_model_attributes(cls, pool):
child_class._build_model_attributes(pool)

@classmethod
def _init_constraints_onchanges(cls):
def _init_api_methods(cls):
# store sql constraint error messages
for (key, _, msg) in cls._sql_constraints:
cls.pool._sql_error[cls._table + '_' + key] = msg

# reset properties memoized on cls
cls._constraint_methods = BaseModel._constraint_methods
cls._onchange_methods = BaseModel._onchange_methods
cls._preupdate_methods = BaseModel._preupdate_methods
cls._postupdate_methods = BaseModel._postupdate_methods

@property
def _constraint_methods(self):
""" Return a list of methods implementing Python constraints. """
def is_constraint(func):
return callable(func) and hasattr(func, '_constrains')

cls = type(self)
methods = []
for attr, func in getmembers(cls, is_constraint):
for name in func._constrains:
field = cls._fields.get(name)
if not field:
_logger.warning("method %s.%s: @constrains parameter %r is not a field name", cls._name, attr, name)
elif not (field.store or field.inverse or field.inherited):
_logger.warning("method %s.%s: @constrains parameter %r is not writeable", cls._name, attr, name)
methods.append(func)

# optimization: memoize result on cls, it will not be recomputed
cls._constraint_methods = methods
return methods

@property
def _onchange_methods(self):
""" Return a dictionary mapping field names to onchange methods. """
def is_onchange(func):
return callable(func) and hasattr(func, '_onchange')

# collect onchange methods on the model's class
cls = type(self)
methods = defaultdict(list)
for attr, func in getmembers(cls, is_onchange):
for name in func._onchange:
@classmethod
def _retrieve_api_methods(cls):
# memoize the results on cls, in order to compute them once
cls._constraint_methods = []
cls._onchange_methods = defaultdict(list)
cls._preupdate_methods = []
cls._postupdate_methods = []

def check_field_names(func, decorator):
for name in getattr(func, '_' + decorator):
if name not in cls._fields:
_logger.warning("@onchange%r parameters must be field names", func._onchange)
methods[name].append(func)
_logger.warning("%s.%s: @%s parameter %r is not a field name",
cls._name, func.__name__, decorator, name)

for attr, func in getmembers(cls, callable):
if hasattr(func, '_constrains'):
check_field_names(func, 'constrains')
cls._constraint_methods.append(func)
if hasattr(func, '_onchange'):
check_field_names(func, 'onchange')
for name in func._onchange:
cls._onchange_methods[name].append(func)
if hasattr(func, '_preupdate'):
check_field_names(func, 'preupdate')
cls._preupdate_methods.append(func)
if hasattr(func, '_postupdate'):
check_field_names(func, 'postupdate')
cls._postupdate_methods.append(func)

# add onchange methods to implement "change_default" on fields
def onchange_default(field, self):
Expand All @@ -611,11 +606,32 @@ def onchange_default(field, self):

for name, field in cls._fields.items():
if field.change_default:
methods[name].append(functools.partial(onchange_default, field))
func = functools.partial(onchange_default, field)
cls._onchange_methods[name].append(func)

# optimization: memoize result on cls, it will not be recomputed
cls._onchange_methods = methods
return methods
@property
def _constraint_methods(self):
""" Return a list of methods implementing Python constraints. """
self._retrieve_api_methods()
return self._constraint_methods

@property
def _onchange_methods(self):
""" Return a dictionary mapping field names to onchange methods. """
self._retrieve_api_methods()
return self._onchange_methods

@property
def _preupdate_methods(self):
""" Return a list of preupdate methods. """
self._retrieve_api_methods()
return self._preupdate_methods

@property
def _postupdate_methods(self):
""" Return a list of postupdate methods. """
self._retrieve_api_methods()
return self._postupdate_methods

def __new__(cls):
# In the past, this method was registering the model class in the server.
Expand Down Expand Up @@ -1071,6 +1087,22 @@ def _log(base, record, field, exception):

yield dbid, xid, converted, dict(extras, record=stream.index)

def _process_preupdates(self, vals):
""" Run the preupdate methods on ``vals``. """
field_names = set(vals)
for func in self._preupdate_methods:
names = func._preupdate
if not names or any(name in field_names for name in names):
func(self, vals)

def _process_postupdates(self, vals):
""" Run the postupdate methods on ``vals``. """
field_names = set(vals)
for func in self._postupdate_methods:
names = func._postupdate
if not names or any(name in field_names for name in names):
func(self, vals)

@api.multi
def _validate_fields(self, field_names):
field_names = set(field_names)
Expand Down Expand Up @@ -2653,8 +2685,8 @@ def _setup_complete(self):
with tools.ignore(*exceptions):
field.setup_triggers(self)

# register constraints and onchange methods
cls._init_constraints_onchanges()
# register constraints, onchange, preupdate and postupdate methods
cls._init_api_methods()

# validate rec_name
if cls._rec_name:
Expand Down Expand Up @@ -3307,6 +3339,8 @@ def write(self, vals):
if not(self.env.uid == SUPERUSER_ID and not self.pool.ready):
bad_names.update(LOG_ACCESS_COLUMNS)

self._process_preupdates(vals)

# distribute fields into sets for various purposes
store_vals = {}
inverse_vals = {}
Expand Down Expand Up @@ -3396,6 +3430,8 @@ def write(self, vals):
# check Python constraints for inversed fields
self._validate_fields(set(inverse_vals) - set(store_vals))

self._process_postupdates(store_vals)

# recompute fields
if self.env.recompute and self._context.get('recompute', True):
self.recompute()
Expand Down Expand Up @@ -3552,6 +3588,8 @@ def create(self, vals_list):
# add missing defaults
vals = self._add_missing_default_values(vals)

self._process_preupdates(vals)

# distribute fields into sets for various purposes
data = {}
data['stored'] = stored = {}
Expand Down Expand Up @@ -3643,6 +3681,7 @@ def create(self, vals_list):
# check Python constraints for non-stored inversed fields
for data in data_list:
data['record']._validate_fields(set(data['inversed']) - set(data['stored']))
data['record']._process_postupdates(data['stored'])

# recompute fields
if self.env.recompute and self._context.get('recompute', True):
Expand Down

0 comments on commit 3d2d6f9

Please sign in to comment.