Skip to content

Commit

Permalink
- Added RandomCharField - Prepopulates random character string.
Browse files Browse the repository at this point in the history
- Factored out query logic to ensure uniqueness to BaseUniqueField.
- The field value factories have been moved to generators.
  • Loading branch information
Derrick Petzold committed May 17, 2015
1 parent 78b52d1 commit c508e66
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 45 deletions.
217 changes: 174 additions & 43 deletions django_extensions/db/fields/__init__.py
@@ -1,8 +1,10 @@
"""
Django Extensions additional model fields
"""
import random
import re
import six
import string
import warnings

try:
Expand Down Expand Up @@ -34,7 +36,46 @@
from django.utils.encoding import force_text as force_unicode # NOQA


class AutoSlugField(SlugField):
MAX_UNIQUE_QUERY_ATTEMPTS = 100
random_sample = random.SystemRandom().sample


class BaseUniqueField(object):

def check_is_bool(self, attrname):
if not isinstance(getattr(self, attrname), bool):
raise ValueError("'{}' argument must be True or False".format(attrname))

def get_queryset(self, model_cls, slug_field):
for field, model in model_cls._meta.get_fields_with_model():
if model and field == slug_field:
return model._default_manager.all()
return model_cls._default_manager.all()

def find_unique(self, model_instance, field, iterator, *args):
# exclude the current model instance from the queryset used in finding
# next valid hash
queryset = self.get_queryset(model_instance.__class__, field)
if model_instance.pk:
queryset = queryset.exclude(pk=model_instance.pk)

# form a kwarg dict used to impliment any unique_together contraints
kwargs = {}
for params in model_instance._meta.unique_together:
if self.attname in params:
for param in params:
kwargs[param] = getattr(model_instance, param, None)

new = six.next(iterator)
kwargs[self.attname] = new
while not new or queryset.filter(**kwargs):
new = six.next(iterator)
kwargs[self.attname] = new
setattr(model_instance, self.attname, new)
return new


