Skip to content
11 changes: 4 additions & 7 deletions addons/analytic/models/analytic_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Part of Odoo. See LICENSE file for full copyright and licensing details.
from odoo import models, fields, api, _
from odoo.fields import Domain
from odoo.tools import SQL, Query, unique
from odoo.tools.float_utils import float_round, float_compare
from odoo.exceptions import UserError, ValidationError
Expand Down Expand Up @@ -141,16 +142,12 @@ def _get_count_id(self, query):
raise ValueError(f"{query.table} does not support analytic_distribution grouping.")
return SQL(ids.get(query.table))

def mapped(self, func):
# Get the related analytic accounts as a recordset instead of the distribution
if func == 'analytic_distribution' and self.env.context.get('distribution_ids'):
return self.distribution_analytic_account_ids
return super().mapped(func)

def filtered_domain(self, domain):
# Filter based on the accounts used (i.e. allowing a name_search) instead of the distribution
# A domain on a binary field doesn't make sense anymore outside of set or not; and it is still doable.
return super(AnalyticMixin, self.with_context(distribution_ids=True)).filtered_domain(domain)
# Hack to filter using another field.
domain = Domain(domain).map_conditions(lambda cond: Domain('distribution_analytic_account_ids', cond.operator, cond.value) if cond.field_expr == 'analytic_distribution' else cond)
return super().filtered_domain(domain)

def write(self, vals):
""" Format the analytic_distribution float value, so equality on analytic_distribution can be done """
Expand Down
4 changes: 2 additions & 2 deletions addons/hr_holidays/models/hr_leave.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,10 @@ def _compute_dashboard_warning_message(self):
])
for holiday in self:
conflicting_holidays = all_leaves.filtered_domain([
('employee_id', '=', holiday.employee_id.id),
('employee_id', 'in', holiday.employee_id.ids),
('date_from', '<', holiday.date_to),
('date_to', '>', holiday.date_from),
('id', '!=', holiday.id),
('id', 'not in', holiday.ids),
])
if not conflicting_holidays:
holiday.dashboard_warning_message = False
Expand Down
3 changes: 1 addition & 2 deletions odoo/addons/base/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,8 +1190,7 @@ def test_invalid(self):
with self.assertRaisesRegex(ValueError, r"^Invalid operator.*\('create_date', '>>', 'foo'\)$"):
Country.search([('create_date', '>>', 'foo')])

# TODO make it "Invalid operator"" for consistency
with self.assertRaisesRegex(ValueError, r"^stray % in format '%'$"):
with self.assertRaisesRegex(ValueError, r"^Invalid operator"):
Country.search([]).filtered_domain([('create_date', '>>', 'foo')])

with self.assertRaisesRegex(ValueError, r"Invalid isoformat string"):
Expand Down
98 changes: 49 additions & 49 deletions odoo/addons/test_orm/tests/test_properties.py

Large diffs are not rendered by default.

17 changes: 11 additions & 6 deletions odoo/addons/test_orm/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from odoo.addons.base.tests.test_expression import TransactionExpressionCase
from odoo.fields import Command
from odoo.tests import TransactionCase

Expand Down Expand Up @@ -1130,7 +1131,7 @@ def test_depends_with_table_query_model_sql(self):
self.assertEqual(self.env['test_orm.custom.table_query_sql'].search([]).sum_quantity, 25)


