Skip to content

Commit

Permalink
fixed bits comparision for postgres compability
Browse files Browse the repository at this point in the history
  • Loading branch information
m.semenov committed May 28, 2020
1 parent dcdcd79 commit 947d13a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
8 changes: 5 additions & 3 deletions protector/managers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from django.db import models
from django.db import models, connection
from django.apps import apps
from django.contrib.contenttypes.models import ContentType
from django.contrib.auth import get_user_model
Expand Down Expand Up @@ -82,7 +82,8 @@ def by_ctype(self, group_type, roles=None):
user=self.instance
)
if roles is not None:
utg_qset = utg_qset.extra(where=["roles & %s"], params=[roles])
rule = "roles & %s" if not connection.vendor == 'postgresql' else "(roles & %s)::boolean"
utg_qset = utg_qset.extra(where=[rule], params=[roles])
return group_model.objects.filter(
pk__in=utg_qset.values_list('group_id')
)
Expand Down Expand Up @@ -140,7 +141,8 @@ def by_role(self, roles):
if roles is None:
links = links.filter(roles__isnull=True)
else:
links = links.extra(where=["roles & %s"], params=[roles])
rule = "roles & %s" if not connection.vendor == 'postgresql' else "(roles & %s)::boolean"
links = links.extra(where=[rule], params=[roles])
return get_user_model().objects.filter(id__in=links.values_list('user_id', flat=True))


Expand Down
9 changes: 5 additions & 4 deletions protector/querysets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import unicode_literals

from copy import deepcopy
from django.db import connection
from django.db.models.query import QuerySet
from django.db.models import F
from django.apps import apps
Expand All @@ -25,10 +26,10 @@ def filter_by_permission(self, user, permission):
class GenericUserToGroupQuerySet(QuerySet):
def by_role(self, roles):
utg_table_name = self.model._meta.db_table
return self.extra(
where=["{utg_table}.roles & %s".format(utg_table=utg_table_name)],
params=[roles]
)
rule = "{utg_table}.roles & %s".format(utg_table=utg_table_name)
if connection.vendor == 'postgresql':
rule = "({})::boolean".format(rule)
return self.extra(where=[rule], params=[roles])

@check_responsible_reason
def delete(self, **kwargs):
Expand Down

0 comments on commit 947d13a

Please sign in to comment.