diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 060e1be60576d..ab4c094c1cb9b 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -14,6 +14,7 @@ from django.core import checks, exceptions, validators from django.db import connection, connections, router from django.db.models.constants import LOOKUP_SEP +from django.db.models.enums import ChoicesMeta from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin from django.utils import timezone from django.utils.datastructures import DictWrapper @@ -216,6 +217,8 @@ def __init__( self.unique_for_date = unique_for_date self.unique_for_month = unique_for_month self.unique_for_year = unique_for_year + if isinstance(choices, ChoicesMeta): + choices = choices.choices if isinstance(choices, collections.abc.Iterator): choices = list(choices) self.choices = choices diff --git a/django/forms/fields.py b/django/forms/fields.py index 46de2f53a00ec..003fb5ca6be38 100644 --- a/django/forms/fields.py +++ b/django/forms/fields.py @@ -16,6 +16,7 @@ from django.core import validators from django.core.exceptions import ValidationError +from django.db.models.enums import ChoicesMeta from django.forms.boundfield import BoundField from django.forms.utils import from_current_timezone, to_current_timezone from django.forms.widgets import ( @@ -857,6 +858,8 @@ class ChoiceField(Field): def __init__(self, *, choices=(), **kwargs): super().__init__(**kwargs) + if isinstance(choices, ChoicesMeta): + choices = choices.choices self.choices = choices def __deepcopy__(self, memo): diff --git a/docs/ref/forms/fields.txt b/docs/ref/forms/fields.txt index 317a955a15df7..7d975a74d5480 100644 --- a/docs/ref/forms/fields.txt +++ b/docs/ref/forms/fields.txt @@ -431,7 +431,7 @@ For each field, we describe the default widget used if you don't specify .. attribute:: choices Either an :term:`iterable` of 2-tuples to use as choices for this - field, :ref:`enumeration ` choices, or a + field, :ref:`enumeration type `, or a callable that returns such an iterable. This argument accepts the same formats as the ``choices`` argument to a model field. See the :ref:`model field reference documentation on choices ` @@ -439,6 +439,11 @@ For each field, we describe the default widget used if you don't specify time the field's form is initialized, in addition to during rendering. Defaults to an empty list. + .. versionchanged:: 5.0 + + Support for using :ref:`enumeration types ` + directly in the ``choices`` was added. + ``DateField`` ------------- diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index c1267507b2b64..447668bbc556a 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -210,7 +210,7 @@ choices in a concise way:: year_in_school = models.CharField( max_length=2, - choices=YearInSchool.choices, + choices=YearInSchool, default=YearInSchool.FRESHMAN, ) @@ -235,8 +235,7 @@ modifications: * A ``.label`` property is added on values, to return the human-readable name. * A number of custom properties are added to the enumeration classes -- ``.choices``, ``.labels``, ``.values``, and ``.names`` -- to make it easier - to access lists of those separate parts of the enumeration. Use ``.choices`` - as a suitable value to pass to :attr:`~Field.choices` in a field definition. + to access lists of those separate parts of the enumeration. .. warning:: @@ -276,7 +275,7 @@ Django provides an ``IntegerChoices`` class. For example:: HEART = 3 CLUB = 4 - suit = models.IntegerField(choices=Suit.choices) + suit = models.IntegerField(choices=Suit) It is also possible to make use of the `Enum Functional API `_ with the caveat @@ -320,6 +319,10 @@ There are some additional caveats to be aware of: __empty__ = _("(Unknown)") +.. versionchanged:: 5.0 + + Support for using enumeration types directly in the ``choices`` was added. + ``db_column`` ------------- diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 7dee6c4575a11..f5a5c568bcb96 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -167,7 +167,9 @@ File Uploads Forms ~~~~~ -* ... +* :attr:`.ChoiceField.choices` now accepts + :ref:`Choices classes ` directly instead of + requiring expansion with the ``choices`` attribute. Generic Views ~~~~~~~~~~~~~ @@ -208,6 +210,10 @@ Models of ``ValidationError`` raised during :ref:`model validation `. +* :attr:`.Field.choices` now accepts + :ref:`Choices classes ` directly instead of + requiring expansion with the ``choices`` attribute. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/topics/db/models.txt b/docs/topics/db/models.txt index fc0c927270025..33a515f14ff35 100644 --- a/docs/topics/db/models.txt +++ b/docs/topics/db/models.txt @@ -211,7 +211,7 @@ ones: class Runner(models.Model): MedalType = models.TextChoices("MedalType", "GOLD SILVER BRONZE") name = models.CharField(max_length=60) - medal = models.CharField(blank=True, choices=MedalType.choices, max_length=10) + medal = models.CharField(blank=True, choices=MedalType, max_length=10) Further examples are available in the :ref:`model field reference `. diff --git a/tests/forms_tests/field_tests/test_choicefield.py b/tests/forms_tests/field_tests/test_choicefield.py index bc580bbf02f1f..e7893abe57454 100644 --- a/tests/forms_tests/field_tests/test_choicefield.py +++ b/tests/forms_tests/field_tests/test_choicefield.py @@ -95,7 +95,8 @@ class FirstNames(models.TextChoices): JOHN = "J", "John" PAUL = "P", "Paul" - f = ChoiceField(choices=FirstNames.choices) + f = ChoiceField(choices=FirstNames) + self.assertEqual(f.choices, FirstNames.choices) self.assertEqual(f.clean("J"), "J") msg = "'Select a valid choice. 3 is not one of the available choices.'" with self.assertRaisesMessage(ValidationError, msg): diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index c6638b3083184..33b52bd5385d1 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -433,24 +433,20 @@ class DateChoices(datetime.date, models.Choices): DateChoices.DATE_1, ("datetime.date(1969, 7, 20)", {"import datetime"}), ) - field = models.CharField(default=TextChoices.B, choices=TextChoices.choices) + field = models.CharField(default=TextChoices.B, choices=TextChoices) string = MigrationWriter.serialize(field)[0] self.assertEqual( string, "models.CharField(choices=[('A', 'A value'), ('B', 'B value')], " "default='B')", ) - field = models.IntegerField( - default=IntegerChoices.B, choices=IntegerChoices.choices - ) + field = models.IntegerField(default=IntegerChoices.B, choices=IntegerChoices) string = MigrationWriter.serialize(field)[0] self.assertEqual( string, "models.IntegerField(choices=[(1, 'One'), (2, 'Two')], default=2)", ) - field = models.DateField( - default=DateChoices.DATE_2, choices=DateChoices.choices - ) + field = models.DateField(default=DateChoices.DATE_2, choices=DateChoices) string = MigrationWriter.serialize(field)[0] self.assertEqual( string, diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py index c35dfc2ebeb81..424a78746d54d 100644 --- a/tests/model_fields/models.py +++ b/tests/model_fields/models.py @@ -69,11 +69,18 @@ class WhizIterEmpty(models.Model): class Choiceful(models.Model): + class Suit(models.IntegerChoices): + DIAMOND = 1, "Diamond" + SPADE = 2, "Spade" + HEART = 3, "Heart" + CLUB = 4, "Club" + no_choices = models.IntegerField(null=True) empty_choices = models.IntegerField(choices=(), null=True) with_choices = models.IntegerField(choices=[(1, "A")], null=True) empty_choices_bool = models.BooleanField(choices=()) empty_choices_text = models.TextField(choices=()) + choices_from_enum = models.IntegerField(choices=Suit) class BigD(models.Model): diff --git a/tests/model_fields/test_charfield.py b/tests/model_fields/test_charfield.py index e9c1444f163fd..782158d210e39 100644 --- a/tests/model_fields/test_charfield.py +++ b/tests/model_fields/test_charfield.py @@ -75,11 +75,11 @@ def test_charfield_with_choices_raises_error_on_invalid_choice(self): f.clean("not a", None) def test_enum_choices_cleans_valid_string(self): - f = models.CharField(choices=self.Choices.choices, max_length=1) + f = models.CharField(choices=self.Choices, max_length=1) self.assertEqual(f.clean("c", None), "c") def test_enum_choices_invalid_input(self): - f = models.CharField(choices=self.Choices.choices, max_length=1) + f = models.CharField(choices=self.Choices, max_length=1) msg = "Value 'a' is not a valid choice." with self.assertRaisesMessage(ValidationError, msg): f.clean("a", None) diff --git a/tests/model_fields/test_integerfield.py b/tests/model_fields/test_integerfield.py index 7698160678da1..6761589b7ebbd 100644 --- a/tests/model_fields/test_integerfield.py +++ b/tests/model_fields/test_integerfield.py @@ -301,11 +301,11 @@ def test_integerfield_validates_zero_against_choices(self): f.clean("0", None) def test_enum_choices_cleans_valid_string(self): - f = models.IntegerField(choices=self.Choices.choices) + f = models.IntegerField(choices=self.Choices) self.assertEqual(f.clean("1", None), 1) def test_enum_choices_invalid_input(self): - f = models.IntegerField(choices=self.Choices.choices) + f = models.IntegerField(choices=self.Choices) with self.assertRaises(ValidationError): f.clean("A", None) with self.assertRaises(ValidationError): diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py index 6d4a91afa2257..fe8526a4800b9 100644 --- a/tests/model_fields/tests.py +++ b/tests/model_fields/tests.py @@ -156,6 +156,7 @@ def setUpClass(cls): cls.empty_choices_bool = Choiceful._meta.get_field("empty_choices_bool") cls.empty_choices_text = Choiceful._meta.get_field("empty_choices_text") cls.with_choices = Choiceful._meta.get_field("with_choices") + cls.choices_from_enum = Choiceful._meta.get_field("choices_from_enum") def test_choices(self): self.assertIsNone(self.no_choices.choices) @@ -192,6 +193,10 @@ def test_formfield(self): with self.subTest(field=field): self.assertIsInstance(field.formfield(), forms.ChoiceField) + def test_choices_from_enum(self): + # Choices class was transparently resolved when given as argument. + self.assertEqual(self.choices_from_enum.choices, Choiceful.Suit.choices) + class GetFieldDisplayTests(SimpleTestCase): def test_choices_and_field_display(self):