class TestDatePartNumber(TransactionCase):
class TestDatePartNumber(TransactionExpressionCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
Expand Down Expand Up @@ -1166,25 +1167,29 @@ def test_basic_cases(self):
result = Person.search([('birthday.iso_week_number', '=', '6')])
self.assertEqual(result, self.person)

def test_datetime_filtered(self):
Person = self.env["test_orm.person"].with_context(active_test=False)
self.assertEqual(self._search(Person, [('birthday.month_number', '=', 2)]), self.person)

def test_many2one(self):
result = self.env["test_orm.lesson"].search([('teacher_id.birthday.month_number', '=', 2)])
result = self._search(self.env["test_orm.lesson"], [('teacher_id.birthday.month_number', '=', 2)])
self.assertEqual(result, self.lesson)

def test_many2many(self):
result = self.env["test_orm.lesson"].search([('attendee_ids.birthday.month_number', '=', 2)])
result = self._search(self.env["test_orm.lesson"], [('attendee_ids.birthday.month_number', '=', 2)])
self.assertEqual(result, self.lesson)

def test_related_field(self):
result = self.env["test_orm.lesson"].search([('teacher_birthdate.month_number', '=', 2)])
result = self._search(self.env["test_orm.lesson"], [('teacher_birthdate.month_number', '=', 2)])
self.assertEqual(result, self.lesson)

def test_inherit(self):
account = self.env["test_orm.person.account"].create({"person_id": self.person.id, "activation_date": "2020-03-09"})

result = self.env["test_orm.person.account"].search([('activation_date.quarter_number', '=', 1)])
result = self._search(self.env["test_orm.person.account"], [('activation_date.quarter_number', '=', 1)])
self.assertEqual(result, account)

result = self.env["test_orm.person.account"].search([('person_id.birthday.month_number', '=', 2)])
result = self._search(self.env["test_orm.person.account"], [('person_id.birthday.month_number', '=', 2)])
self.assertEqual(result, account)


Expand Down
101 changes: 98 additions & 3 deletions odoo/orm/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@

from odoo.exceptions import UserError
from odoo.tools import SQL, OrderedSet, Query, classproperty, partition, str2bool

from .identifiers import NewId
from .utils import COLLECTION_TYPES
from .utils import COLLECTION_TYPES, parse_field_expr

if typing.TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterable
from odoo.fields import Field
from odoo.models import BaseModel

M = typing.TypeVar('M', bound=BaseModel)


_logger = logging.getLogger('odoo.domains')

Expand Down Expand Up @@ -364,6 +367,15 @@ def validate(self, model: BaseModel) -> None:
# just execute the optimization code that goes through all the fields
self.optimize(model, full=True)

def _as_predicate(self, records: M) -> Callable[[M], bool]:
"""Return a predicate function from the domain (bound to records).
The predicate function return whether its argument (a single record)
satisfies the domain.

This is used to implement ``Model.filtered_domain``.
"""
raise NotImplementedError

def optimize(self, model: BaseModel, *, full: bool = False) -> Domain:
"""Perform optimizations of the node given a model.

Expand Down Expand Up @@ -467,6 +479,9 @@ def __or__(self, other):
def __iter__(self):
yield _TRUE_LEAF if self.value else _FALSE_LEAF

def _as_predicate(self, records):
return lambda _: self.value

def _to_sql(self, model: BaseModel, alias: str, query: Query) -> SQL:
return SQL("TRUE") if self.value else SQL("FALSE")

Expand Down Expand Up @@ -512,6 +527,10 @@ def __eq__(self, other):
def __hash__(self):
return ~hash(self.child)

def _as_predicate(self, records):
predicate = self.child._as_predicate(records)
return lambda rec: not predicate(rec)

def _to_sql(self, model: BaseModel, alias: str, query: Query) -> SQL:
condition = self.child._to_sql(model, alias, query)
return SQL("(%s) IS NOT TRUE", condition)
Expand Down Expand Up @@ -631,6 +650,18 @@ def __and__(self, other):
return DomainAnd(self.children + other.children)
return super().__and__(other)

def _as_predicate(self, records):
# For the sake of performance, the list of predicates is generated
# lazily with a generator, which is memoized with `itertools.tee`
all_predicates = (child._as_predicate(records) for child in self.children)

def and_predicate(record):
nonlocal all_predicates
all_predicates, predicates = itertools.tee(all_predicates)
return all(pred(record) for pred in predicates)

return and_predicate


class DomainOr(DomainNary):
"""Domain: OR with multiple children"""
Expand All @@ -649,6 +680,18 @@ def __or__(self, other):
return DomainOr(self.children + other.children)
return super().__or__(other)

def _as_predicate(self, records):
# For the sake of performance, the list of predicates is generated
# lazily with a generator, which is memoized with `itertools.tee`
all_predicates = (child._as_predicate(records) for child in self.children)

def or_predicate(record):
nonlocal all_predicates
all_predicates, predicates = itertools.tee(all_predicates)
return any(pred(record) for pred in predicates)

return or_predicate


class DomainCondition(Domain):
"""Domain condition on field: (field, operator, value)
Expand Down Expand Up @@ -777,14 +820,14 @@ def _field(self, model: BaseModel) -> Field:

def __get_field(self, model: BaseModel) -> tuple[Field, str]:
"""Get the field or raise an exception"""
field_name, *props = self.field_expr.split('.', 1)
field_name, property_name = parse_field_expr(self.field_expr)
try:
field = model._fields[field_name]
except KeyError:
self._raise("Invalid field %s.%s", model._name, field_name)
# cache field value, with this hack to bypass immutability
object.__setattr__(self, '_field_instance', field)
return field, (props[0] if props else '')
return field, property_name or ''

def _optimize(self, model: BaseModel, full: bool) -> Domain:
"""Optimization step.
Expand Down Expand Up @@ -897,6 +940,56 @@ def _optimize_field_search_method(self, model: BaseModel) -> Domain:
model_label=f"{model.env['ir.model']._get(model._name).name!r} ({model._name})",
))

def _as_predicate(self, records):
if not records:
return lambda _: False

if self._opt_level < OptimizationLevel.BASIC:
return self.optimize(records, full=False)._as_predicate(records)

operator = self.operator
if operator in ('child_of', 'parent_of'):
# TODO have a specific implementation for these
return self.optimize(records, full=True)._as_predicate(records)

assert operator in STANDARD_CONDITION_OPERATORS, "Expecting a sub-set of operators"
field_expr, value = self.field_expr, self.value
positive_operator = NEGATIVE_CONDITION_OPERATORS.get(operator, operator)

if isinstance(value, SQL):
# transform into an Query value
if positive_operator == operator:
condition = self
operator = 'any'
else:
condition = ~self
operator = 'not any'
positive_operator = 'any'
field_expr = 'id'
value = records.with_context(active_test=False)._search(DomainCondition('id', 'in', OrderedSet(records.ids)) & condition)
assert isinstance(value, Query)

if isinstance(value, Query):
# rebuild a domain with an 'in' values
if positive_operator not in ('in', 'any'):
self._raise("Cannot filter using Query without the 'any' or 'in' operator")
if positive_operator == 'any':
operator = 'in' if positive_operator == operator else 'not in'
positive_operator = 'in'
value = set(value.get_result_ids())
return DomainCondition(field_expr, operator, value)._as_predicate(records)

field = self._field(records)
if field_expr == 'display_name':
# when searching by name, ignore AccessError
field_expr = 'display_name.no_error'
elif field_expr == 'id':
# for new records, compare to their origin
field_expr = 'id.origin'

func = field.filter_function(records, field_expr, positive_operator, value)
return func if positive_operator == operator else lambda rec: not func(rec)

def _to_sql(self, model: BaseModel, alias: str, query: Query) -> SQL:
return model._condition_to_sql(alias, self.field_expr, self.operator, self.value, query)

Expand Down Expand Up @@ -1104,6 +1197,8 @@ def _optimize_in_required(condition, model):
field.falsy_value is None
and field.required
and field in model.env.registry.not_null_fields
# only optimize if there are no NewId's
and all(model._ids)
):
value = OrderedSet(v for v in value if v is not False)
if len(value) == len(condition.value):
Expand Down
Loading