class AutoSlugField(BaseUniqueField, SlugField):
""" AutoSlugField
By default, sets editable=False, blank=True.
Expand Down Expand Up @@ -68,11 +109,9 @@ def __init__(self, *args, **kwargs):
self.slugify_function = kwargs.pop('slugify_function', slugify)
self.separator = kwargs.pop('separator', six.u('-'))
self.overwrite = kwargs.pop('overwrite', False)
if not isinstance(self.overwrite, bool):
raise ValueError("'overwrite' argument must be True or False")
self.check_is_bool('overwrite')
self.allow_duplicates = kwargs.pop('allow_duplicates', False)
if not isinstance(self.allow_duplicates, bool):
raise ValueError("'allow_duplicates' argument must be True or False")
self.check_is_bool('allow_duplicates')
super(AutoSlugField, self).__init__(*args, **kwargs)

def _slug_strip(self, value):
Expand All @@ -87,17 +126,25 @@ def _slug_strip(self, value):
value = re.sub('%s+' % re_sep, self.separator, value)
return re.sub(r'^%s+|%s+$' % (re_sep, re_sep), '', value)

def get_queryset(self, model_cls, slug_field):
for field, model in model_cls._meta.get_fields_with_model():
if model and field == slug_field:
return model._default_manager.all()
return model_cls._default_manager.all()

def slugify_func(self, content):
if content:
return self.slugify_function(content)
return ''

def slug_generator(self, original_slug, start):
yield original_slug
for i in range(start, MAX_UNIQUE_QUERY_ATTEMPTS):
slug = original_slug
end = '%s%s' % (self.separator, i)
end_len = len(end)
if self.slug_len and len(slug) + end_len > self.slug_len:
slug = slug[:self.slug_len - end_len]
slug = self._slug_strip(slug)
slug = '%s%s' % (slug, end)
yield slug
raise RuntimeError('max slug attempts for %s exceeded (%s)' %
(original_slug, MAX_UNIQUE_QUERY_ATTEMPTS))

def create_slug(self, model_instance, add):
# get fields to populate from and slug field to set
if not isinstance(self._populate_from, (list, tuple)):
Expand All @@ -108,7 +155,7 @@ def create_slug(self, model_instance, add):
# slugify the original field content and set next step to 2
slug_for_field = lambda field: self.slugify_func(getattr(model_instance, field))
slug = self.separator.join(map(slug_for_field, self._populate_from))
next = 2
start = 2
else:
# get slug from the current model instance
slug = getattr(model_instance, self.attname)
Expand All @@ -118,46 +165,20 @@ def create_slug(self, model_instance, add):

# strip slug depending on max_length attribute of the slug field
# and clean-up
slug_len = slug_field.max_length
if slug_len:
slug = slug[:slug_len]
self.slug_len = slug_field.max_length
if self.slug_len:
slug = slug[:self.slug_len]
slug = self._slug_strip(slug)
original_slug = slug

if self.allow_duplicates:
return slug

# exclude the current model instance from the queryset used in finding
# the next valid slug
queryset = self.get_queryset(model_instance.__class__, slug_field)
if model_instance.pk:
queryset = queryset.exclude(pk=model_instance.pk)

# form a kwarg dict used to impliment any unique_together contraints
kwargs = {}
for params in model_instance._meta.unique_together:
if self.attname in params:
for param in params:
kwargs[param] = getattr(model_instance, param, None)
kwargs[self.attname] = slug

# increases the number while searching for the next valid slug
# depending on the given slug, clean-up
while not slug or queryset.filter(**kwargs):
slug = original_slug
end = '%s%s' % (self.separator, next)
end_len = len(end)
if slug_len and len(slug) + end_len > slug_len:
slug = slug[:slug_len - end_len]
slug = self._slug_strip(slug)
slug = '%s%s' % (slug, end)
kwargs[self.attname] = slug
next += 1
return slug
return super(AutoSlugField, self).find_unique(
model_instance, slug_field, self.slug_generator(original_slug, start))

def pre_save(self, model_instance, add):
value = force_unicode(self.create_slug(model_instance, add))
setattr(model_instance, self.attname, value)
return value

def get_internal_type(self):
Expand Down Expand Up @@ -190,6 +211,116 @@ def deconstruct(self):
return name, path, args, kwargs


class RandomCharField(BaseUniqueField, CharField):
""" RandomCharField
By default, sets editable=False, blank=True.
Required arguments:
length
Specifies the length of the field
Optional arguments:
lowercase
If set to True, lowercase the alpha characters (default: False)
include_alpha
If set to True, include alpha characters (default: True)
include_digits
If set to True, include digit characters (default: True)
include_punctuation
If set to True, include punctuation characters (default: True)
"""
def __init__(self, *args, **kwargs):
kwargs.setdefault('blank', True)
kwargs.setdefault('editable', False)

self.length = kwargs.pop('length', None)
if self.length is None:
raise ValueError("missing 'length' argument")
kwargs['max_length'] = self.length

self.lowercase = kwargs.pop('lowercase', False)
self.check_is_bool('lowercase')
self.include_digits = kwargs.pop('include_digits', True)
self.check_is_bool('include_digits')
self.include_alpha = kwargs.pop('include_alpha', True)
self.check_is_bool('include_alpha')
self.include_punctuation = kwargs.pop('include_punctuation', False)
self.check_is_bool('include_punctuation')

# Set db_index=True unless it's been set manually.
if 'db_index' not in kwargs:
kwargs['db_index'] = True

super(RandomCharField, self).__init__(*args, **kwargs)

def random_char_generator(self, chars):
for i in range(100):
yield ''.join(random_sample(chars, self.length))
raise RuntimeError('max random character attempts exceeded (%s)' %
MAX_UNIQUE_QUERY_ATTEMPTS)

def pre_save(self, model_instance, add):
if not add:
return getattr(model_instance, self.attname)

population = ''
if self.include_alpha:
if self.lowercase:
population += string.ascii_lowercase
else:
population += string.ascii_letters

if self.include_digits:
population += string.digits

if self.include_punctuation:
population += string.punctuation

return super(RandomCharField, self).find_unique(
model_instance,
model_instance._meta.get_field(self.attname),
self.random_char_generator(population),
)

def internal_type(self):
return "CharField"

def south_field_triple(self):
"Returns a suitable description of this field for South."
# We'll just introspect the _actual_ field.
from south.modelsinspector import introspector
field_class = '%s.RandomCharField' % self.__module__
args, kwargs = introspector(self)
kwargs.update({
'lowercase': repr(self.lowercase),
'include_digits': repr(self.include_digits),
'include_aphla': repr(self.include_alpha),
'include_punctuation': repr(self.include_punctuation),
'length': repr(self.length),
})
# That's our definition!
return (field_class, args, kwargs)

def deconstruct(self):
name, path, args, kwargs = super(RandomCharField, self).deconstruct()
kwargs['length'] = self.length
if self.lowercase is not True:
kwargs['lowercase'] = False
if self.include_alpha is not True:
kwargs['include_alpha'] = False
if self.include_digits is not True:
kwargs['include_digits'] = False
if self.include_punctuation is not True:
kwargs['include_punctuation'] = False
return name, path, args, kwargs


class CreationDateTimeField(DateTimeField):
""" CreationDateTimeField
Expand Down
18 changes: 18 additions & 0 deletions docs/field_extensions.rst
Expand Up @@ -11,6 +11,24 @@ Current Database Model Field Extensions
incrementing an appended number on the slug until it is unique. Inspired by
SmileyChris' Unique Slugify snippet.

