diff --git a/HISTORY.rst b/HISTORY.rst index 1fe4d21..a59741d 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,25 @@ History ------- +0.2.0 (2015-09-12) +~~~~~~~~~~~~~~~~~~ + +* Added `SQLAlchemy `_ support. +* ``FilterSet`` instances have much more useful ``__repr__`` which + shows all filters at a glance. For example:: + + >>> PlaceFilterSet() + PlaceFilterSet() + address = Filter(form_field=CharField, lookups=ALL, default_lookup="exact", is_default=False) + id = Filter(form_field=IntegerField, lookups=ALL, default_lookup="exact", is_default=True) + name = Filter(form_field=CharField, lookups=ALL, default_lookup="exact", is_default=False) + restaurant = RestaurantFilterSet() + serves_hot_dogs = Filter(form_field=BooleanField, lookups=ALL, default_lookup="exact", is_default=False) + serves_pizza = Filter(form_field=BooleanField, lookups=ALL, default_lookup="exact", is_default=False) + waiter = WaiterFilterSet() + id = Filter(form_field=IntegerField, lookups=ALL, default_lookup="exact", is_default=True) + name = Filter(form_field=CharField, lookups=ALL, default_lookup="exact", is_default=False) + 0.1.1 (2015-09-06) ~~~~~~~~~~~~~~~~~~ diff --git a/docs/api/url_filter.backends.base.rst b/docs/api/url_filter.backends.base.rst new file mode 100644 index 0000000..3683700 --- /dev/null +++ b/docs/api/url_filter.backends.base.rst @@ -0,0 +1,7 @@ +url_filter.backends.base module +=============================== + +.. automodule:: url_filter.backends.base + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/url_filter.backends.rst b/docs/api/url_filter.backends.rst index be8d6fd..0f9fccd 100644 --- a/docs/api/url_filter.backends.rst +++ b/docs/api/url_filter.backends.rst @@ -11,5 +11,7 @@ Submodules .. toctree:: + url_filter.backends.base url_filter.backends.django + url_filter.backends.sqlalchemy diff --git a/docs/api/url_filter.backends.sqlalchemy.rst b/docs/api/url_filter.backends.sqlalchemy.rst new file mode 100644 index 0000000..2e59373 --- /dev/null +++ b/docs/api/url_filter.backends.sqlalchemy.rst @@ -0,0 +1,7 @@ +url_filter.backends.sqlalchemy module +===================================== + +.. automodule:: url_filter.backends.sqlalchemy + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/url_filter.filtersets.rst b/docs/api/url_filter.filtersets.rst index 8aa2f16..0c4a7a5 100644 --- a/docs/api/url_filter.filtersets.rst +++ b/docs/api/url_filter.filtersets.rst @@ -13,4 +13,5 @@ Submodules url_filter.filtersets.base url_filter.filtersets.django + url_filter.filtersets.sqlalchemy diff --git a/docs/api/url_filter.filtersets.sqlalchemy.rst b/docs/api/url_filter.filtersets.sqlalchemy.rst new file mode 100644 index 0000000..f27ed03 --- /dev/null +++ b/docs/api/url_filter.filtersets.sqlalchemy.rst @@ -0,0 +1,7 @@ +url_filter.filtersets.sqlalchemy module +======================================= + +.. automodule:: url_filter.filtersets.sqlalchemy + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/big_picture.rst b/docs/big_picture.rst index cadbd22..33e8977 100644 --- a/docs/big_picture.rst +++ b/docs/big_picture.rst @@ -94,16 +94,16 @@ Filtering +++++++++ Since filtering is decoupled from the ``FilterSet``, the filtering honors -all full on a specified filter backend. The backend is very simple. +all go to a specified filter backend. The backend is very simple. It takes a list of filter specifications and a data to filter and its job is to filter that data as specified in the specifications. .. note:: - Currently we only support Django ORM filter backend but you can imagine - that any backend can be implemented. We plan to add support for SQLAlchemy - since, well, why not add it? Eventually filter backends can be added - for flat data-structures like filtering a vanilla Python lists or - filtering from exotic data-source like Mongo. + Currently we only support Django ORM and SQLAlchemy filter backends + but you can imagine that any backend can be implemented. + Eventually filter backends can be added for flat data-structures + like filtering a vanilla Python lists or even more exotic sources + like Mongo, Redis, etc. Steps ----- diff --git a/docs/history.rst b/docs/history.rst new file mode 100644 index 0000000..2506499 --- /dev/null +++ b/docs/history.rst @@ -0,0 +1 @@ +.. include:: ../HISTORY.rst diff --git a/docs/index.rst b/docs/index.rst index 3add778..e471e1d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,7 @@ Contents usage big_picture + history api/modules .. include:: ../README.rst diff --git a/docs/usage.rst b/docs/usage.rst index be3615a..f079124 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -67,6 +67,35 @@ Notable things: model = User fields = ['username', 'email', 'joined', 'profile'] +SQLAlchemy +---------- + +`SQLAlchemy `_ works very similar to how Django +backend works. For example:: + + from django import forms + from url_filter.backend.sqlalchemy import SQLAlchemyFilterBackend + from url_filter.filtersets.sqlalchemy import SQLAlchemyModelFilterSet + + class UserFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = User # this model should be SQLAlchemy model + fields = ['username', 'email', 'joined', 'profile'] + + fs = UserFilterSet(data=QueryDict(), queryset=session.query(User)) + fs.filter() + +Notable things: + +* this works exactly same as ``ModelFitlerSet`` so refer above for some of + general options. +* ``filter_backend_class`` **must** be provided since otherwise + ``DjangoFilterBackend`` will be used which will obviously not work + with SQLAlchemy models. +* ``queryset`` given to the queryset should be SQLAlchemy query object. + Integrations ------------ diff --git a/requirements-dev.txt b/requirements-dev.txt index bff2f31..7d309d3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,4 @@ -r requirements.txt -Sphinx -Werkzeug coverage django-extensions djangorestframework @@ -10,7 +8,10 @@ mock pytest pytest-cov pytest-django +sphinx sphinx-autobuild sphinx-rtd-theme +sqlalchemy tox watchdog +werkzeug diff --git a/test_project/alchemy.py b/test_project/alchemy.py new file mode 100644 index 0000000..a3f54e9 --- /dev/null +++ b/test_project/alchemy.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals + +from sqlalchemy.ext.declarative import declarative_base + + +Base = declarative_base() diff --git a/test_project/many_to_many/alchemy.py b/test_project/many_to_many/alchemy.py new file mode 100644 index 0000000..8fcaa77 --- /dev/null +++ b/test_project/many_to_many/alchemy.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals + +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import backref, relationship +from sqlalchemy.schema import ForeignKey, Table + +from ..alchemy import Base + + +class Publication(Base): + __tablename__ = 'many_to_many_publication' + id = Column(Integer, primary_key=True) + title = Column(String(30)) + + @property + def pk(self): + return self.id + + +publication_article_association_table = Table( + 'many_to_many_article_publications', + Base.metadata, + Column('id', Integer), + Column('publication_id', Integer, ForeignKey('many_to_many_publication.id')), + Column('article_id', Integer, ForeignKey('many_to_many_article.id')), +) + + +class Article(Base): + __tablename__ = 'many_to_many_article' + id = Column(Integer, primary_key=True) + headline = Column(String(100)) + + publications = relationship( + Publication, + secondary=publication_article_association_table, + backref=backref('articles', uselist=True), + uselist=True, + ) + + @property + def pk(self): + return self.id diff --git a/test_project/many_to_many/api.py b/test_project/many_to_many/api.py index 2fe33e3..b958db6 100644 --- a/test_project/many_to_many/api.py +++ b/test_project/many_to_many/api.py @@ -2,10 +2,13 @@ from __future__ import print_function, unicode_literals from rest_framework.serializers import ModelSerializer -from rest_framework.viewsets import ModelViewSet +from rest_framework.viewsets import ReadOnlyModelViewSet +from url_filter.backends.sqlalchemy import SQLAlchemyFilterBackend from url_filter.filtersets import ModelFilterSet +from url_filter.filtersets.sqlalchemy import SQLAlchemyModelFilterSet +from . import alchemy from .models import Article, Publication @@ -39,10 +42,11 @@ class Meta(object): model = Publication -class PublicationViewSet(ModelViewSet): - queryset = Publication.objects.all() - serializer_class = PublicationNestedSerializer - filter_class = PublicationFilterSet +class SQLAlchemyPublicationFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = alchemy.Publication class ArticleFilterSet(ModelFilterSet): @@ -50,7 +54,36 @@ class Meta(object): model = Article -class ArticleViewSet(ModelViewSet): +class SQLAlchemyArticleFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = alchemy.Article + + +class PublicationViewSet(ReadOnlyModelViewSet): + queryset = Publication.objects.all() + serializer_class = PublicationNestedSerializer + filter_class = PublicationFilterSet + + +class SQLAlchemyPublicationViewSet(ReadOnlyModelViewSet): + serializer_class = PublicationNestedSerializer + filter_class = SQLAlchemyPublicationFilterSet + + def get_queryset(self): + return self.request.alchemy_session.query(alchemy.Publication) + + +class ArticleViewSet(ReadOnlyModelViewSet): queryset = Article.objects.all() serializer_class = ArticleNestedSerializer filter_class = ArticleFilterSet + + +class SQLAlchemyArticleViewSet(ReadOnlyModelViewSet): + serializer_class = ArticleNestedSerializer + filter_class = SQLAlchemyArticleFilterSet + + def get_queryset(self): + return self.request.alchemy_session.query(alchemy.Article) diff --git a/test_project/many_to_one/alchemy.py b/test_project/many_to_one/alchemy.py new file mode 100644 index 0000000..4a1b761 --- /dev/null +++ b/test_project/many_to_one/alchemy.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals + +from sqlalchemy import Column, Date, Integer, String +from sqlalchemy.orm import backref, relationship + +from ..alchemy import Base + + +class Reporter(Base): + __tablename__ = 'many_to_one_reporter' + id = Column(Integer, primary_key=True) + first_name = Column(String(30)) + last_name = Column(String(30)) + email = Column(String(254)) + + @property + def pk(self): + return self.id + + +class Article(Base): + __tablename__ = 'many_to_one_article' + id = Column(Integer, primary_key=True) + reporter_id = Column(Integer) + headline = Column(String(100)) + pub_date = Column(Date) + + reporter = relationship( + Reporter, + backref=backref('articles', uselist=True), + uselist=False, + primaryjoin='test_project.many_to_one.alchemy.Article.reporter_id == Reporter.id', + foreign_keys=reporter_id, + ) + + @property + def pk(self): + return self.id diff --git a/test_project/many_to_one/api.py b/test_project/many_to_one/api.py index f0ce341..429b757 100644 --- a/test_project/many_to_one/api.py +++ b/test_project/many_to_one/api.py @@ -2,10 +2,13 @@ from __future__ import print_function, unicode_literals from rest_framework.serializers import ModelSerializer -from rest_framework.viewsets import ModelViewSet +from rest_framework.viewsets import ReadOnlyModelViewSet +from url_filter.backends.sqlalchemy import SQLAlchemyFilterBackend from url_filter.filtersets import ModelFilterSet +from url_filter.filtersets.sqlalchemy import SQLAlchemyModelFilterSet +from . import alchemy from .models import Article, Reporter @@ -39,18 +42,48 @@ class Meta(object): model = Reporter +class SQLAlchemyReporterFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = alchemy.Reporter + + class ArticleFilterSet(ModelFilterSet): class Meta(object): model = Article -class ReporterViewSet(ModelViewSet): +class SQLAlchemyArticleFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = alchemy.Article + + +class ReporterViewSet(ReadOnlyModelViewSet): queryset = Reporter.objects.all() serializer_class = ReporterNestedSerializer filter_class = ReporterFilterSet -class ArticleViewSet(ModelViewSet): +class SQLAlchemyReporterViewSet(ReadOnlyModelViewSet): + serializer_class = ReporterNestedSerializer + filter_class = SQLAlchemyReporterFilterSet + + def get_queryset(self): + return self.request.alchemy_session.query(alchemy.Reporter) + + +class ArticleViewSet(ReadOnlyModelViewSet): queryset = Article.objects.all() serializer_class = ArticleNestedSerializer filter_class = ArticleFilterSet + + +class SQLAlchemyArticleViewSet(ReadOnlyModelViewSet): + serializer_class = ArticleNestedSerializer + filter_class = SQLAlchemyArticleFilterSet + + def get_queryset(self): + return self.request.alchemy_session.query(alchemy.Article) diff --git a/test_project/middleware.py b/test_project/middleware.py new file mode 100644 index 0000000..387095f --- /dev/null +++ b/test_project/middleware.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, unicode_literals + +from django.conf import settings +from sqlalchemy.orm import sessionmaker + + +Session = sessionmaker(bind=settings.SQLALCHEMY_ENGINE) + + +class SQLAlchemySessionMiddleware(object): + def process_request(self, request): + request.alchemy_session = Session() + + def process_response(self, request, response): + request.alchemy_session.close() + return response diff --git a/test_project/one_to_one/alchemy.py b/test_project/one_to_one/alchemy.py new file mode 100644 index 0000000..812c79e --- /dev/null +++ b/test_project/one_to_one/alchemy.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals + +from sqlalchemy import Boolean, Column, Integer, String +from sqlalchemy.orm import backref, relationship + +from ..alchemy import Base + + +class Place(Base): + __tablename__ = 'one_to_one_place' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + address = Column(String(80)) + + @property + def pk(self): + return self.id + + +class Restaurant(Base): + __tablename__ = 'one_to_one_restaurant' + place_id = Column(Integer, primary_key=True) + serves_hot_dogs = Column(Boolean) + serves_pizza = Column(Boolean) + + place = relationship( + Place, + backref=backref('restaurant', uselist=False), + uselist=False, + primaryjoin='Restaurant.place_id == Place.id', + foreign_keys=place_id, + ) + + @property + def pk(self): + return self.place_id + + +class Waiter(Base): + __tablename__ = 'one_to_one_waiter' + id = Column(Integer, primary_key=True) + restaurant_id = Column(Integer) + name = Column(String(50)) + + restaurant = relationship( + Restaurant, + backref=backref('waiter_set', uselist=True), + uselist=False, + primaryjoin='Waiter.restaurant_id == Restaurant.place_id', + foreign_keys=restaurant_id, + ) + + @property + def pk(self): + return self.id diff --git a/test_project/one_to_one/api.py b/test_project/one_to_one/api.py index eda74b0..e596095 100644 --- a/test_project/one_to_one/api.py +++ b/test_project/one_to_one/api.py @@ -2,10 +2,13 @@ from __future__ import print_function, unicode_literals from rest_framework.serializers import ModelSerializer -from rest_framework.viewsets import ModelViewSet +from rest_framework.viewsets import ReadOnlyModelViewSet +from url_filter.backends.sqlalchemy import SQLAlchemyFilterBackend from url_filter.filtersets import ModelFilterSet +from url_filter.filtersets.sqlalchemy import SQLAlchemyModelFilterSet +from . import alchemy from .models import Place, Restaurant, Waiter @@ -60,29 +63,74 @@ class Meta(object): model = Place +class SQLAlchemyPlaceFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = alchemy.Place + + class RestaurantFilterSet(ModelFilterSet): class Meta(object): model = Restaurant +class SQLAlchemyRestaurantFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = alchemy.Restaurant + + class WaiterFilterSet(ModelFilterSet): class Meta(object): model = Waiter -class PlaceViewSet(ModelViewSet): +class SQLAlchemyWaiterFilterSet(SQLAlchemyModelFilterSet): + filter_backend_class = SQLAlchemyFilterBackend + + class Meta(object): + model = alchemy.Waiter + + +class PlaceViewSet(ReadOnlyModelViewSet): queryset = Place.objects.all() serializer_class = PlaceNestedSerializer filter_class = PlaceFilterSet -class RestaurantViewSet(ModelViewSet): +class SQLAlchemyPlaceViewSet(ReadOnlyModelViewSet): + serializer_class = PlaceNestedSerializer + filter_class = SQLAlchemyPlaceFilterSet + + def get_queryset(self): + return self.request.alchemy_session.query(alchemy.Place) + + +class RestaurantViewSet(ReadOnlyModelViewSet): queryset = Restaurant.objects.all() serializer_class = RestaurantNestedSerializer filter_class = RestaurantFilterSet -class WaiterViewSet(ModelViewSet): +class SQLAlchemyRestaurantViewSet(ReadOnlyModelViewSet): + serializer_class = RestaurantNestedSerializer + filter_class = SQLAlchemyRestaurantFilterSet + + def get_queryset(self): + return self.request.alchemy_session.query(alchemy.Restaurant) + + +class WaiterViewSet(ReadOnlyModelViewSet): queryset = Waiter.objects.all() serializer_class = WaiterNestedSerializer filter_class = WaiterFilterSet + + +class SQLAlchemyWaiterViewSet(ReadOnlyModelViewSet): + serializer_class = WaiterNestedSerializer + filter_class = SQLAlchemyWaiterFilterSet + + def get_queryset(self): + return self.request.alchemy_session.query(alchemy.Waiter) diff --git a/test_project/settings.py b/test_project/settings.py index 60c176b..3c73bf7 100644 --- a/test_project/settings.py +++ b/test_project/settings.py @@ -1,7 +1,10 @@ # Bare ``settings.py`` for running tests for url_filter +from sqlalchemy import create_engine + DEBUG = True +SQLALCHEMY_ENGINE = create_engine('sqlite:///url_filter.sqlite', echo=True) DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', @@ -25,7 +28,9 @@ STATIC_URL = '/static/' SECRET_KEY = 'foo' -MIDDLEWARE_CLASSES = [] +MIDDLEWARE_CLASSES = [ + 'test_project.middleware.SQLAlchemySessionMiddleware', +] ROOT_URLCONF = 'test_project.urls' diff --git a/test_project/urls.py b/test_project/urls.py index 9fce939..2eec5f5 100644 --- a/test_project/urls.py +++ b/test_project/urls.py @@ -10,14 +10,21 @@ router = DefaultRouter() +router.register('one-to-one/places/alchemy', o2o_api.SQLAlchemyPlaceViewSet, 'one-to-one-alchemy:place') router.register('one-to-one/places', o2o_api.PlaceViewSet, 'one-to-one:place') +router.register('one-to-one/restaurants/alchemy', o2o_api.SQLAlchemyRestaurantViewSet, 'one-to-one-alchemy:restaurant') router.register('one-to-one/restaurants', o2o_api.RestaurantViewSet, 'one-to-one:restaurant') +router.register('one-to-one/waiters/alchemy', o2o_api.SQLAlchemyWaiterViewSet, 'one-to-one-alchemy:waiter') router.register('one-to-one/waiters', o2o_api.WaiterViewSet, 'one-to-one:waiter') -router.register('many-to-one/reporters', m2o_api.ReporterViewSet, 'many-to-many:reporter') -router.register('many-to-one/articles', m2o_api.ArticleViewSet, 'many-to-many:article') +router.register('many-to-one/reporters/alchemy', m2o_api.SQLAlchemyReporterViewSet, 'many-to-one-alchemy:reporter') +router.register('many-to-one/reporters', m2o_api.ReporterViewSet, 'many-to-one:reporter') +router.register('many-to-one/articles/alchemy', m2o_api.SQLAlchemyArticleViewSet, 'many-to-one-alchemy:article') +router.register('many-to-one/articles', m2o_api.ArticleViewSet, 'many-to-one:article') +router.register('many-to-many/publications/alchemy', m2m_api.SQLAlchemyPublicationViewSet, 'many-to-many-alchemy:publication') router.register('many-to-many/publications', m2m_api.PublicationViewSet, 'many-to-many:publication') +router.register('many-to-many/articles/alchemy', m2m_api.SQLAlchemyArticleViewSet, 'many-to-many-alchemy:article') router.register('many-to-many/articles', m2m_api.ArticleViewSet, 'many-to-many:article') urlpatterns = router.urls diff --git a/tests/backends/__init__.py b/tests/backends/__init__.py new file mode 100644 index 0000000..ba25ec7 --- /dev/null +++ b/tests/backends/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals diff --git a/tests/backends/test_django.py b/tests/backends/test_django.py new file mode 100644 index 0000000..88273f4 --- /dev/null +++ b/tests/backends/test_django.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals + +import mock + +from test_project.one_to_one.models import Place +from url_filter.backends.django import DjangoFilterBackend +from url_filter.utils import FilterSpec + + +class TestDjangoFilterBackend(object): + def test_init(self): + backend = DjangoFilterBackend( + Place.objects.all(), + context={'context': 'here'}, + ) + + assert backend.model is Place + assert backend.context == {'context': 'here'} + + def test_get_model(self): + backend = DjangoFilterBackend(Place.objects.all()) + + assert backend.get_model() is Place + + def test_bind(self): + backend = DjangoFilterBackend(Place.objects.all()) + + assert backend.specs == [] + backend.bind([1, 2]) + assert backend.specs == [1, 2] + + def test_includes(self): + backend = DjangoFilterBackend(Place.objects.all()) + backend.bind([ + FilterSpec(['name'], 'exact', 'value', False), + FilterSpec(['address'], 'contains', 'value', True), + ]) + + assert list(backend.includes) == [ + FilterSpec(['name'], 'exact', 'value', False), + ] + + def test_excludes(self): + backend = DjangoFilterBackend(Place.objects.all()) + backend.bind([ + FilterSpec(['name'], 'exact', 'value', False), + FilterSpec(['address'], 'contains', 'value', True), + ]) + + assert list(backend.excludes) == [ + FilterSpec(['address'], 'contains', 'value', True), + ] + + def test_prepare_spec(self): + backend = DjangoFilterBackend(Place.objects.all()) + spec = backend.prepare_spec(FilterSpec(['name'], 'exact', 'value')) + + assert spec == 'name__exact' + + def test_filter(self): + qs = mock.Mock() + + backend = DjangoFilterBackend(qs) + backend.bind([ + FilterSpec(['name'], 'exact', 'value', False), + FilterSpec(['address'], 'contains', 'value', True), + ]) + + result = backend.filter() + + assert result == qs.filter.return_value.exclude.return_value + qs.filter.assert_called_once_with(name__exact='value') + qs.filter.return_value.exclude.assert_called_once_with(address__contains='value') diff --git a/tests/backends/test_sqlalchemy.py b/tests/backends/test_sqlalchemy.py new file mode 100644 index 0000000..ab18ad3 --- /dev/null +++ b/tests/backends/test_sqlalchemy.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals + +import pytest +import six +from sqlalchemy import func +from sqlalchemy.sql.elements import ClauseList, Grouping +from sqlalchemy.types import String + +from test_project.one_to_one.alchemy import Place, Restaurant, Waiter +from url_filter.backends.sqlalchemy import SQLAlchemyFilterBackend +from url_filter.utils import FilterSpec + + +def assert_alchemy_expressions_equal(exp1, exp2): + assert six.text_type(exp1) == six.text_type(exp2) + + if isinstance(exp1.right, Grouping): + values1 = list(i.value for i in exp1.right.element.clauses) + values2 = list(i.value for i in exp2.right.element.clauses) + assert values1 == values2 + + elif isinstance(exp1.right, ClauseList): + values1 = list(i.value for i in exp1.right.clauses) + values2 = list(i.value for i in exp2.right.clauses) + assert values1 == values2 + + elif hasattr(exp1.right, 'value'): + assert exp1.right.value == exp2.right.value + + +class TestSQLAlchemyFilterBackend(object): + def test_init(self, alchemy_db): + backend = SQLAlchemyFilterBackend( + alchemy_db.query(Place), + context={'context': 'here'}, + ) + + assert backend.model is Place + assert backend.context == {'context': 'here'} + + with pytest.raises(AssertionError): + SQLAlchemyFilterBackend( + alchemy_db.query(Place, Restaurant), + ) + + def test_get_model(self, alchemy_db): + backend = SQLAlchemyFilterBackend(alchemy_db.query(Place)) + + assert backend.get_model() is Place + + def test_filter_no_specs(self, alchemy_db): + qs = alchemy_db.query(Place) + + backend = SQLAlchemyFilterBackend(qs) + backend.bind([]) + + assert backend.filter() is qs + + def test_filter(self, alchemy_db): + backend = SQLAlchemyFilterBackend( + alchemy_db.query(Place), + ) + backend.bind([ + FilterSpec(['restaurant', 'waiter_set', 'name'], 'exact', 'John', False), + ]) + + filtered = backend.filter() + + assert six.text_type(filtered) == ( + 'SELECT one_to_one_place.id AS one_to_one_place_id, ' + 'one_to_one_place.name AS one_to_one_place_name, ' + 'one_to_one_place.address AS one_to_one_place_address \n' + 'FROM one_to_one_place ' + 'JOIN one_to_one_restaurant ' + 'ON one_to_one_restaurant.place_id = one_to_one_place.id ' + 'JOIN one_to_one_waiter ' + 'ON one_to_one_waiter.restaurant_id = one_to_one_restaurant.place_id ' + '\nWHERE one_to_one_waiter.name = :name_1' + ) + + def _test_build_clause(self, alchemy_db, name, lookup, value, expected, is_negated=False): + backend = SQLAlchemyFilterBackend( + alchemy_db.query(Place), + ) + + clause, to_join = backend.build_clause( + FilterSpec(['restaurant', 'waiter_set', name], lookup, value, is_negated) + ) + + assert to_join == [Place.restaurant, Restaurant.waiter_set] + assert_alchemy_expressions_equal(clause, expected) + + def test_build_clause_contains(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'contains', 'John', + Waiter.name.contains('John') + ) + + def test_build_clause_endswith(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'endswith', 'John', + Waiter.name.endswith('John') + ) + + def test_build_clause_exact(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'exact', 'John', + Waiter.name == 'John' + ) + + def test_build_clause_exact_negated(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'exact', 'John', + Waiter.name != 'John', + True + ) + + def test_build_clause_gt(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'id', 'gt', 1, + Waiter.id > 1 + ) + + def test_build_clause_gte(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'id', 'gte', 1, + Waiter.id >= 1 + ) + + def test_build_clause_icontains(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'icontains', 'Django', + func.lower(Waiter.name).contains('django') + ) + + def test_build_clause_icontains_cant_lower(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'icontains', 5, + func.lower(Waiter.name).contains(5) + ) + + def test_build_clause_iendswith(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'iendswith', 'Django', + func.lower(Waiter.name).endswith('django') + ) + + def test_build_clause_iexact(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'iexact', 'Django', + func.lower(Waiter.name) == 'django' + ) + + def test_build_clause_in(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'in', ['Django', 'rocks'], + Waiter.name.in_(['Django', 'rocks']) + ) + + def test_build_clause_isnull(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'isnull', True, + Waiter.name == None # noqa + ) + self._test_build_clause( + alchemy_db, 'name', 'isnull', False, + Waiter.name != None # noqa + ) + + def test_build_clause_istartswith(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'istartswith', 'Django', + func.lower(Waiter.name).startswith('django') + ) + + def test_build_clause_lt(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'id', 'lt', 1, + Waiter.id < 1 + ) + + def test_build_clause_lte(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'id', 'lte', 1, + Waiter.id <= 1 + ) + + def test_build_clause_range(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'id', 'range', [1, 5], + Waiter.id.between(1, 5) + ) + + def test_build_clause_startswith(self, alchemy_db): + self._test_build_clause( + alchemy_db, 'name', 'startswith', 'Django', + Waiter.name.startswith('Django') + ) + + def test__get_properties_for_model(self): + properties = SQLAlchemyFilterBackend._get_properties_for_model(Waiter) + + assert set(properties) == {'restaurant', 'id', 'restaurant_id', 'name'} + + def test__get_column_for_field(self): + properties = SQLAlchemyFilterBackend._get_properties_for_model(Waiter) + name = properties['name'] + column = SQLAlchemyFilterBackend._get_column_for_field(name) + + assert column.key == 'name' + assert isinstance(column.type, String) + assert column.table is Waiter.__table__ + + def test__get_attribute_for_field(self): + properties = SQLAlchemyFilterBackend._get_properties_for_model(Waiter) + name = properties['name'] + attr = SQLAlchemyFilterBackend._get_attribute_for_field(name) + + assert attr is Waiter.name + + def test__get_related_model_for_field(self): + properties = SQLAlchemyFilterBackend._get_properties_for_model(Waiter) + restaurant = properties['restaurant'] + model = SQLAlchemyFilterBackend._get_related_model_for_field(restaurant) + + assert model is Restaurant diff --git a/tests/conftest.py b/tests/conftest.py index bce3390..d9fb933 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,9 @@ from __future__ import print_function, unicode_literals import pytest +from django.conf import settings from django.core.management import call_command +from sqlalchemy.orm import sessionmaker @pytest.fixture @@ -18,3 +20,15 @@ def many_to_one(db): @pytest.fixture def many_to_many(db): call_command('loaddata', 'many_to_many.json') + + +@pytest.fixture +def alchemy_db(request): + session = sessionmaker(bind=settings.SQLALCHEMY_ENGINE)() + + def fin(): + session.close() + + request.addfinalizer(fin) + + return session diff --git a/tests/filtersets/test_base.py b/tests/filtersets/test_base.py index 10303b2..a41b118 100644 --- a/tests/filtersets/test_base.py +++ b/tests/filtersets/test_base.py @@ -26,6 +26,21 @@ def test_init(self): assert fs.context == {'context': 'here'} assert fs.strict_mode == StrictMode.fail + def test_repr(self): + class FooFilterSet(FilterSet): + foo = Filter(form_field=forms.CharField()) + + class BarFilterSet(FilterSet): + bar = Filter(form_field=forms.IntegerField()) + foo = FooFilterSet() + + assert repr(BarFilterSet()) == ( + 'BarFilterSet()\n' + ' bar = Filter(form_field=IntegerField, lookups=ALL, default_lookup="exact", is_default=False)\n' + ' foo = FooFilterSet()\n' + ' foo = Filter(form_field=CharField, lookups=ALL, default_lookup="exact", is_default=False)' + ) + def test_get_filters(self): class TestFilterSet(FilterSet): foo = Filter(form_field=forms.CharField()) diff --git a/tests/filtersets/test_sqlalchemy.py b/tests/filtersets/test_sqlalchemy.py new file mode 100644 index 0000000..8951db6 --- /dev/null +++ b/tests/filtersets/test_sqlalchemy.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, unicode_literals + +import pytest +from django import forms +from sqlalchemy.orm.properties import ColumnProperty +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.types import Integer, String + +from test_project.many_to_many.alchemy import Article as M2MArticle, Publication +from test_project.many_to_one.alchemy import Article as M2OArticle +from test_project.one_to_one.alchemy import Place, Restaurant +from url_filter.exceptions import SkipFilter +from url_filter.filters import Filter +from url_filter.filtersets.sqlalchemy import SQLAlchemyModelFilterSet + + +class TestSQLAlchemyModelFilterSet(object): + def test_get_filters_no_model(self): + class PlaceFilterSet(SQLAlchemyModelFilterSet): + pass + + with pytest.raises(AssertionError): + PlaceFilterSet().get_filters() + + def test_get_filters_no_relations_place(self): + class PlaceFilterSet(SQLAlchemyModelFilterSet): + class Meta(object): + model = Place + allow_related = False + + filters = PlaceFilterSet().get_filters() + + assert set(filters.keys()) == { + 'id', 'name', 'address', + } + + assert isinstance(filters['id'], Filter) + assert isinstance(filters['id'].form_field, forms.IntegerField) + assert isinstance(filters['name'], Filter) + assert isinstance(filters['name'].form_field, forms.CharField) + assert isinstance(filters['address'], Filter) + assert isinstance(filters['address'].form_field, forms.CharField) + + def test_get_filters_no_relations_restaurant(self): + class RestaurantFilterSet(SQLAlchemyModelFilterSet): + class Meta(object): + model = Restaurant + allow_related = False + + filters = RestaurantFilterSet().get_filters() + + assert set(filters.keys()) == { + 'serves_pizza', 'serves_hot_dogs', 'place_id', + } + + assert isinstance(filters['serves_pizza'], Filter) + assert isinstance(filters['serves_pizza'].form_field, forms.BooleanField) + assert isinstance(filters['place_id'], Filter) + assert isinstance(filters['place_id'].form_field, forms.IntegerField) + assert isinstance(filters['serves_hot_dogs'], Filter) + assert isinstance(filters['serves_hot_dogs'].form_field, forms.BooleanField) + + def test_get_filters_with_only_reverse_relations(self): + class PlaceFilterSet(SQLAlchemyModelFilterSet): + class Meta(object): + model = Place + + filters = PlaceFilterSet().get_filters() + + assert set(filters.keys()) == { + 'id', 'name', 'address', 'restaurant', + } + assert set(filters['restaurant'].filters.keys()) == { + 'serves_pizza', 'serves_hot_dogs', 'waiter_set', 'place_id' + } + + assert isinstance(filters['id'], Filter) + assert isinstance(filters['id'].form_field, forms.IntegerField) + assert isinstance(filters['name'], Filter) + assert isinstance(filters['name'].form_field, forms.CharField) + assert isinstance(filters['address'], Filter) + assert isinstance(filters['address'].form_field, forms.CharField) + assert isinstance(filters['restaurant'], SQLAlchemyModelFilterSet) + + def test_get_filters_with_both_reverse_and_direct_relations(self): + class RestaurantFilterSet(SQLAlchemyModelFilterSet): + class Meta(object): + model = Restaurant + + filters = RestaurantFilterSet().get_filters() + + assert set(filters.keys()) == { + 'place', 'place_id', 'waiter_set', 'serves_hot_dogs', 'serves_pizza', + } + assert set(filters['place'].filters.keys()) == { + 'id', 'name', 'address', + } + assert set(filters['waiter_set'].filters.keys()) == { + 'id', 'name', 'restaurant_id' + } + + assert isinstance(filters['serves_hot_dogs'], Filter) + assert isinstance(filters['serves_hot_dogs'].form_field, forms.BooleanField) + assert isinstance(filters['serves_pizza'], Filter) + assert isinstance(filters['serves_pizza'].form_field, forms.BooleanField) + assert isinstance(filters['place'], SQLAlchemyModelFilterSet) + assert isinstance(filters['waiter_set'], SQLAlchemyModelFilterSet) + + def test_get_filters_with_reverse_many_to_many_relations(self): + class PublicationFilterSet(SQLAlchemyModelFilterSet): + class Meta(object): + model = Publication + + filters = PublicationFilterSet().get_filters() + + assert set(filters.keys()) == { + 'id', 'title', 'articles', + } + assert set(filters['articles'].filters.keys()) == { + 'id', 'headline', + } + + assert isinstance(filters['id'], Filter) + assert isinstance(filters['id'].form_field, forms.IntegerField) + assert isinstance(filters['title'], Filter) + assert isinstance(filters['title'].form_field, forms.CharField) + assert isinstance(filters['articles'], SQLAlchemyModelFilterSet) + + def test_get_filters_with_many_to_many_relations(self): + class ArticleFilterSet(SQLAlchemyModelFilterSet): + class Meta(object): + model = M2MArticle + + filters = ArticleFilterSet().get_filters() + + assert set(filters.keys()) == { + 'id', 'headline', 'publications', + } + assert set(filters['publications'].filters.keys()) == { + 'id', 'title', + } + + assert isinstance(filters['id'], Filter) + assert isinstance(filters['id'].form_field, forms.IntegerField) + assert isinstance(filters['headline'], Filter) + assert isinstance(filters['headline'].form_field, forms.CharField) + assert isinstance(filters['publications'], SQLAlchemyModelFilterSet) + + def test_get_filters_with_many_to_one_relations(self): + class ArticleFilterSet(SQLAlchemyModelFilterSet): + class Meta(object): + model = M2OArticle + + filters = ArticleFilterSet().get_filters() + + assert set(filters.keys()) == { + 'id', 'headline', 'pub_date', 'reporter', 'reporter_id', + } + assert set(filters['reporter'].filters.keys()) == { + 'id', 'email', 'first_name', 'last_name', + } + + assert isinstance(filters['id'], Filter) + assert isinstance(filters['id'].form_field, forms.IntegerField) + assert isinstance(filters['headline'], Filter) + assert isinstance(filters['headline'].form_field, forms.CharField) + assert isinstance(filters['pub_date'], Filter) + assert isinstance(filters['pub_date'].form_field, forms.DateField) + assert isinstance(filters['reporter'], SQLAlchemyModelFilterSet) + + def test_get_form_field_for_field(self): + fs = SQLAlchemyModelFilterSet() + + assert isinstance( + fs.get_form_field_for_field(ColumnProperty(Column('name', String(50)))), + forms.CharField + ) + assert isinstance( + fs.get_form_field_for_field(ColumnProperty(Column('name', Integer))), + forms.IntegerField + ) + + with pytest.raises(SkipFilter): + fs.get_form_field_for_field(ColumnProperty(Column('name', TypeEngine))) diff --git a/tests/integrations/test_drf.py b/tests/integrations/test_drf.py index f93ecdd..4188a54 100644 --- a/tests/integrations/test_drf.py +++ b/tests/integrations/test_drf.py @@ -22,15 +22,6 @@ class View(object): assert filter_class is PlaceFilterSet - def test_get_filter_class_supplied_model_mismatch(self): - class View(object): - filter_class = PlaceFilterSet - - with pytest.raises(AssertionError): - DjangoFilterBackend().get_filter_class( - View(), Restaurant.objects.all() - ) - def test_get_filter_class_by_filter_fields(self): class View(object): filter_fields = ['name'] @@ -71,3 +62,19 @@ class View(object): ) assert filtered == mock_filter.return_value + + @mock.patch.object(FilterSet, 'filter') + def test_filter_queryset_supplied_model_mismatch(self, mock_filter, db, rf): + class View(object): + filter_class = PlaceFilterSet + filter_fields = ['name'] + + request = rf.get('/') + request.query_params = QueryDict() + + with pytest.raises(AssertionError): + DjangoFilterBackend().filter_queryset( + request=request, + queryset=Restaurant.objects.all(), + view=View() + ) diff --git a/tests/test_filters.py b/tests/test_filters.py index b854ce9..28e990c 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,14 +1,20 @@ # -*- coding: utf-8 -*- from __future__ import print_function, unicode_literals +from functools import partial +import mock import pytest from django import forms +from url_filter.backends.django import DjangoFilterBackend from url_filter.fields import MultipleValuesField -from url_filter.filters import Filter +from url_filter.filters import Filter as _Filter from url_filter.utils import FilterSpec, LookupConfig +Filter = partial(_Filter, lookups=DjangoFilterBackend.supported_lookups) + + class TestFilter(object): def test_init(self): f = Filter( @@ -21,12 +27,37 @@ def test_init(self): assert f.source == 'foo' assert isinstance(f.form_field, forms.CharField) - assert f.lookups == ['foo', 'bar'] + assert f.lookups == {'foo', 'bar'} assert f.default_lookup == 'foo' assert f.is_default is True assert f.parent is None assert f.name is None + def test_lookups(self): + assert Filter(form_field=None, lookups=['foo', 'bar']).lookups == {'foo', 'bar'} + assert Filter(form_field=None, lookups=None).lookups == set() + + f = Filter(form_field=None, lookups=None) + f.parent = mock.Mock() + f.parent.root = f.parent + f.parent.filter_backend.supported_lookups = DjangoFilterBackend.supported_lookups + + assert f.lookups == DjangoFilterBackend.supported_lookups + + def test_repr(self): + f = Filter( + source='foo', + lookups=None, + form_field=forms.CharField(), + default_lookup='foo', + is_default=True, + ) + + assert repr(f) == ( + 'Filter(form_field=CharField, lookups=ALL, ' + 'default_lookup="foo", is_default=True)' + ) + def test_source(self): f = Filter(source=None, form_field=forms.CharField()) f.name = 'bar' diff --git a/tests/test_utils.py b/tests/test_utils.py index 7e352c1..4d8be7d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -71,11 +71,14 @@ class Bar(Foo): pass mapping = SubClassDict({ + 'a': 'b', + 'z': 'b', Foo: 'foo', Klass: 'klass', 'key': 'value', }) + assert mapping.get('a') == 'b' assert mapping.get('key') == 'value' assert mapping.get(Klass) == 'klass' assert mapping.get(Foo) == 'foo' diff --git a/url_filter/__init__.py b/url_filter/__init__.py index 59b084e..870a863 100644 --- a/url_filter/__init__.py +++ b/url_filter/__init__.py @@ -4,4 +4,4 @@ __author__ = 'Miroslav Shubernetskiy' __email__ = 'miroslav@miki725.com' -__version__ = '0.1.1' +__version__ = '0.2.0' diff --git a/url_filter/backends/base.py b/url_filter/backends/base.py new file mode 100644 index 0000000..f7b6367 --- /dev/null +++ b/url_filter/backends/base.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals +import abc + +import six +from cached_property import cached_property + + +class BaseFilterBackend(six.with_metaclass(abc.ABCMeta, object)): + supported_lookups = set() + + def __init__(self, queryset, context=None): + self.queryset = queryset + self.context = context or {} + self.specs = [] + + @cached_property + def model(self): + return self.get_model() + + def bind(self, specs): + self.specs = specs + + @abc.abstractmethod + def get_model(self): + """ + Get the queryset model. + + .. note:: **MUST** be implemented by subclasses + """ + + @abc.abstractmethod + def filter(self): + """ + Main method for filtering queryset. + + .. note:: **MUST** be implemented by subclasses + """ diff --git a/url_filter/backends/django.py b/url_filter/backends/django.py index c813819..d900c25 100644 --- a/url_filter/backends/django.py +++ b/url_filter/backends/django.py @@ -3,14 +3,39 @@ from django.db.models.constants import LOOKUP_SEP +from .base import BaseFilterBackend -class DjangoFilterBackend(object): - def __init__(self, queryset, context=None): - self.queryset = queryset - self.context = context or {} - def bind(self, specs): - self.specs = specs +class DjangoFilterBackend(BaseFilterBackend): + supported_lookups = { + 'contains', + 'day', + 'endswith', + 'exact', + 'gt', + 'gte', + 'hour', + 'icontains', + 'iendswith', + 'iexact', + 'in', + 'iregex', + 'isnull', + 'istartswith', + 'lt', + 'lte', + 'minute', + 'month', + 'range', + 'regex', + 'second', + 'startswith', + 'week_day', + 'year', + } + + def get_model(self): + return self.queryset.model @property def includes(self): diff --git a/url_filter/backends/sqlalchemy.py b/url_filter/backends/sqlalchemy.py new file mode 100644 index 0000000..aaeeb99 --- /dev/null +++ b/url_filter/backends/sqlalchemy.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals +import itertools + +from sqlalchemy import func +from sqlalchemy.orm import class_mapper +from sqlalchemy.sql.expression import not_ + +from .base import BaseFilterBackend + + +def lower(value): + try: + return value.lower() + except AttributeError: + return value + + +class SQLAlchemyFilterBackend(BaseFilterBackend): + supported_lookups = { + 'contains', + 'endswith', + 'exact', + 'gt', + 'gte', + 'icontains', + 'iendswith', + 'iexact', + 'in', + 'isnull', + 'istartswith', + 'lt', + 'lte', + 'range', + 'startswith', + } + + def __init__(self, *args, **kwargs): + super(SQLAlchemyFilterBackend, self).__init__(*args, **kwargs) + + assert len(self.queryset._entities) == 1, ( + '{} does not support filtering when multiple entities ' + 'are being queried (e.g. session.query(Foo, Bar)).' + ''.format(self.__class__.__name__) + ) + + def get_model(self): + return self.queryset._primary_entity.entities[0] + + def filter(self): + if not self.specs: + return self.queryset + + clauses = [self.build_clause(spec) for spec in self.specs] + conditions, joins = zip(*clauses) + joins = list(itertools.chain(*joins)) + + qs = self.queryset + if joins: + qs = qs.join(*joins) + + return qs.filter(*conditions) + + def build_clause(self, spec): + to_join = [] + + model = self.model + for component in spec.components: + _field = getattr(model, component) + field = self._get_properties_for_model(model)[component] + try: + model = self._get_related_model_for_field(field) + except AttributeError: + break + else: + to_join.append(_field) + + builder = getattr(self, '_build_clause_{}'.format(spec.lookup)) + column = self._get_attribute_for_field(field) + clause = builder(spec, column) + + if spec.is_negated: + clause = not_(clause) + + return clause, to_join + + def _build_clause_contains(self, spec, column): + return column.contains(spec.value) + + def _build_clause_endswith(self, spec, column): + return column.endswith(spec.value) + + def _build_clause_exact(self, spec, column): + return column == spec.value + + def _build_clause_gt(self, spec, column): + return column > spec.value + + def _build_clause_gte(self, spec, column): + return column >= spec.value + + def _build_clause_icontains(self, spec, column): + return func.lower(column).contains(lower(spec.value)) + + def _build_clause_iendswith(self, spec, column): + return func.lower(column).endswith(lower(spec.value)) + + def _build_clause_iexact(self, spec, column): + return func.lower(column) == lower(spec.value) + + def _build_clause_in(self, spec, column): + return column.in_(spec.value) + + def _build_clause_isnull(self, spec, column): + if spec.value: + return column == None # noqa + else: + return column != None # noqa + + def _build_clause_istartswith(self, spec, column): + return func.lower(column).startswith(lower(spec.value)) + + def _build_clause_lt(self, spec, column): + return column < spec.value + + def _build_clause_lte(self, spec, column): + return column <= spec.value + + def _build_clause_range(self, spec, column): + return column.between(*spec.value) + + def _build_clause_startswith(self, spec, column): + return column.startswith(spec.value) + + @classmethod + def _get_properties_for_model(cls, model): + mapper = class_mapper(model) + return { + i.key: i + for i in mapper.iterate_properties + } + + @classmethod + def _get_column_for_field(cls, field): + return field.columns[0] + + @classmethod + def _get_attribute_for_field(cls, field): + return field.class_attribute + + @classmethod + def _get_related_model_for_field(self, field): + return field._dependency_processor.mapper.class_ diff --git a/url_filter/filters.py b/url_filter/filters.py index 5289428..b538594 100644 --- a/url_filter/filters.py +++ b/url_filter/filters.py @@ -2,9 +2,10 @@ from __future__ import absolute_import, print_function, unicode_literals from functools import partial +import six +from cached_property import cached_property from django import forms from django.core.exceptions import ValidationError -from django.db.models.sql.constants import QUERY_TERMS from .fields import MultipleValuesField from .utils import FilterSpec @@ -85,10 +86,38 @@ def __init__(self, source=None, *args, **kwargs): def _init(self, form_field, lookups=None, default_lookup='exact', is_default=False): self.form_field = form_field - self.lookups = lookups or list(QUERY_TERMS) + self._given_lookups = lookups self.default_lookup = default_lookup or self.default_lookup self.is_default = is_default + def repr(self, prefix=''): + return ( + '{name}(' + 'form_field={form_field}, ' + 'lookups={lookups}, ' + 'default_lookup="{default_lookup}", ' + 'is_default={is_default}' + ')' + ''.format(name=self.__class__.__name__, + form_field=self.form_field.__class__.__name__, + lookups=self._given_lookups or 'ALL', + default_lookup=self.default_lookup, + is_default=self.is_default) + ) + + def __repr__(self): + data = self.repr() + data = data if six.PY3 else data.encode('utf-8') + return data + + @cached_property + def lookups(self): + if self._given_lookups: + return set(self._given_lookups) + if hasattr(self.root, 'filter_backend'): + return self.root.filter_backend.supported_lookups + return set() + @property def source(self): """ diff --git a/url_filter/filtersets/base.py b/url_filter/filtersets/base.py index 8349cd5..96d13e1 100644 --- a/url_filter/filtersets/base.py +++ b/url_filter/filtersets/base.py @@ -155,6 +155,18 @@ def _init(self, data=None, queryset=None, context=None, self.context = context or {} self.strict_mode = strict_mode + def repr(self, prefix=''): + header = '{name}()'.format(name=self.__class__.__name__) + lines = [header] + [ + '{prefix}{key} = {value}'.format( + prefix=prefix + ' ', + key=k, + value=v.repr(prefix=prefix + ' '), + ) + for k, v in sorted(self.filters.items()) + ] + return '\n'.join(lines) + def get_filters(self): """ Get all filters defined in this filterset. @@ -216,6 +228,22 @@ def get_filter_backend(self): context=self.context, ) + @cached_property + def filter_backend(self): + """ + Property for getting instantiated filter backend. + + Primarily useful when accessing filter_backend outside + of the filterset such as leaf filters or integration + layers since backend has useful information for both of + those examples. + """ + assert self.data is not None, ( + 'Filter backend can only be used when data is provided ' + 'to filterset.' + ) + return self.get_filter_backend() + def filter(self): """ Main method which should be used on root ``FilterSet`` @@ -247,7 +275,6 @@ def filter(self): ) specs = self.get_specs() - self.filter_backend = self.get_filter_backend() self.filter_backend.bind(specs) return self.filter_backend.filter() diff --git a/url_filter/filtersets/django.py b/url_filter/filtersets/django.py index 1a934b0..54ac273 100644 --- a/url_filter/filtersets/django.py +++ b/url_filter/filtersets/django.py @@ -55,7 +55,6 @@ class ModelFilterSet(FilterSet): The filterset can be configured via ``Meta`` class attribute, very much like Django's ``ModelForm`` is configured. - """ filter_options_class = ModelFilterSetOptions @@ -116,7 +115,7 @@ def get_model_field_names(self): def get_form_field_for_field(self, field): """ - Get form field for the given Djagno model field. + Get form field for the given Django model field. By default ``Field.formfield()`` is used to get the form field unless an overwrite is present for the field. diff --git a/url_filter/filtersets/sqlalchemy.py b/url_filter/filtersets/sqlalchemy.py new file mode 100644 index 0000000..c54d25c --- /dev/null +++ b/url_filter/filtersets/sqlalchemy.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, unicode_literals +import inspect +from functools import partial + +from django import forms +from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty +from sqlalchemy.types import ( + BIGINT, + CHAR, + CLOB, + DATE, + DECIMAL, + INTEGER, + SMALLINT, + TIMESTAMP, + VARCHAR, + BigInteger, + Boolean, + Date, + DateTime, + Float, + Integer, + Numeric, + String, +) + +from ..backends.sqlalchemy import SQLAlchemyFilterBackend +from ..exceptions import SkipFilter +from ..filters import Filter +from ..utils import SubClassDict +from .base import FilterSet +from .django import ModelFilterSetOptions + + +__all__ = ['SQLAlchemyModelFilterSet'] + + +_STRING = lambda field, column: forms.CharField(max_length=column.type.length) + +SQLALCHEMY_FIELD_MAPPING = SubClassDict({ + BIGINT: forms.IntegerField, + BigInteger: forms.IntegerField, + Integer: forms.IntegerField, + Boolean: partial(forms.BooleanField, required=False), + CHAR: _STRING, + CLOB: _STRING, + DATE: forms.DateTimeField, + Date: forms.DateField, + DateTime: forms.DateTimeField, + DECIMAL: forms.DecimalField, + Float: forms.FloatField, + INTEGER: forms.IntegerField, + Numeric: forms.IntegerField, + SMALLINT: forms.IntegerField, + String: _STRING, + TIMESTAMP: forms.DateTimeField, + VARCHAR: _STRING, +}) + + +class SQLAlchemyModelFilterSet(FilterSet): + """ + ``FilterSet`` for SQLAlchemy models. + + The filterset can be configured via ``Meta`` class attribute, + very much like Django's ``ModelForm`` is configured. + """ + filter_options_class = ModelFilterSetOptions + + def get_filters(self): + """ + Get all filters defined in this filterset including + filters corresponding to Django model fields. + """ + filters = super(SQLAlchemyModelFilterSet, self).get_filters() + + assert self.Meta.model, ( + '{}.Meta.model is missing. Please specify the model ' + 'in order to use ModelFilterSet.' + ''.format(self.__class__.__name__) + ) + + if self.Meta.fields is None: + self.Meta.fields = self.get_model_field_names() + + fields = SQLAlchemyFilterBackend._get_properties_for_model(self.Meta.model) + + for name in self.Meta.fields: + if name in self.Meta.exclude: + continue + + field = fields[name] + + try: + _filter = None + + if isinstance(field, ColumnProperty): + _filter = self.build_filter_from_field(field) + elif isinstance(field, RelationshipProperty): + if not self.Meta.allow_related: + raise SkipFilter + _filter = self.build_filterset_from_related_field(field) + + except SkipFilter: + continue + + else: + if _filter is not None: + filters[name] = _filter + + return filters + + def get_model_field_names(self): + """ + Get a list of all model fields. + + This is used when ``Meta.fields`` is ``None`` + in which case this method returns all model fields. + """ + return list(SQLAlchemyFilterBackend._get_properties_for_model(self.Meta.model).keys()) + + def get_form_field_for_field(self, field): + """ + Get form field for the given SQLAlchemy model field. + """ + column = SQLAlchemyFilterBackend._get_column_for_field(field) + + form_field = SQLALCHEMY_FIELD_MAPPING.get( + column.type.__class__, None, + ) + + if form_field is None: + raise SkipFilter + + if inspect.isclass(form_field) or isinstance(form_field, partial): + return form_field() + else: + return form_field(field, column) + + def build_filter_from_field(self, field): + """ + Build ``Filter`` for a standard SQLAlchemy model field. + """ + column = SQLAlchemyFilterBackend._get_column_for_field(field) + + return Filter( + form_field=self.get_form_field_for_field(field), + is_default=column.primary_key, + ) + + def build_filterset_from_related_field(self, field): + m = SQLAlchemyFilterBackend._get_related_model_for_field(field) + meta = { + 'model': m, + 'exclude': [field.back_populates] + } + + meta = type(str('Meta'), (object,), meta) + + filterset = type( + str('{}FilterSet'.format(m.__name__)), + (SQLAlchemyModelFilterSet,), + { + 'Meta': meta, + '__module__': self.__module__, + } + ) + + return filterset() diff --git a/url_filter/integrations/drf.py b/url_filter/integrations/drf.py index 7bd23e2..d3da066 100644 --- a/url_filter/integrations/drf.py +++ b/url_filter/integrations/drf.py @@ -16,14 +16,6 @@ def get_filter_class(self, view, queryset=None): filter_fields = getattr(view, 'filter_fields', None) if filter_class: - filter_model = getattr(filter_class.Meta, 'model', None) - - if filter_model: - assert issubclass(queryset.model, filter_model), ( - 'FilterSet model {} does not match queryset model {}' - ''.format(filter_model, queryset.model) - ) - return filter_class if filter_fields: @@ -57,6 +49,15 @@ def filter_queryset(self, request, queryset, view): queryset=queryset, context=self.get_filter_context(request, view), ) + + filter_model = getattr(_filter.Meta, 'model', None) + if filter_model: + model = _filter.filter_backend.model + assert issubclass(model, filter_model), ( + 'FilterSet model {} does not match queryset model {}' + ''.format(filter_model, model) + ) + return _filter.filter() return queryset