diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 467e79f9b9f7..b8648d76cdb6 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -1,6 +1,7 @@ import copy import inspect -from functools import wraps +import types +from functools import partial, wraps from importlib import import_module from django.db import router @@ -90,7 +91,8 @@ def manager_method(self, *args, **kwargs): new_methods = {} for name, method in inspect.getmembers( - queryset_class, predicate=inspect.isfunction + queryset_class, + predicate=lambda member: isinstance(member, (types.FunctionType, partial)), ): # Only copy missing methods. if hasattr(cls, name): diff --git a/django/db/models/query.py b/django/db/models/query.py index 73d6717bce23..6a26a1f41e17 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -29,7 +29,7 @@ from django.db.models.expressions import Case, DatabaseDefault, F, Value, When from django.db.models.fetch_modes import FETCH_ONE from django.db.models.functions import Cast, Trunc -from django.db.models.query_utils import FilteredRelation, Q +from django.db.models.query_utils import FilteredRelation, Q, class_or_instance_method from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, ROW_COUNT from django.db.models.utils import ( AltersData, @@ -300,6 +300,8 @@ def __iter__(self): class QuerySet(AltersData): """Represent a lazy database lookup for a set of objects.""" + _initial_filter = None + def __init__(self, model=None, query=None, using=None, hints=None): self.model = model self._db = using @@ -323,6 +325,8 @@ def query(self): negate, args, kwargs = self._deferred_filter self._filter_or_exclude_inplace(negate, args, kwargs) self._deferred_filter = None + if self._initial_filter is not None: + self._query.add_q(self._initial_filter) return self._query @query.setter @@ -1618,7 +1622,26 @@ def all(self): """ return self._chain() - def filter(self, *args, **kwargs): + def _class_filter(cls, *args, **kwargs): + if invalid_kwargs := PROHIBITED_FILTER_KWARGS.intersection(kwargs): + invalid_kwargs_str = ", ".join(f"'{k}'" for k in sorted(invalid_kwargs)) + raise TypeError(f"The following kwargs are invalid: {invalid_kwargs_str}") + initial_filter = Q(*args, **kwargs) + # Chain initial filters. + if cls._initial_filter is not None: + initial_filter = cls._initial_filter & initial_filter + bases = cls.__bases__ + class_name = cls.__name__ + else: + bases = (cls,) + class_name = f"{cls.__name__}WithFilter" + return type( + class_name, + bases, + {"_initial_filter": initial_filter}, + ) + + def _instance_filter(self, *args, **kwargs): """ Return a new QuerySet instance with the args ANDed to the existing set. @@ -1626,6 +1649,9 @@ def filter(self, *args, **kwargs): self._not_support_combined_queries("filter") return self._filter_or_exclude(False, args, kwargs) + filter = class_or_instance_method(_class_filter, _instance_filter) + _class_filter = classmethod(_class_filter) + def exclude(self, *args, **kwargs): """ Return a new QuerySet instance with NOT (args) ANDed to the existing diff --git a/docs/releases/6.1.txt b/docs/releases/6.1.txt index 1c533f1341fe..358a03c38fd1 100644 --- a/docs/releases/6.1.txt +++ b/docs/releases/6.1.txt @@ -263,6 +263,32 @@ Models :ref:`negative array indexing ` on Oracle 21c+. +* The :class:`.QuerySet` class now allows defining initial filters. For + example:: + + from django.db import models + + + class Book(models.Model): + published = models.BooleanField() + ... + + published_objects = models.QuerySet.filter(published=True).as_manager() + + or:: + + from django.db import models + + + class BookQuerySet(models.QuerySet): ... + + + class Book(models.Model): + published = models.BooleanField() + ... + + published_objects = BookQuerySet.filter(published=True).as_manager() + Pagination ~~~~~~~~~~ diff --git a/tests/custom_managers/models.py b/tests/custom_managers/models.py index 53a07c462df6..29c9ef30a39d 100644 --- a/tests/custom_managers/models.py +++ b/tests/custom_managers/models.py @@ -140,6 +140,11 @@ class FunPerson(models.Model): objects = FunPeopleManager() +class BookQuerySet(models.QuerySet): + def authors_a(self): + return self.filter(author__istartswith="a") + + class Book(models.Model): title = models.CharField(max_length=50) author = models.CharField(max_length=30) @@ -159,6 +164,18 @@ class Book(models.Model): published_objects = PublishedBookManager() annotated_objects = AnnotatedBookManager() + # Custom querysets with initial filters. + published_objects_from_qs = BookQuerySet.filter(is_published=True).as_manager() + not_published_objects_from_qs = BookQuerySet.filter(is_published=False).as_manager() + not_published_objects_q_from_qs = models.QuerySet.filter( + ~models.Q(is_published=True) + ).as_manager() + # Chain initial filters + published_title_t_objects = ( + BookQuerySet.filter(is_published=True) + .filter(title__istartswith="T") + .as_manager() + ) class Meta: base_manager_name = "annotated_objects" diff --git a/tests/custom_managers/tests.py b/tests/custom_managers/tests.py index 3d9485c13b44..3a104a473489 100644 --- a/tests/custom_managers/tests.py +++ b/tests/custom_managers/tests.py @@ -636,6 +636,39 @@ def test_abstract_model_with_custom_manager_name(self): lambda c: c.objects, ) + def test_queryset_initial_filter(self): + b3 = Book.published_objects.create( + title="The Dark Tower: The Gunslinger", + author="Stephen King", + is_published=True, + ) + + self.assertCountEqual(Book.published_objects_from_qs.all(), [self.b1, b3]) + self.assertSequenceEqual(Book.not_published_objects_from_qs.all(), [self.b2]) + self.assertSequenceEqual(Book.not_published_objects_q_from_qs.all(), [self.b2]) + self.assertSequenceEqual(Book.published_objects_from_qs.authors_a(), []) + self.assertSequenceEqual( + Book.not_published_objects_from_qs.authors_a(), [self.b2] + ) + self.assertSequenceEqual(Book.published_title_t_objects.all(), [b3]) + + def test_queryset_initial_filter_invalid_argument(self): + msg = "The following kwargs are invalid: '_connector', '_negated'" + with self.assertRaisesMessage(TypeError, msg): + models.QuerySet.filter(pk=1, _negated=True, _connector="evil") + + def test_queryset_initial_filter_chained(self): + objects_published = models.QuerySet.filter(published=True) + objects_published_title_a = objects_published.filter(title__istartswith="A") + self.assertEqual(objects_published._initial_filter, models.Q(published=True)) + self.assertTrue(issubclass(objects_published, models.QuerySet)) + self.assertEqual( + objects_published_title_a._initial_filter, + models.Q(published=True) & models.Q(title__istartswith="A"), + ) + self.assertTrue(issubclass(objects_published, models.QuerySet)) + self.assertFalse(issubclass(objects_published_title_a, objects_published)) + class TestCars(TestCase): def test_managers(self):