Skip to content

Commit

Permalink
Make assert_max_length support array type
Browse files Browse the repository at this point in the history
  • Loading branch information
kvesteri committed Feb 3, 2015
1 parent 9e85029 commit 2700984
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.


0.29.5 (2015-02-03)
^^^^^^^^^^^^^^^^^^^

- Made assert_max_length support PostgreSQL array type


0.29.4 (2015-01-31)
^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from .models import Timestamp


__version__ = '0.29.4'
__version__ = '0.29.5'


__all__ = (
Expand Down
64 changes: 61 additions & 3 deletions sqlalchemy_utils/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class User(Base):
# raises AssertionError because the max length of email is 255
assert_max_length(user, 'email', 300)
"""
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.exc import DataError, IntegrityError


Expand Down Expand Up @@ -71,6 +73,27 @@ def _expect_failing_update(obj, field, value, expected_exc):
session.rollback()


def _repeated_value(type_):
if isinstance(type_, ARRAY):
if isinstance(type_.item_type, sa.Integer):
return [0]
elif isinstance(type_.item_type, sa.String):
return [u'a']
elif isinstance(type_.item_type, sa.Numeric):
return [Decimal('0')]
else:
raise TypeError('Unknown array item type')
else:
return u'a'


def _expected_exception(type_):
if isinstance(type_, ARRAY):
return IntegrityError
else:
return DataError


def assert_nullable(obj, column):
"""
Assert that given column is nullable. This is checked by running an SQL
Expand All @@ -95,14 +118,49 @@ def assert_non_nullable(obj, column):

def assert_max_length(obj, column, max_length):
"""
Assert that the given column is of given max length.
Assert that the given column is of given max length. This function supports
string typed columns as well as PostgreSQL array typed columns.
In the following example we add a check constraint that user can have a
maximum of 5 favorite colors and then test this.::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
favorite_colors = sa.Column(ARRAY(sa.String), nullable=False)
__table_args__ = (
sa.CheckConstraint(
sa.func.array_length(favorite_colors, 1) <= 5
)
)
user = User(name='John Doe', favorite_colors=['red', 'blue'])
session.add(user)
session.commit()
assert_max_length(user, 'favorite_colors', 5)
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param max_length: Maximum length of given column
"""
_expect_successful_update(obj, column, u'a' * max_length, DataError)
_expect_failing_update(obj, column, u'a' * (max_length + 1), DataError)
type_ = sa.inspect(obj.__class__).columns[column].type
_expect_successful_update(
obj,
column,
_repeated_value(type_) * max_length,
_expected_exception(type_)
)
_expect_failing_update(
obj,
column,
_repeated_value(type_) * (max_length + 1),
_expected_exception(type_)
)


def assert_min_value(obj, column, min_value):
Expand Down
32 changes: 31 additions & 1 deletion tests/test_asserts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sqlalchemy as sa
import pytest
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy_utils import (
assert_min_value,
assert_max_length,
Expand Down Expand Up @@ -34,21 +35,50 @@ class User(self.Base):
name = sa.Column(sa.String(20))
age = sa.Column(sa.Integer, nullable=False)
email = sa.Column(sa.String(200), nullable=False, unique=True)
fav_numbers = sa.Column(ARRAY(sa.Integer))

__table_args__ = (
sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
sa.CheckConstraint(
sa.and_(
sa.func.array_length(fav_numbers, 1) <= 8
)
)
)

self.User = User

def setup_method(self, method):
TestCase.setup_method(self, method)
user = self.User(name='Someone', email='someone@example.com', age=15)
user = self.User(
name='Someone',
email='someone@example.com',
age=15,
fav_numbers=[1, 2, 3]
)
self.session.add(user)
self.session.commit()
self.user = user


class TestAssertMaxLengthWithArray(AssertionTestCase):
def test_with_max_length(self):
assert_max_length(self.user, 'fav_numbers', 8)
assert_max_length(self.user, 'fav_numbers', 8)

def test_smaller_than_max_length(self):
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7)
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7)

def test_bigger_than_max_length(self):
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9)
with raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9)


class TestAssertNonNullable(AssertionTestCase):
def test_non_nullable_column(self):
# Test everything twice so that session gets rolled back properly
Expand Down

0 comments on commit 2700984

Please sign in to comment.