Skip to content

Commit

Permalink
Fix handling of UUIDField for mysql, mariadb and sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-k committed Apr 14, 2020
1 parent 9c98180 commit d816cfe
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions guardian/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,29 @@
"""
import warnings
from collections import defaultdict
from functools import partial
from itertools import groupby

from django.apps import apps
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group, Permission
from django.contrib.contenttypes.models import ContentType
from django.db import connection
from django.db.models import Count, Q, QuerySet
from django.shortcuts import _get_queryset
from django.db.models.functions import Cast
from django.db.models.expressions import Value
from django.db.models.functions import Cast, Replace
from django.db.models import (
IntegerField,
AutoField,
BigIntegerField,
CharField,
ForeignKey,
IntegerField,
PositiveIntegerField,
PositiveSmallIntegerField,
SmallIntegerField,
ForeignKey
UUIDField,
)
)
from guardian.core import ObjectPermissionChecker
from guardian.ctypes import get_content_type
from guardian.exceptions import MixedContentTypeError, WrongAppError, MultipleIdentityAndObjectError
Expand Down Expand Up @@ -624,26 +628,21 @@ def get_objects_for_user(user, perms, klass=None, use_groups=True, any_perm=Fals
field_pk = user_fields[0]
values = user_obj_perms_queryset

cast_pk_field = _cast_pk_field(queryset)

if cast_pk_field:
values = values.annotate(
obj_pk=Cast(field_pk, cast_pk_field())
)
handle_pk_field = _handle_pk_field(queryset)
if handle_pk_field is not None:
values = values.annotate(obj_pk=handle_pk_field(expression=field_pk))
field_pk = 'obj_pk'

values = values.values_list(field_pk, flat=True)
q = Q(pk__in=list(values))
q = Q(pk__in=values)
if use_groups:
field_pk = group_fields[0]
values = groups_obj_perms_queryset
if cast_pk_field:
values = values.annotate(
obj_pk=Cast(field_pk, cast_pk_field())
)
if handle_pk_field is not None:
values = values.annotate(obj_pk=handle_pk_field(expression=field_pk))
field_pk = 'obj_pk'
values = values.values_list(field_pk, flat=True)
q |= Q(pk__in=list(values))
q |= Q(pk__in=values)

return queryset.filter(q)

Expand Down Expand Up @@ -791,23 +790,20 @@ def get_objects_for_group(group, perms, klass=None, any_perm=False, accept_globa
field_pk = fields[0]
values = groups_obj_perms_queryset

cast_pk_field = _cast_pk_field(queryset)

if cast_pk_field:
values = values.annotate(
obj_pk=Cast(field_pk, cast_pk_field())
)
handle_pk_field = _handle_pk_field(queryset)
if handle_pk_field is not None:
values = values.annotate(obj_pk=handle_pk_field(expression=field_pk))
field_pk = 'obj_pk'

values = values.values_list(field_pk, flat=True)
return queryset.filter(pk__in=values)


def _cast_pk_field(queryset):
def _handle_pk_field(queryset):
pk = queryset.model._meta.pk

if isinstance(pk, ForeignKey):
return _cast_pk_field(pk.target_field)
return _handle_pk_field(pk.target_field)

if isinstance(
pk,
Expand All @@ -820,9 +816,16 @@ def _cast_pk_field(queryset):
SmallIntegerField,
),
):
return BigIntegerField
return partial(Cast, output_field=BigIntegerField())

if isinstance(pk, UUIDField):
return UUIDField
if connection.features.has_native_uuid_field:
return partial(Cast, output_field=UUIDField())
return partial(
Replace,
text=Value('-'),
replacement=Value(''),
output_field=CharField(),
)

return None

0 comments on commit d816cfe

Please sign in to comment.