* *RandomCharField* - AutoRandomCharField will automatically create a
unique random character field with the specified length. By default
upper/lower case and digits are included as possible characters. Given
a length of 8 thats yields 3.4 million possible combinations. A 12
character field would yield about 2 billion. Below are some examples::

>>> RandomCharField(length=8)
BVm9GEaE

>>> RandomCharField(length=4, digits_only=True)
7097

>>> RandomCharField(length=12, include_punctuation=True)
k[ZS.TR,0LHO

>>> RandomCharField(length=12, lower=True, alpha_only=True)
pzolbemetmok

* *CreationDateTimeField* - DateTimeField that will automatically set its date
when the object is first saved to the database. Works in the same way as the
auto_now_add keyword.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -129,7 +129,7 @@ def fullsplit(path, result=None):
cmdclass=cmdclasses,
package_data=package_data,
install_requires=['six>=1.2'],
tests_require=['Django', 'shortuuid', 'python-dateutil', 'pytest', 'tox'],
tests_require=['Django', 'shortuuid', 'python-dateutil', 'pytest', 'tox', 'mock'],
classifiers=[
'Development Status :: 4 - Beta',
'Development Status :: 5 - Production/Stable',
Expand Down
78 changes: 78 additions & 0 deletions tests/test_randomchar_field.py
@@ -0,0 +1,78 @@
import mock
import string
import pytest

import django
from django.test import TestCase

from .testapp.models import (
RandomCharTestModel,
RandomCharTestModelLower,
RandomCharTestModelAlpha,
RandomCharTestModelDigits,
RandomCharTestModelPunctuation,
RandomCharTestModelLowerAlphaDigits,
)

if django.VERSION >= (1, 7):
from django.db import migrations # NOQA
from django.db.migrations.writer import MigrationWriter # NOQA
from django.utils import six # NOQA
import django_extensions # NOQA


class RandomCharFieldTest(TestCase):

def testRandomCharField(self):
m = RandomCharTestModel()
m.save()
assert len(m.random_char_field) == 8, m.random_char_field

def testRandomCharFieldLower(self):
m = RandomCharTestModelLower()
m.save()
for c in m.random_char_field:
assert c.islower(), m.random_char_field

def testRandomCharFieldAlpha(self):
m = RandomCharTestModelAlpha()
m.save()
for c in m.random_char_field:
assert c.isalpha(), m.random_char_field

def testRandomCharFieldDigits(self):
m = RandomCharTestModelDigits()
m.save()
for c in m.random_char_field:
assert c.isdigit(), m.random_char_field

def testRandomCharFieldPunctuation(self):
m = RandomCharTestModelPunctuation()
m.save()
for c in m.random_char_field:
assert c in string.punctuation, m.random_char_field

def testRandomCharTestModelLowerAlphaDigits(self):
m = RandomCharTestModelLowerAlphaDigits()
m.save()
for c in m.random_char_field:
assert c.isdigit() or (c.isalpha() and c.islower()), m.random_char_field

def testRandomCharTestModelDuplicate(self):
m = RandomCharTestModel()
m.save()
with mock.patch('django_extensions.db.fields.RandomCharField.random_char_generator') as func:
func.return_value = iter([m.random_char_field, 'aaa'])
m = RandomCharTestModel()
m.save()
assert m.random_char_field == 'aaa'

def testRandomCharTestModelAsserts(self):
with mock.patch('django_extensions.db.fields.random_sample') as mock_sample:
mock_sample.return_value = 'aaa'
m = RandomCharTestModel()
m.save()

m = RandomCharTestModel()
with pytest.raises(RuntimeError):
m.save()

0 comments on commit c508e66

Please sign in to comment.