Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions django/db/models/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import inspect
from functools import wraps
import types
from functools import partial, wraps
from importlib import import_module

from django.db import router
Expand Down Expand Up @@ -90,7 +91,8 @@ def manager_method(self, *args, **kwargs):

new_methods = {}
for name, method in inspect.getmembers(
queryset_class, predicate=inspect.isfunction
queryset_class,
predicate=lambda member: isinstance(member, (types.FunctionType, partial)),
):
# Only copy missing methods.
if hasattr(cls, name):
Expand Down
30 changes: 28 additions & 2 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from django.db.models.expressions import Case, DatabaseDefault, F, Value, When
from django.db.models.fetch_modes import FETCH_ONE
from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.query_utils import FilteredRelation, Q, class_or_instance_method
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, ROW_COUNT
from django.db.models.utils import (
AltersData,
Expand Down Expand Up @@ -300,6 +300,8 @@ def __iter__(self):
class QuerySet(AltersData):
"""Represent a lazy database lookup for a set of objects."""

_initial_filter = None

def __init__(self, model=None, query=None, using=None, hints=None):
self.model = model
self._db = using
Expand All @@ -323,6 +325,8 @@ def query(self):
negate, args, kwargs = self._deferred_filter
self._filter_or_exclude_inplace(negate, args, kwargs)
self._deferred_filter = None
if self._initial_filter is not None:
self._query.add_q(self._initial_filter)
return self._query

@query.setter
Expand Down Expand Up @@ -1618,14 +1622,36 @@ def all(self):
"""
return self._chain()

def filter(self, *args, **kwargs):
def _class_filter(cls, *args, **kwargs):
if invalid_kwargs := PROHIBITED_FILTER_KWARGS.intersection(kwargs):
invalid_kwargs_str = ", ".join(f"'{k}'" for k in sorted(invalid_kwargs))
raise TypeError(f"The following kwargs are invalid: {invalid_kwargs_str}")
initial_filter = Q(*args, **kwargs)
# Chain initial filters.
if cls._initial_filter is not None:
initial_filter = cls._initial_filter & initial_filter
bases = cls.__bases__
class_name = cls.__name__
else:
bases = (cls,)
class_name = f"{cls.__name__}WithFilter"
return type(
class_name,
bases,
{"_initial_filter": initial_filter},
)

def _instance_filter(self, *args, **kwargs):
"""
Return a new QuerySet instance with the args ANDed to the existing
set.
"""
self._not_support_combined_queries("filter")
return self._filter_or_exclude(False, args, kwargs)

filter = class_or_instance_method(_class_filter, _instance_filter)
_class_filter = classmethod(_class_filter)

def exclude(self, *args, **kwargs):
"""
Return a new QuerySet instance with NOT (args) ANDed to the existing
Expand Down
26 changes: 26 additions & 0 deletions docs/releases/6.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,32 @@ Models
:ref:`negative array indexing <key-index-and-path-transforms>` on Oracle
21c+.

* The :class:`.QuerySet` class now allows defining initial filters. For
example::

from django.db import models


class Book(models.Model):
published = models.BooleanField()
...

published_objects = models.QuerySet.filter(published=True).as_manager()

or::

from django.db import models


class BookQuerySet(models.QuerySet): ...


class Book(models.Model):
published = models.BooleanField()
...

published_objects = BookQuerySet.filter(published=True).as_manager()

Pagination
~~~~~~~~~~

Expand Down
17 changes: 17 additions & 0 deletions tests/custom_managers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ class FunPerson(models.Model):
objects = FunPeopleManager()


class BookQuerySet(models.QuerySet):
def authors_a(self):
return self.filter(author__istartswith="a")


class Book(models.Model):
title = models.CharField(max_length=50)
author = models.CharField(max_length=30)
Expand All @@ -159,6 +164,18 @@ class Book(models.Model):

published_objects = PublishedBookManager()
annotated_objects = AnnotatedBookManager()
# Custom querysets with initial filters.
published_objects_from_qs = BookQuerySet.filter(is_published=True).as_manager()
not_published_objects_from_qs = BookQuerySet.filter(is_published=False).as_manager()
not_published_objects_q_from_qs = models.QuerySet.filter(
~models.Q(is_published=True)
).as_manager()
# Chain initial filters
published_title_t_objects = (
BookQuerySet.filter(is_published=True)
.filter(title__istartswith="T")
.as_manager()
)

class Meta:
base_manager_name = "annotated_objects"
Expand Down
33 changes: 33 additions & 0 deletions tests/custom_managers/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,39 @@ def test_abstract_model_with_custom_manager_name(self):
lambda c: c.objects,
)

def test_queryset_initial_filter(self):
b3 = Book.published_objects.create(
title="The Dark Tower: The Gunslinger",
author="Stephen King",
is_published=True,
)

self.assertCountEqual(Book.published_objects_from_qs.all(), [self.b1, b3])
self.assertSequenceEqual(Book.not_published_objects_from_qs.all(), [self.b2])
self.assertSequenceEqual(Book.not_published_objects_q_from_qs.all(), [self.b2])
self.assertSequenceEqual(Book.published_objects_from_qs.authors_a(), [])
self.assertSequenceEqual(
Book.not_published_objects_from_qs.authors_a(), [self.b2]
)
self.assertSequenceEqual(Book.published_title_t_objects.all(), [b3])

def test_queryset_initial_filter_invalid_argument(self):
msg = "The following kwargs are invalid: '_connector', '_negated'"
with self.assertRaisesMessage(TypeError, msg):
models.QuerySet.filter(pk=1, _negated=True, _connector="evil")

def test_queryset_initial_filter_chained(self):
objects_published = models.QuerySet.filter(published=True)
objects_published_title_a = objects_published.filter(title__istartswith="A")
self.assertEqual(objects_published._initial_filter, models.Q(published=True))
self.assertTrue(issubclass(objects_published, models.QuerySet))
self.assertEqual(
objects_published_title_a._initial_filter,
models.Q(published=True) & models.Q(title__istartswith="A"),
)
self.assertTrue(issubclass(objects_published, models.QuerySet))
self.assertFalse(issubclass(objects_published_title_a, objects_published))


class TestCars(TestCase):
def test_managers(self):
Expand Down