Skip to content

Commit

Permalink
Merge pull request #4700 from khalibloo/fix-attrs-resolver
Browse files Browse the repository at this point in the history
fix: attributes resolver inCategory and inCollection
  • Loading branch information
maarcingebala committed Sep 27, 2019
2 parents 0397b33 + fdd787a commit e44fcd5
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 74 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ All notable, unreleased changes to this project will be documented in this file.
- PaymentGatewayEnum removed from GraphQL schema as gateways now are dynamic plugins. Gateway names changed. - #4756 by @salwator
- Add support for webhooks - #4731 by @korycins
- Fixed the random failure of `populatedb` trying to create a new user with an existing email - #4769 by @NyanKiyoshi
- Fixed the inability of filtering attributes using `inCategory` and `inCollection` and deprecated those fields to use `filter { inCollection: ..., inCategory: ... }` instead - #4700 by @NyanKiyoshi & @khalibloo

## 2.8.0

Expand Down
1 change: 0 additions & 1 deletion saleor/graphql/core/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def get_output_fields(model, return_field_name):


def get_error_fields(error_type_class, error_type_field):
""" """
return {
error_type_field: graphene.Field(
graphene.List(
Expand Down
37 changes: 36 additions & 1 deletion saleor/graphql/product/filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict

import django_filters
from django.db.models import Sum
from django.db.models import Q, Sum
from graphene_django.filter import GlobalIDFilter, GlobalIDMultipleChoiceFilter

from ...product.filters import (
Expand All @@ -13,6 +13,7 @@
from ..core.filters import EnumFilter, ListObjectTypeFilter, ObjectTypeFilter
from ..core.types import FilterInputObjectType
from ..core.types.common import PriceRangeInput
from ..core.utils import from_global_id_strict_type
from ..utils import filter_by_query_param, get_nodes
from . import types
from .enums import (
Expand Down Expand Up @@ -178,6 +179,37 @@ def filter_product_type(qs, _, value):
return qs


def filter_attributes_by_product_types(qs, field, value):
if not value:
return qs

if field == "in_category":
category_id = from_global_id_strict_type(
value, only_type="Category", field=field
)
category = Category.objects.filter(pk=category_id).first()

if category is None:
return qs.none()

tree = category.get_descendants(include_self=True)
product_qs = Product.objects.filter(category__in=tree)

elif field == "in_collection":
collection_id = from_global_id_strict_type(
value, only_type="Collection", field=field
)
product_qs = Product.objects.filter(collections__id=collection_id)

else:
raise NotImplementedError(f"Filtering by {field} is unsupported")

product_types = set(product_qs.values_list("product_type_id", flat=True))
return qs.filter(
Q(product_types__in=product_types) | Q(product_variant_types__in=product_types)
)


class ProductFilter(django_filters.FilterSet):
is_published = django_filters.BooleanFilter()
collections = GlobalIDMultipleChoiceFilter(method=filter_collections)
Expand Down Expand Up @@ -257,6 +289,9 @@ class AttributeFilter(django_filters.FilterSet):
)
ids = GlobalIDMultipleChoiceFilter(field_name="id")

in_collection = GlobalIDFilter(method=filter_attributes_by_product_types)
in_category = GlobalIDFilter(method=filter_attributes_by_product_types)

class Meta:
model = Attribute
fields = [
Expand Down
38 changes: 8 additions & 30 deletions saleor/graphql/product/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import TYPE_CHECKING, Optional

import graphene
import graphene_django_optimizer as gql_optimizer
from django.db.models import Q, Sum
from django.db.models import Sum
from graphql import GraphQLError
from graphql_relay import from_global_id

Expand All @@ -12,6 +11,7 @@
from ..utils import filter_by_period, filter_by_query_param, get_database_id, get_nodes
from .enums import AttributeSortField, OrderDirection
from .filters import (
filter_attributes_by_product_types,
filter_products_by_attributes,
filter_products_by_categories,
filter_products_by_collections,
Expand All @@ -31,45 +31,23 @@
ATTRIBUTES_SEARCH_FIELDS = ("name", "slug")


def _filter_attributes_by_product_types(attribute_qs, product_qs):
product_types = set(product_qs.values_list("product_type_id", flat=True))
return attribute_qs.filter(
Q(product_types__in=product_types) | Q(product_variant_types__in=product_types)
)


def resolve_attributes(
info,
qs=None,
category_id=None,
collection_id=None,
in_category=None,
in_collection=None,
query=None,
sort_by=None,
**_kwargs,
):
qs = qs or models.Attribute.objects.get_visible_to_user(info.context.user)
qs = filter_by_query_param(qs, query, ATTRIBUTES_SEARCH_FIELDS)

if category_id:
# Filter attributes by product types belonging to the given category.
category = graphene.Node.get_node_from_global_id(info, category_id, "Category")
if category:
tree = category.get_descendants(include_self=True)
product_qs = models.Product.objects.filter(category__in=tree)
qs = _filter_attributes_by_product_types(qs, product_qs)
else:
qs = qs.none()
if in_category:
qs = filter_attributes_by_product_types(qs, "in_category", in_category)

if collection_id:
# Filter attributes by product types belonging to the given collection.
collection = graphene.Node.get_node_from_global_id(
info, collection_id, "Collection"
)
if collection:
product_qs = collection.products.all()
qs = _filter_attributes_by_product_types(qs, product_qs)
else:
qs = qs.none()
if in_collection:
qs = filter_attributes_by_product_types(qs, "in_collection", in_collection)

if sort_by:
is_asc = sort_by["direction"] == OrderDirection.ASC.value
Expand Down
14 changes: 10 additions & 4 deletions saleor/graphql/product/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,19 @@ class ProductQueries(graphene.ObjectType):
query=graphene.String(description=DESCRIPTIONS["attributes"]),
in_category=graphene.Argument(
graphene.ID,
description="""Return attributes for products
belonging to the given category.""",
description=(
"Return attributes for products belonging to the given category. ",
"DEPRECATED: "
"Will be removed in Saleor 2.10, use the `filter` field instead.",
),
),
in_collection=graphene.Argument(
graphene.ID,
description="""Return attributes for products
belonging to the given collection.""",
description=(
"Return attributes for products belonging to the given collection. ",
"DEPRECATED: "
"Will be removed in Saleor 2.10, use the `filter` field instead.",
),
),
filter=AttributeFilterInput(description="Filtering options for attributes."),
sort_by=graphene.Argument(
Expand Down
2 changes: 2 additions & 0 deletions saleor/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ input AttributeFilterInput {
availableInGrid: Boolean
search: String
ids: [ID]
inCollection: ID
inCategory: ID
}

input AttributeInput {
Expand Down
8 changes: 5 additions & 3 deletions saleor/product/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ def sort_by_attribute(self, attribute_pk: Union[int, str], ascending: bool = Tru

# Retrieve all the products' attribute data IDs (assignments) and
# product types that have the given attribute associated to them
associated_values = AttributeProduct.objects.filter(
attribute_id=attribute_pk
).values_list("pk", "product_type_id")
associated_values = tuple(
AttributeProduct.objects.filter(attribute_id=attribute_pk).values_list(
"pk", "product_type_id"
)
)

if not associated_values:
if not ascending:
Expand Down
130 changes: 95 additions & 35 deletions tests/api/test_attributes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union
from unittest import mock

import graphene
import pytest
Expand All @@ -7,8 +8,10 @@
from django.template.defaultfilters import slugify
from graphene.utils.str_converters import to_camel_case

from saleor.core.taxes import zero_money
from saleor.graphql.core.utils import snake_to_camel_case
from saleor.graphql.product.enums import AttributeTypeEnum, AttributeValueType
from saleor.graphql.product.filters import filter_attributes_by_product_types
from saleor.graphql.product.mutations.attributes import validate_value_is_unique
from saleor.graphql.product.types.attributes import resolve_attribute_value_type
from saleor.product import AttributeInputType
Expand All @@ -19,6 +22,7 @@
AttributeValue,
AttributeVariant,
Category,
Collection,
Product,
ProductType,
ProductVariant,
Expand Down Expand Up @@ -311,41 +315,89 @@ def test_resolve_attribute_values_non_assigned_to_node(
assert variant_attributes[0]["value"] is None


def test_attributes_in_category_query(user_api_client, product):
category = Category.objects.first()
query = """
query {
attributes(inCategory: "%(category_id)s", first: 20) {
edges {
node {
id
name
slug
}
}
}
}
""" % {
"category_id": graphene.Node.to_global_id("Category", category.id)
}
response = user_api_client.post_graphql(query)
content = get_graphql_content(response)
attributes_data = content["data"]["attributes"]["edges"]
assert len(attributes_data) == Attribute.objects.count()
def test_attributes_filter_by_product_type_with_empty_value():
"""Ensure passing an empty or null value is ignored and the queryset is simply
returned without any modification.
"""

qs = Attribute.objects.all()

assert filter_attributes_by_product_types(qs, "...", "") is qs
assert filter_attributes_by_product_types(qs, "...", None) is qs


def test_attributes_filter_by_product_type_with_unsupported_field():
"""Ensure using an unknown field to filter attributes by raises a NotImplemented
exception.
"""

qs = Attribute.objects.all()

with pytest.raises(NotImplementedError) as exc:
filter_attributes_by_product_types(qs, "in_space", "a-value")

assert exc.value.args == ("Filtering by in_space is unsupported",)

def test_attributes_in_collection_query(user_api_client, collection):
product_types = set(
collection.products.all().values_list("product_type_id", flat=True)

def test_attributes_filter_by_non_existing_category_id():
"""Ensure using a non-existing category ID returns an empty query set."""

category_id = graphene.Node.to_global_id("Category", -1)
mocked_qs = mock.MagicMock()
qs = filter_attributes_by_product_types(mocked_qs, "in_category", category_id)
assert qs == mocked_qs.none.return_value


@pytest.mark.parametrize("test_deprecated_filter", [True, False])
@pytest.mark.parametrize("tested_field", ["inCategory", "inCollection"])
def test_attributes_in_collection_query(
user_api_client,
product_type,
category,
collection,
collection_with_products,
test_deprecated_filter,
tested_field,
):
if "Collection" in tested_field:
filtered_by_node_id = graphene.Node.to_global_id("Collection", collection.pk)
elif "Category" in tested_field:
filtered_by_node_id = graphene.Node.to_global_id("Category", category.pk)
else:
raise AssertionError(tested_field)
expected_qs = Attribute.objects.filter(
Q(attributeproduct__product_type_id=product_type.pk)
| Q(attributevariant__product_type_id=product_type.pk)
)

# Create another product type and attribute that shouldn't get matched
other_category = Category.objects.create(name="Other Category", slug="other-cat")
other_attribute = Attribute.objects.create(name="Other", slug="other")
other_product_type = ProductType.objects.create(
name="Other type", has_variants=True, is_shipping_required=True
)
expected_attrs = Attribute.objects.filter(
Q(attributeproduct__product_type_id__in=product_types)
| Q(attributevariant__product_type_id__in=product_types)
other_product_type.product_attributes.add(other_attribute)
other_product = Product.objects.create(
name=f"Another Product",
product_type=other_product_type,
category=other_category,
price=zero_money(),
is_published=True,
)

# Create another collection with products but shouldn't get matched
# as we don't look for this other collection
other_collection = Collection.objects.create(
name="Other Collection",
slug="other-collection",
is_published=True,
description="Description",
)
other_collection.products.add(other_product)

query = """
query {
attributes(inCollection: "%(collection_id)s", first: 20) {
query($nodeID: ID!) {
attributes(first: 20, %(filter_input)s) {
edges {
node {
id
Expand All @@ -355,13 +407,21 @@ def test_attributes_in_collection_query(user_api_client, collection):
}
}
}
""" % {
"collection_id": graphene.Node.to_global_id("Collection", collection.pk)
}
response = user_api_client.post_graphql(query)
content = get_graphql_content(response)
"""

if test_deprecated_filter:
query = query % {"filter_input": f"{tested_field}: $nodeID"}
else:
query = query % {"filter_input": "filter: { %s: $nodeID }" % tested_field}

variables = {"nodeID": filtered_by_node_id}
content = get_graphql_content(user_api_client.post_graphql(query, variables))
attributes_data = content["data"]["attributes"]["edges"]
assert len(attributes_data) == len(expected_attrs)

flat_attributes_data = [attr["node"]["slug"] for attr in attributes_data]
expected_flat_attributes_data = list(expected_qs.values_list("slug", flat=True))

assert flat_attributes_data == expected_flat_attributes_data


CREATE_ATTRIBUTES_QUERY = """
Expand Down

0 comments on commit e44fcd5

Please sign in to comment.