Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fixed #20348 -- Consistently handle Promise objects in model fields.

All Promise objects were passed to force_text() deep in ORM query code.
Not only does this make it difficult or impossible for developers to
prevent or alter this behaviour, but it is also wrong for non-text
fields.

This commit changes `Field.get_prep_value()` from a no-op to one that
resolved Promise objects. All subclasses now call super() method first
to ensure that they have a real value to work with.
  • Loading branch information...
commit 31e6d58d46894ca35080b4eab7967e4c6aae82d4 1 parent 8f5533a
Tai Lee authored akaariai committed
View
1  django/contrib/gis/db/models/fields.py
@@ -148,6 +148,7 @@ def get_prep_value(self, value):
value properly, and preserve any other lookup parameters before
returning to the caller.
"""
+ value = super(GeometryField, self).get_prep_value(value)
if isinstance(value, SQLEvaluator):
return value
elif isinstance(value, (tuple, list)):
View
16 django/db/models/fields/__init__.py
@@ -17,7 +17,7 @@
from django.core import exceptions, validators
from django.utils.datastructures import DictWrapper
from django.utils.dateparse import parse_date, parse_datetime, parse_time
-from django.utils.functional import curry, total_ordering
+from django.utils.functional import curry, total_ordering, Promise
from django.utils.text import capfirst
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
@@ -421,6 +421,8 @@ def get_prep_value(self, value):
"""
Perform preliminary non-db specific value checks and conversions.
"""
+ if isinstance(value, Promise):
+ value = value._proxy____cast()
return value
def get_db_prep_value(self, value, connection, prepared=False):
@@ -704,6 +706,7 @@ def get_db_prep_value(self, value, connection, prepared=False):
return value
def get_prep_value(self, value):
+ value = super(AutoField, self).get_prep_value(value)
if value is None:
return None
return int(value)
@@ -763,6 +766,7 @@ def get_prep_lookup(self, lookup_type, value):
return super(BooleanField, self).get_prep_lookup(lookup_type, value)
def get_prep_value(self, value):
+ value = super(BooleanField, self).get_prep_value(value)
if value is None:
return None
return bool(value)
@@ -796,6 +800,7 @@ def to_python(self, value):
return smart_text(value)
def get_prep_value(self, value):
+ value = super(CharField, self).get_prep_value(value)
return self.to_python(value)
def formfield(self, **kwargs):
@@ -911,6 +916,7 @@ def get_prep_lookup(self, lookup_type, value):
return super(DateField, self).get_prep_lookup(lookup_type, value)
def get_prep_value(self, value):
+ value = super(DateField, self).get_prep_value(value)
return self.to_python(value)
def get_db_prep_value(self, value, connection, prepared=False):
@@ -1008,6 +1014,7 @@ def pre_save(self, model_instance, add):
# get_prep_lookup is inherited from DateField
def get_prep_value(self, value):
+ value = super(DateTimeField, self).get_prep_value(value)
value = self.to_python(value)
if value is not None and settings.USE_TZ and timezone.is_naive(value):
# For backwards compatibility, interpret naive datetimes in local
@@ -1096,6 +1103,7 @@ def get_db_prep_save(self, value, connection):
self.max_digits, self.decimal_places)
def get_prep_value(self, value):
+ value = super(DecimalField, self).get_prep_value(value)
return self.to_python(value)
def formfield(self, **kwargs):
@@ -1185,6 +1193,7 @@ class FloatField(Field):
description = _("Floating point number")
def get_prep_value(self, value):
+ value = super(FloatField, self).get_prep_value(value)
if value is None:
return None
return float(value)
@@ -1218,6 +1227,7 @@ class IntegerField(Field):
description = _("Integer")
def get_prep_value(self, value):
+ value = super(IntegerField, self).get_prep_value(value)
if value is None:
return None
return int(value)
@@ -1326,6 +1336,7 @@ def get_db_prep_value(self, value, connection, prepared=False):
return value or None
def get_prep_value(self, value):
+ value = super(GenericIPAddressField, self).get_prep_value(value)
if value and ':' in value:
try:
return clean_ipv6_address(value, self.unpack_ipv4)
@@ -1391,6 +1402,7 @@ def get_prep_lookup(self, lookup_type, value):
value)
def get_prep_value(self, value):
+ value = super(NullBooleanField, self).get_prep_value(value)
if value is None:
return None
return bool(value)
@@ -1473,6 +1485,7 @@ def get_internal_type(self):
return "TextField"
def get_prep_value(self, value):
+ value = super(TextField, self).get_prep_value(value)
if isinstance(value, six.string_types) or value is None:
return value
return smart_text(value)
@@ -1549,6 +1562,7 @@ def pre_save(self, model_instance, add):
return super(TimeField, self).pre_save(model_instance, add)
def get_prep_value(self, value):
+ value = super(TimeField, self).get_prep_value(value)
return self.to_python(value)
def get_db_prep_value(self, value, connection, prepared=False):
View
1  django/db/models/fields/files.py
@@ -253,6 +253,7 @@ def get_prep_lookup(self, lookup_type, value):
def get_prep_value(self, value):
"Returns field's value prepared for saving into a database."
+ value = super(FileField, self).get_prep_value(value)
# Need to convert File objects provided via a form to unicode for database insertion
if value is None:
return None
View
12 django/db/models/sql/subqueries.py
@@ -11,8 +11,6 @@
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, Constraint
-from django.utils.functional import Promise
-from django.utils.encoding import force_text
from django.utils import six
from django.utils import timezone
@@ -147,10 +145,6 @@ def add_update_fields(self, values_seq):
Used by add_update_values() as well as the "fast" update path when
saving models.
"""
- # Check that no Promise object passes to the query. Refs #10498.
- values_seq = [(value[0], value[1], force_text(value[2]))
- if isinstance(value[2], Promise) else value
- for value in values_seq]
self.values.extend(values_seq)
def add_related_update(self, model, field, value):
@@ -210,12 +204,6 @@ def insert_values(self, fields, objs, raw=False):
into the query, for example.
"""
self.fields = fields
- # Check that no Promise object reaches the DB. Refs #10498.
- for field in fields:
- for obj in objs:
- value = getattr(obj, field.attname)
- if isinstance(value, Promise):
- setattr(obj, field.attname, force_text(value))
self.objs = objs
self.raw = raw
View
191 tests/model_fields/tests.py
@@ -8,11 +8,21 @@
from django import forms
from django.core.exceptions import ValidationError
from django.db import connection, models, IntegrityError
-from django.db.models.fields.files import FieldFile
+from django.db.models.fields import (
+ AutoField, BigIntegerField, BinaryField, BooleanField, CharField,
+ CommaSeparatedIntegerField, DateField, DateTimeField, DecimalField,
+ EmailField, FilePathField, FloatField, IntegerField, IPAddressField,
+ GenericIPAddressField, NullBooleanField, PositiveIntegerField,
+ PositiveSmallIntegerField, SlugField, SmallIntegerField, TextField,
+ TimeField, URLField)
+from django.db.models.fields.files import FileField, ImageField
from django.utils import six
+from django.utils.functional import lazy
+from django.utils.unittest import skipIf
-from .models import (Foo, Bar, Whiz, BigD, BigS, Image, BigInt, Post,
- NullBooleanModel, BooleanModel, DataModel, Document, RenamedField,
+from .models import (
+ Foo, Bar, Whiz, BigD, BigS, BigInt, Post, NullBooleanModel,
+ BooleanModel, DataModel, Document, RenamedField,
VerboseNameField, FksToBooleans)
@@ -64,7 +74,7 @@ def test_field_verbose_name(self):
m = VerboseNameField
for i in range(1, 23):
self.assertEqual(m._meta.get_field('field%d' % i).verbose_name,
- 'verbose field%d' % i)
+ 'verbose field%d' % i)
self.assertEqual(m._meta.get_field('id').verbose_name, 'verbose pk')
@@ -290,9 +300,9 @@ def test_slugfield_max_length(self):
"""
Make sure SlugField honors max_length (#9706)
"""
- bs = BigS.objects.create(s = 'slug'*50)
+ bs = BigS.objects.create(s='slug' * 50)
bs = BigS.objects.get(pk=bs.pk)
- self.assertEqual(bs.s, 'slug'*50)
+ self.assertEqual(bs.s, 'slug' * 50)
class ValidationTest(test.TestCase):
@@ -313,15 +323,17 @@ def test_integerfield_raises_error_on_invalid_intput(self):
self.assertRaises(ValidationError, f.clean, "a", None)
def test_charfield_with_choices_cleans_valid_choice(self):
- f = models.CharField(max_length=1, choices=[('a','A'), ('b','B')])
+ f = models.CharField(max_length=1,
+ choices=[('a', 'A'), ('b', 'B')])
self.assertEqual('a', f.clean('a', None))
def test_charfield_with_choices_raises_error_on_invalid_choice(self):
- f = models.CharField(choices=[('a','A'), ('b','B')])
+ f = models.CharField(choices=[('a', 'A'), ('b', 'B')])
self.assertRaises(ValidationError, f.clean, "not a", None)
def test_choices_validation_supports_named_groups(self):
- f = models.IntegerField(choices=(('group',((10,'A'),(20,'B'))),(30,'C')))
+ f = models.IntegerField(
+ choices=(('group', ((10, 'A'), (20, 'B'))), (30, 'C')))
self.assertEqual(10, f.clean(10, None))
def test_nullable_integerfield_raises_error_with_blank_false(self):
@@ -370,7 +382,7 @@ def test_limits(self):
self.assertEqual(qs[0].value, minval)
def test_types(self):
- b = BigInt(value = 0)
+ b = BigInt(value=0)
self.assertIsInstance(b.value, six.integer_types)
b.save()
self.assertIsInstance(b.value, six.integer_types)
@@ -378,8 +390,8 @@ def test_types(self):
self.assertIsInstance(b.value, six.integer_types)
def test_coercing(self):
- BigInt.objects.create(value ='10')
- b = BigInt.objects.get(value = '10')
+ BigInt.objects.create(value='10')
+ b = BigInt.objects.get(value='10')
self.assertEqual(b.value, 10)
class TypeCoercionTests(test.TestCase):
@@ -466,7 +478,7 @@ def test_set_and_retrieve(self):
test_set_and_retrieve = unittest.expectedFailure(test_set_and_retrieve)
def test_max_length(self):
- dm = DataModel(short_data=self.binary_data*4)
+ dm = DataModel(short_data=self.binary_data * 4)
self.assertRaises(ValidationError, dm.full_clean)
class GenericIPAddressFieldTests(test.TestCase):
@@ -481,3 +493,156 @@ def test_genericipaddressfield_formfield_protocol(self):
model_field = models.GenericIPAddressField(protocol='IPv6')
form_field = model_field.formfield()
self.assertRaises(ValidationError, form_field.clean, '127.0.0.1')
+
+
+class PromiseTest(test.TestCase):
+ def test_AutoField(self):
+ lazy_func = lazy(lambda: 1, int)
+ self.assertIsInstance(
+ AutoField(primary_key=True).get_prep_value(lazy_func()),
+ int)
+
+ @skipIf(six.PY3, "Python 3 has no `long` type.")
+ def test_BigIntegerField(self):
+ lazy_func = lazy(lambda: long(9999999999999999999), long)
+ self.assertIsInstance(
+ BigIntegerField().get_prep_value(lazy_func()),
+ long)
+
+ def test_BinaryField(self):
+ lazy_func = lazy(lambda: b'', bytes)
+ self.assertIsInstance(
+ BinaryField().get_prep_value(lazy_func()),
+ bytes)
+
+ def test_BooleanField(self):
+ lazy_func = lazy(lambda: True, bool)
+ self.assertIsInstance(
+ BooleanField().get_prep_value(lazy_func()),
+ bool)
+
+ def test_CharField(self):
+ lazy_func = lazy(lambda: '', six.text_type)
+ self.assertIsInstance(
+ CharField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_CommaSeparatedIntegerField(self):
+ lazy_func = lazy(lambda: '1,2', six.text_type)
+ self.assertIsInstance(
+ CommaSeparatedIntegerField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_DateField(self):
+ lazy_func = lazy(lambda: datetime.date.today(), datetime.date)
+ self.assertIsInstance(
+ DateField().get_prep_value(lazy_func()),
+ datetime.date)
+
+ def test_DateTimeField(self):
+ lazy_func = lazy(lambda: datetime.datetime.now(), datetime.datetime)
+ self.assertIsInstance(
+ DateTimeField().get_prep_value(lazy_func()),
+ datetime.datetime)
+
+ def test_DecimalField(self):
+ lazy_func = lazy(lambda: Decimal('1.2'), Decimal)
+ self.assertIsInstance(
+ DecimalField().get_prep_value(lazy_func()),
+ Decimal)
+
+ def test_EmailField(self):
+ lazy_func = lazy(lambda: 'mailbox@domain.com', six.text_type)
+ self.assertIsInstance(
+ EmailField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_FileField(self):
+ lazy_func = lazy(lambda: 'filename.ext', six.text_type)
+ self.assertIsInstance(
+ FileField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_FilePathField(self):
+ lazy_func = lazy(lambda: 'tests.py', six.text_type)
+ self.assertIsInstance(
+ FilePathField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_FloatField(self):
+ lazy_func = lazy(lambda: 1.2, float)
+ self.assertIsInstance(
+ FloatField().get_prep_value(lazy_func()),
+ float)
+
+ def test_ImageField(self):
+ lazy_func = lazy(lambda: 'filename.ext', six.text_type)
+ self.assertIsInstance(
+ ImageField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_IntegerField(self):
+ lazy_func = lazy(lambda: 1, int)
+ self.assertIsInstance(
+ IntegerField().get_prep_value(lazy_func()),
+ int)
+
+ def test_IPAddressField(self):
+ lazy_func = lazy(lambda: '127.0.0.1', six.text_type)
+ self.assertIsInstance(
+ IPAddressField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_GenericIPAddressField(self):
+ lazy_func = lazy(lambda: '127.0.0.1', six.text_type)
+ self.assertIsInstance(
+ GenericIPAddressField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_NullBooleanField(self):
+ lazy_func = lazy(lambda: True, bool)
+ self.assertIsInstance(
+ NullBooleanField().get_prep_value(lazy_func()),
+ bool)
+
+ def test_PositiveIntegerField(self):
+ lazy_func = lazy(lambda: 1, int)
+ self.assertIsInstance(
+ PositiveIntegerField().get_prep_value(lazy_func()),
+ int)
+
+ def test_PositiveSmallIntegerField(self):
+ lazy_func = lazy(lambda: 1, int)
+ self.assertIsInstance(
+ PositiveSmallIntegerField().get_prep_value(lazy_func()),
+ int)
+
+ def test_SlugField(self):
+ lazy_func = lazy(lambda: 'slug', six.text_type)
+ self.assertIsInstance(
+ SlugField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_SmallIntegerField(self):
+ lazy_func = lazy(lambda: 1, int)
+ self.assertIsInstance(
+ SmallIntegerField().get_prep_value(lazy_func()),
+ int)
+
+ def test_TextField(self):
+ lazy_func = lazy(lambda: 'Abc', six.text_type)
+ self.assertIsInstance(
+ TextField().get_prep_value(lazy_func()),
+ six.text_type)
+
+ def test_TimeField(self):
+ lazy_func = lazy(lambda: datetime.datetime.now().time(), datetime.time)
+ self.assertIsInstance(
+ TimeField().get_prep_value(lazy_func()),
+ datetime.time)
+
+ def test_URLField(self):
+ lazy_func = lazy(lambda: 'http://domain.com', six.text_type)
+ self.assertIsInstance(
+ URLField().get_prep_value(lazy_func()),
+ six.text_type)
Please sign in to comment.
Something went wrong with that request. Please try again.