From 467098a24f10490976b703a042685fbddb6cb932 Mon Sep 17 00:00:00 2001 From: Ali Nawaz Date: Mon, 6 May 2024 13:49:05 +0500 Subject: [PATCH 1/6] initial commit --- course_discovery/apps/api/serializers.py | 46 ++++++++++++++----- .../apps/api/v1/views/course_runs.py | 6 ++- .../apps/course_metadata/algolia_models.py | 36 ++++++++------- .../apps/course_metadata/models.py | 35 ++++++++------ .../apps/course_metadata/utils.py | 13 ++++-- .../apps/learner_pathway/utils.py | 2 +- 6 files changed, 91 insertions(+), 47 deletions(-) diff --git a/course_discovery/apps/api/serializers.py b/course_discovery/apps/api/serializers.py index b0cff4420e..a51da92523 100644 --- a/course_discovery/apps/api/serializers.py +++ b/course_discovery/apps/api/serializers.py @@ -1183,7 +1183,7 @@ class MinimalCourseSerializer(FlexFieldsSerializerMixin, TimestampModelSerialize url_slug = serializers.SerializerMethodField() course_type = serializers.SerializerMethodField() enterprise_subscription_inclusion = serializers.BooleanField(required=False) - course_run_statuses = serializers.ReadOnlyField() + course_run_statuses = serializers.SerializerMethodField() @classmethod def prefetch_queryset(cls, queryset=None, course_runs=None): @@ -1207,6 +1207,10 @@ def get_course_type(self, obj): def get_url_slug(self, obj): # pylint: disable=unused-argument return None # this has been removed from the MinimalCourseSerializer, set to None to not break APIs + def get_course_run_statuses(self, course): + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + return course.course_run_statuses(restriction_list=restriction_list) + @classmethod def prefetch_course_runs(cls, serializer_class, course_runs=None): """Returns a Prefetch object for the course runs in a course.""" @@ -1317,7 +1321,6 @@ class CourseSerializer(TaggitSerializer, MinimalCourseSerializer): url_slug = serializers.SlugField(read_only=True, source='active_url_slug') url_slug_history = serializers.SlugRelatedField(slug_field='url_slug', read_only=True, many=True) url_redirects = serializers.SlugRelatedField(slug_field='value', read_only=True, many=True) - course_run_statuses = serializers.ReadOnlyField() editors = CourseEditorSerializer(many=True, read_only=True) collaborators = SlugRelatedFieldWithReadSerializer(slug_field='uuid', required=False, many=True, queryset=Collaborator.objects.all(), @@ -1400,7 +1403,7 @@ class Meta(MinimalCourseSerializer.Meta): 'syllabus_raw', 'outcome', 'original_image', 'card_image_url', 'canonical_course_run_key', 'extra_description', 'additional_information', 'additional_metadata', 'faq', 'learner_testimonials', 'enrollment_count', 'recent_enrollment_count', 'topics', 'partner', 'key_for_reruns', 'url_slug', - 'url_slug_history', 'url_redirects', 'course_run_statuses', 'editors', 'collaborators', 'skill_names', + 'url_slug_history', 'url_redirects', 'editors', 'collaborators', 'skill_names', 'skills', 'organization_short_code_override', 'organization_logo_override_url', 'enterprise_subscription_inclusion', 'geolocation', 'location_restriction', 'in_year_value', 'product_source', 'data_modified_timestamp', 'excluded_from_search', 'excluded_from_seo', 'watchers', @@ -1603,16 +1606,22 @@ def prefetch_queryset(cls, partner, queryset=None, course_runs=None, programs=No ) def get_advertised_course_run_uuid(self, course): - if course.advertised_course_run: - return course.advertised_course_run.uuid + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + advertised_run = course.advertised_course_run(restriction_list) + if advertised_run: + return advertised_run.uuid return None def get_course_run_keys(self, course): - return [course_run.key for course_run in course.course_runs.all()] + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + + return [course_run.key for course_run in CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list)] def get_course_runs(self, course): + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + return CourseRunSerializer( - course.course_runs, + CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list), many=True, context={ 'request': self.context.get('request'), @@ -1667,7 +1676,9 @@ def get_marketing_url(self, obj): ) def get_course_run_keys(self, course): - return [course_run.key for course_run in course.course_runs.all()] + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + + return [course_run.key for course_run in CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list)] class Meta(CourseSerializer.Meta): model = Course @@ -1710,8 +1721,10 @@ def prefetch_queryset(cls, partner, queryset=None, course_runs=None): ) def get_course_runs(self, course): + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + return CourseRunSerializer( - course.course_runs, + CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list), many=True, context=self.context ).data @@ -1735,7 +1748,10 @@ class MinimalProgramCourseSerializer(MinimalCourseSerializer): course_runs = serializers.SerializerMethodField() def get_course_runs(self, course): + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + course_runs = self.context['course_runs'] + course_runs = CourseRun.get_exposed_runs(course_runs, restriction_list) course_runs = [course_run for course_run in course_runs if course_run.course == course] if self.context.get('published_course_runs_only'): @@ -1979,7 +1995,7 @@ class MinimalProgramSerializer(TaggitSerializer, FlexFieldsSerializerMixin, Base degree = DegreeSerializer() curricula = CurriculumSerializer(many=True) card_image_url = serializers.SerializerMethodField() - course_run_statuses = serializers.ReadOnlyField() + course_run_statuses = serializers.SerializerMethodField() organization_short_code_override = serializers.CharField(required=False, allow_blank=True) organization_logo_override_url = serializers.SerializerMethodField() primary_subject_override = SubjectSerializer() @@ -1995,6 +2011,10 @@ def get_organization_logo_override_url(self, obj): return logo_image_override.url return None + def get_course_run_statuses(self, program): + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + return program.course_run_statuses(restriction_list=restriction_list) + @classmethod def prefetch_queryset(cls, partner, queryset=None): # Explicitly check if the queryset is None before selecting related @@ -2299,7 +2319,11 @@ class PathwaySerializer(BaseModelSerializer): description = serializers.CharField() destination_url = serializers.CharField() pathway_type = serializers.CharField() - course_run_statuses = serializers.ReadOnlyField() + course_run_statuses = serializers.SerializerMethodField() + + def get_course_run_statuses(self, pathway): + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + return pathway.course_run_statuses(restriction_list=restriction_list) @classmethod def prefetch_queryset(cls, partner): diff --git a/course_discovery/apps/api/v1/views/course_runs.py b/course_discovery/apps/api/v1/views/course_runs.py index cb8fe47bac..d77a53ceb4 100644 --- a/course_discovery/apps/api/v1/views/course_runs.py +++ b/course_discovery/apps/api/v1/views/course_runs.py @@ -1,6 +1,7 @@ import logging from django.db import models, transaction +from django.db.models import Q from django.db.models.functions import Lower from django.http.response import Http404 from django.utils.translation import gettext as _ @@ -81,6 +82,7 @@ def get_queryset(self): multiple: false """ q = self.request.query_params.get('q') + restriction_list = self.request.query_params.get('restriction_list', '').split(',') partner = self.request.site.partner edit_mode = get_query_param(self.request, 'editable') or self.request.method not in SAFE_METHODS @@ -95,7 +97,9 @@ def get_queryset(self): queryset = CourseEditor.editable_course_runs(self.request.user, queryset) else: queryset = self.queryset - + queryset = queryset.filter( + Q(restricted_run__isnull=True) | Q(restricted_run__restriction_type__in=restriction_list) + ) if q: queryset = SearchQuerySetWrapper( CourseRun.search(q).filter('term', partner=partner.short_code), diff --git a/course_discovery/apps/course_metadata/algolia_models.py b/course_discovery/apps/course_metadata/algolia_models.py index 51bb48ad8b..5506197b44 100644 --- a/course_discovery/apps/course_metadata/algolia_models.py +++ b/course_discovery/apps/course_metadata/algolia_models.py @@ -10,9 +10,9 @@ from taxonomy.choices import ProductTypes from taxonomy.utils import get_whitelisted_serialized_skills -from course_discovery.apps.course_metadata.choices import CourseRunStatus, ExternalProductStatus, ProgramStatus +from course_discovery.apps.course_metadata.choices import CourseRunStatus, ExternalProductStatus, ProgramStatus, CourseRunRestrictionType from course_discovery.apps.course_metadata.models import ( - AbstractLocationRestrictionModel, Course, CourseType, ProductValue, Program, ProgramType + AbstractLocationRestrictionModel, Course, CourseRun, CourseType, ProductValue, Program, ProgramType ) from course_discovery.apps.course_metadata.utils import transform_skills_data @@ -26,8 +26,8 @@ # Utility methods used by both courses and programs def get_active_language_tag(course): - if course.advertised_course_run and course.advertised_course_run.language: - return course.advertised_course_run.language + if course.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]) and course.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]).language: + return course.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]).language return None @@ -88,6 +88,7 @@ def _wrap(self, *args, **kwargs): def get_course_availability(course): all_runs = course.course_runs.filter(status=CourseRunStatus.Published) + all_runs = CourseRun.get_exposed_runs(all_runs, restriction_list=[CourseRunRestrictionType.CustomB2C.value]) availability = set() for course_run in all_runs: @@ -211,6 +212,7 @@ class AlgoliaProxyCourse(Course, AlgoliaBasicModelFieldsMixin): class Meta: proxy = True + @property def product_type(self): if self.type.slug == CourseType.EXECUTIVE_EDUCATION_2U: @@ -253,15 +255,15 @@ def active_languages(self): @property def active_run_key(self): - return getattr(self.advertised_course_run, 'key', None) + return getattr(self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]), 'key', None) @property def active_run_start(self): - return getattr(self.advertised_course_run, 'start', None) + return getattr(self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]), 'start', None) @property def active_run_type(self): - return getattr(self.advertised_course_run, 'type', None) + return getattr(self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]), 'type', None) @property def availability_level(self): @@ -329,15 +331,15 @@ def product_card_image_url(self): @property def product_weeks_to_complete(self): - return getattr(self.advertised_course_run, 'weeks_to_complete', None) + return getattr(self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]), 'weeks_to_complete', None) @property def product_min_effort(self): - return getattr(self.advertised_course_run, 'min_effort', None) + return getattr(self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]), 'min_effort', None) @property def product_max_effort(self): - return getattr(self.advertised_course_run, 'max_effort', None) + return getattr(self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]), 'max_effort', None) @property def owners(self): @@ -409,8 +411,8 @@ def should_index(self): self.active_url_slug and self.partner.name == 'edX' and self.availability_level and - bool(self.advertised_course_run) and - not self.advertised_course_run.hidden) + bool(self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value])) and + not self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]).hidden) @property def should_index_spanish(self): @@ -430,15 +432,15 @@ def skills(self): @property def availability_rank(self): today_midnight = datetime.datetime.now(pytz.UTC).replace(hour=0, minute=0, second=0, microsecond=0) - if self.advertised_course_run: - if self.advertised_course_run.is_current_and_still_upgradeable(): + if self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]): + if self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]).is_current_and_still_upgradeable(): return 1 - paid_seat_enrollment_end = self.advertised_course_run.get_paid_seat_enrollment_end() + paid_seat_enrollment_end = self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]).get_paid_seat_enrollment_end() if paid_seat_enrollment_end and paid_seat_enrollment_end > today_midnight: return 2 - if datetime.datetime.now(pytz.UTC) >= self.advertised_course_run.start: + if datetime.datetime.now(pytz.UTC) >= self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]).start: return 3 - return self.advertised_course_run.start.timestamp() + return self.advertised_course_run(restriction_list=[CourseRunRestrictionType.CustomB2C.value]).start.timestamp() return None # Algolia will deprioritize entries where a ranked field is empty @property diff --git a/course_discovery/apps/course_metadata/models.py b/course_discovery/apps/course_metadata/models.py index 314f557074..e4fb329dcb 100644 --- a/course_discovery/apps/course_metadata/models.py +++ b/course_discovery/apps/course_metadata/models.py @@ -1713,8 +1713,7 @@ def first_enrollable_paid_seat_price(self): return None - @property - def course_run_statuses(self): + def course_run_statuses(self, restriction_list=None): """ Returns all unique course run status values inside this course. @@ -1723,7 +1722,7 @@ def course_run_statuses(self): invalidates the prefetch on API level. """ statuses = set() - return sorted(list(get_course_run_statuses(statuses, self.course_runs.all()))) + return sorted(list(get_course_run_statuses(statuses, self.course_runs.all(), restriction_list))) @property def is_active(self): @@ -1873,8 +1872,7 @@ def set_active_url_slug(self, slug): other_slug.is_active_on_draft = False other_slug.save() - @cached_property - def advertised_course_run(self): + def advertised_course_run(self, restriction_list=None): now = datetime.datetime.now(pytz.UTC) min_date = datetime.datetime.min.replace(tzinfo=pytz.UTC) max_date = datetime.datetime.max.replace(tzinfo=pytz.UTC) @@ -1884,7 +1882,7 @@ def advertised_course_run(self): tier_three = [] marketable_course_runs = [course_run for course_run in self.course_runs.all() if course_run.is_marketable] - + marketable_course_runs = CourseRun.get_exposed_runs(marketable_course_runs, restriction_list) for course_run in marketable_course_runs: course_run_started = (not course_run.start) or (course_run.start and course_run.start < now) if course_run.is_current_and_still_upgradeable(): @@ -2587,6 +2585,17 @@ def search(cls, query): dsl_query = ESDSLQ('query_string', query=query, analyze_wildcard=True) return queryset.query(dsl_query) + @classmethod + def get_exposed_runs(cls, runs, restriction_list=None): + if restriction_list is None: + return runs + + return filter( + lambda cr: not hasattr(cr, "restricted_run") or + cr.restricted_run.restriction_type in restriction_list, + runs + ) + def __str__(self): return f'{self.key}: {self.title}' @@ -3439,8 +3448,7 @@ def canonical_course_runs(self): if canonical_course_run and canonical_course_run.id not in excluded_course_run_ids: yield canonical_course_run - @property - def course_run_statuses(self): + def course_run_statuses(self, restriction_list=None): """ Returns all unique course run status values inside the courses in this program. @@ -3450,7 +3458,7 @@ def course_run_statuses(self): """ statuses = set() for course in self.courses.all(): - get_course_run_statuses(statuses, course.course_runs.all()) + get_course_run_statuses(statuses, course.course_runs.all(), restriction_list) return sorted(list(statuses)) @property @@ -3654,9 +3662,9 @@ def start(self): @property def staff(self): - advertised_course_runs = [course.advertised_course_run for + advertised_course_runs = [course.advertised_course_run() for course in self.courses.all() if - course.advertised_course_run] + course.advertised_course_run()] staff = [advertised_course_run.staff.all() for advertised_course_run in advertised_course_runs] staff = itertools.chain.from_iterable(staff) return set(staff) @@ -4189,8 +4197,7 @@ def validate_partner_programs(cls, partner, programs): msg = _('These programs are for a different partner than the pathway itself: {}') raise ValidationError(msg.format(', '.join(bad_programs))) - @property - def course_run_statuses(self): + def course_run_statuses(self, restriction_list=None): """ Returns all unique course run status values inside the programs in this pathway. @@ -4201,7 +4208,7 @@ def course_run_statuses(self): statuses = set() for program in self.programs.all(): for course in program.courses.all(): - get_course_run_statuses(statuses, course.course_runs.all()) + get_course_run_statuses(statuses, course.course_runs.all(), restriction_list) return sorted(list(statuses)) diff --git a/course_discovery/apps/course_metadata/utils.py b/course_discovery/apps/course_metadata/utils.py index 4c3549f961..51a48b7a6c 100644 --- a/course_discovery/apps/course_metadata/utils.py +++ b/course_discovery/apps/course_metadata/utils.py @@ -473,10 +473,11 @@ def serialize_entitlement_for_ecommerce_api(entitlement): if IS_COURSE_RUN_VARIANT_ID_ECOMMERCE_CONSUMABLE.is_enabled(): course = entitlement.course - if course.advertised_course_run and course.advertised_course_run.variant_id: + advertised_run = course.advertised_course_run() + if advertised_run and advertised_run.variant_id: attribute_values_list.append({ 'name': 'variant_id', - 'value': str(course.advertised_course_run.variant_id), + 'value': str(advertised_run.variant_id), }) else: additional_metadata = entitlement.course.additional_metadata @@ -1252,7 +1253,7 @@ def get_product_skill_names(product_identifier, product_type): return list({product_skill['name'] for product_skill in product_skills}) -def get_course_run_statuses(statuses, course_runs): +def get_course_run_statuses(statuses, course_runs, restriction_list=None): """ Util method to get course run statuses based on the course_runs """ @@ -1260,6 +1261,12 @@ def get_course_run_statuses(statuses, course_runs): for course_run in course_runs: if course_run.hidden: continue + if ( + restriction_list is not None and + hasattr(course_run, "restricted_run") and + course_run.restricted_run.restriction_type not in restriction_list + ): + continue if course_run.end and course_run.end < now and course_run.status == CourseRunStatus.Unpublished: statuses.add('archived') else: diff --git a/course_discovery/apps/learner_pathway/utils.py b/course_discovery/apps/learner_pathway/utils.py index 0f00b1cca8..e99fb5e13e 100644 --- a/course_discovery/apps/learner_pathway/utils.py +++ b/course_discovery/apps/learner_pathway/utils.py @@ -7,7 +7,7 @@ def get_advertised_course_run_estimated_hours(course): active_course_runs = course.active_course_runs if course.advertised_course_run: - advertised_course_run_uuid = course.advertised_course_run.uuid + advertised_course_run_uuid = course.advertised_course_run().uuid for course_run in active_course_runs: if course_run.uuid == advertised_course_run_uuid: return get_course_run_estimated_hours(course_run) From 31eb260c3dc50c6978b64d7c014daf372476e8f2 Mon Sep 17 00:00:00 2001 From: Ali Nawaz Date: Mon, 6 May 2024 14:09:28 +0500 Subject: [PATCH 2/6] fix queryset for patch requests --- course_discovery/apps/api/v1/views/course_runs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/course_discovery/apps/api/v1/views/course_runs.py b/course_discovery/apps/api/v1/views/course_runs.py index d77a53ceb4..cd4854c094 100644 --- a/course_discovery/apps/api/v1/views/course_runs.py +++ b/course_discovery/apps/api/v1/views/course_runs.py @@ -97,9 +97,11 @@ def get_queryset(self): queryset = CourseEditor.editable_course_runs(self.request.user, queryset) else: queryset = self.queryset - queryset = queryset.filter( - Q(restricted_run__isnull=True) | Q(restricted_run__restriction_type__in=restriction_list) - ) + + if self.request.method == 'GET': + queryset = queryset.filter( + Q(restricted_run__isnull=True) | Q(restricted_run__restriction_type__in=restriction_list) + ) if q: queryset = SearchQuerySetWrapper( CourseRun.search(q).filter('term', partner=partner.short_code), From 0e739f69729cccf5428f09cff28cdf846e87c787 Mon Sep 17 00:00:00 2001 From: Ali Nawaz Date: Mon, 6 May 2024 23:46:26 +0500 Subject: [PATCH 3/6] add support for ES Apis --- .../apps/api/v1/views/course_runs.py | 5 +++-- course_discovery/apps/api/v1/views/search.py | 16 +++++++++++++++- .../search_indexes/documents/course_run.py | 5 +++++ .../search_indexes/serializers/course.py | 6 ++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/course_discovery/apps/api/v1/views/course_runs.py b/course_discovery/apps/api/v1/views/course_runs.py index cd4854c094..a3a5e82485 100644 --- a/course_discovery/apps/api/v1/views/course_runs.py +++ b/course_discovery/apps/api/v1/views/course_runs.py @@ -23,7 +23,7 @@ from course_discovery.apps.api.utils import StudioAPI, get_query_param, reviewable_data_has_changed from course_discovery.apps.api.v1.exceptions import EditableAndQUnsupported from course_discovery.apps.core.utils import SearchQuerySetWrapper -from course_discovery.apps.course_metadata.choices import CourseRunStatus +from course_discovery.apps.course_metadata.choices import CourseRunStatus, CourseRunRestrictionType from course_discovery.apps.course_metadata.constants import COURSE_RUN_ID_REGEX from course_discovery.apps.course_metadata.exceptions import EcommerceSiteAPIClientException from course_discovery.apps.course_metadata.models import Course, CourseEditor, CourseRun @@ -83,6 +83,7 @@ def get_queryset(self): """ q = self.request.query_params.get('q') restriction_list = self.request.query_params.get('restriction_list', '').split(',') + forbidden = list(set(CourseRunRestrictionType.values) - set(restriction_list)) partner = self.request.site.partner edit_mode = get_query_param(self.request, 'editable') or self.request.method not in SAFE_METHODS @@ -104,7 +105,7 @@ def get_queryset(self): ) if q: queryset = SearchQuerySetWrapper( - CourseRun.search(q).filter('term', partner=partner.short_code), + CourseRun.search(q).filter('term', partner=partner.short_code).exclude('terms', restriction_type=forbidden), model=queryset.model ) else: diff --git a/course_discovery/apps/api/v1/views/search.py b/course_discovery/apps/api/v1/views/search.py index f2fbd70fa2..b274f599f0 100644 --- a/course_discovery/apps/api/v1/views/search.py +++ b/course_discovery/apps/api/v1/views/search.py @@ -17,7 +17,7 @@ from course_discovery.apps.api import serializers from course_discovery.apps.api.utils import update_query_params_with_body_data -from course_discovery.apps.course_metadata.choices import ProgramStatus +from course_discovery.apps.course_metadata.choices import ProgramStatus, CourseRunRestrictionType from course_discovery.apps.course_metadata.models import Person from course_discovery.apps.course_metadata.search_indexes import documents as search_documents from course_discovery.apps.course_metadata.search_indexes import serializers as search_indexes_serializers @@ -132,6 +132,16 @@ class CourseRunSearchViewSet(FacetQueryFieldsMixin, BaseElasticsearchDocumentVie }, } + def get_queryset(self): + """Get queryset.""" + queryset = super().get_queryset() + + query_params = self.request.query_params + restriction_list = query_params.get('restriction_list', '').split(',') + forbidden = list(set(CourseRunRestrictionType.values) - set(restriction_list)) + queryset = queryset.exclude('terms', restriction_type=forbidden) + + return queryset class ProgramSearchViewSet(BaseElasticsearchDocumentViewSet): """ @@ -300,6 +310,10 @@ def get_queryset(self): if not query_params.get(LEARNER_PATHWAY_FEATURE_PARAM, 'false').lower() == 'true': queryset = queryset.exclude('term', content_type=LearnerPathway.__name__.lower()) + restriction_list = query_params.get('restriction_list', '').split(',') + forbidden = list(set(CourseRunRestrictionType.values) - set(restriction_list)) + queryset = queryset.exclude('terms', restriction_type=forbidden) + return queryset @update_query_params_with_body_data diff --git a/course_discovery/apps/course_metadata/search_indexes/documents/course_run.py b/course_discovery/apps/course_metadata/search_indexes/documents/course_run.py index e85df5f990..3758d77e0e 100644 --- a/course_discovery/apps/course_metadata/search_indexes/documents/course_run.py +++ b/course_discovery/apps/course_metadata/search_indexes/documents/course_run.py @@ -64,6 +64,7 @@ class CourseRunDocument(BaseCourseDocument): 'description': fields.TextField(), }) status = fields.KeywordField() + restriction_type = fields.KeywordField() start = fields.DateField() slug = fields.TextField() staff_uuids = fields.KeywordField(multi=True) @@ -125,6 +126,10 @@ def prepare_seat_types(self, obj): def prepare_skill_names(self, obj): return get_product_skill_names(obj.course.key, ProductTypes.Course) + def prepare_restriction_type(self, obj): + if hasattr(obj, "restricted_run"): + return obj.restricted_run.restriction_type + return None def prepare_skills(self, obj): return get_whitelisted_serialized_skills(obj.course.key, product_type=ProductTypes.Course) diff --git a/course_discovery/apps/course_metadata/search_indexes/serializers/course.py b/course_discovery/apps/course_metadata/search_indexes/serializers/course.py index 317441ef01..258bcb5bda 100644 --- a/course_discovery/apps/course_metadata/search_indexes/serializers/course.py +++ b/course_discovery/apps/course_metadata/search_indexes/serializers/course.py @@ -12,6 +12,7 @@ from course_discovery.apps.api import serializers as cd_serializers from course_discovery.apps.api.serializers import ContentTypeSerializer, CourseWithProgramsSerializer from course_discovery.apps.course_metadata.utils import get_course_run_estimated_hours, get_product_skill_names +from course_discovery.apps.course_metadata.choices import CourseRunRestrictionType from course_discovery.apps.edx_elasticsearch_dsl_extensions.serializers import BaseDjangoESDSLFacetSerializer from ..constants import BASE_SEARCH_INDEX_FIELDS, COMMON_IGNORED_FIELDS @@ -94,6 +95,9 @@ def get_course_runs(self, result): exclude_expired = request.POST.get('exclude_expired_course_run') or exclude_expired detail_fields = request.POST.get('detail_fields') or detail_fields + restriction_list = request.query_params.get('restriction_list', '').split(',') + forbidden = set(CourseRunRestrictionType.values) - set(restriction_list) + def should_include_course_run(course_run, params, exclude_expired): matches_parameter = False for key, values in params.items(): @@ -103,6 +107,8 @@ def should_include_course_run(course_run, params, exclude_expired): matches_parameter = True if matches_parameter: break + if hasattr(course_run, 'restricted_run') and course_run.restricted_run.restriction_type in forbidden: + return False return (not exclude_expired or matches_parameter or course_run.end is None or course_run.end > now) return [ From a37c452cc398ff949ceadeea62cb9fa01bd047ee Mon Sep 17 00:00:00 2001 From: Ali Nawaz Date: Tue, 7 May 2024 16:56:47 +0500 Subject: [PATCH 4/6] fix learner pathways in search/all --- course_discovery/apps/learner_pathway/api/serializers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/course_discovery/apps/learner_pathway/api/serializers.py b/course_discovery/apps/learner_pathway/api/serializers.py index b6c0169c48..fecfb48fdc 100644 --- a/course_discovery/apps/learner_pathway/api/serializers.py +++ b/course_discovery/apps/learner_pathway/api/serializers.py @@ -3,7 +3,7 @@ """ from rest_framework import serializers -from course_discovery.apps.course_metadata.choices import CourseRunStatus +from course_discovery.apps.course_metadata.choices import CourseRunStatus, CourseRunRestrictionType from course_discovery.apps.learner_pathway import models @@ -19,7 +19,10 @@ class Meta: fields = ('key', 'course_runs') def get_course_runs(self, obj): - return list(obj.course.course_runs.filter(status=CourseRunStatus.Published).values('key')) + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + forbidden = set(CourseRunRestrictionType.values) - set(restriction_list) + + return list(obj.course.course_runs.filter(status=CourseRunStatus.Published).exclude(restricted_run__restriction_type__in=forbidden).values('key')) class LearnerPathwayCourseSerializer(LearnerPathwayCourseMinimalSerializer): From 2bcd2cf970b634068e3e74212690ce8cb4aa477d Mon Sep 17 00:00:00 2001 From: Ali Nawaz Date: Wed, 8 May 2024 01:51:17 +0500 Subject: [PATCH 5/6] filtering fixes for learner pathway and search api --- .../apps/course_metadata/models.py | 12 ++++++------ .../search_indexes/serializers/course.py | 19 +++++++++++-------- .../apps/learner_pathway/api/serializers.py | 5 ++++- .../apps/learner_pathway/models.py | 4 ++-- .../apps/learner_pathway/utils.py | 2 +- 5 files changed, 24 insertions(+), 18 deletions(-) diff --git a/course_discovery/apps/course_metadata/models.py b/course_discovery/apps/course_metadata/models.py index e4fb329dcb..0b11374df5 100644 --- a/course_discovery/apps/course_metadata/models.py +++ b/course_discovery/apps/course_metadata/models.py @@ -1681,20 +1681,20 @@ def course_ends(self): else: return _('Past') - def languages(self, exclude_inactive_runs=False): + def languages(self, exclude_inactive_runs=False, allowed_restriction_types=[]): """ Returns a set of all languages used in this course. Arguments: exclude_inactive_runs (bool): whether to exclude inactive runs """ + runs = self.course_runs.all() if exclude_inactive_runs: - return list({ - serialize_language(course_run.language) for course_run in self.course_runs.all() - if course_run.is_active and course_run.language is not None - }) + runs = filter(lambda r: r.is_active, runs) + runs = filter(lambda r: not hasattr(r, "restricted_run") or r.restricted_run.restriction_type in allowed_restriction_types, runs) + return list({ - serialize_language(course_run.language) for course_run in self.course_runs.all() + serialize_language(course_run.language) for course_run in runs if course_run.language is not None }) diff --git a/course_discovery/apps/course_metadata/search_indexes/serializers/course.py b/course_discovery/apps/course_metadata/search_indexes/serializers/course.py index 258bcb5bda..f4c5d81714 100644 --- a/course_discovery/apps/course_metadata/search_indexes/serializers/course.py +++ b/course_discovery/apps/course_metadata/search_indexes/serializers/course.py @@ -127,7 +127,9 @@ def get_languages(self, result): if request.method == 'POST': exclude_non_active_languages = request.POST.get('exclude_expired_course_run', exclude_non_active_languages) - return result.object.languages(exclude_non_active_languages) + restriction_list = request.query_params.get('restriction_list', '').split(',') + + return result.object.languages(exclude_non_active_languages, allowed_restriction_types=restriction_list) def get_seat_types(self, result): now = datetime.datetime.now(pytz.UTC) @@ -135,14 +137,15 @@ def get_seat_types(self, result): exclude_expired = request.GET.get('exclude_expired_course_run') if request.method == 'POST': exclude_expired = request.POST.get('exclude_expired_course_run', exclude_expired) + + runs = result.object.course_runs.all() if exclude_expired: - # if course_run is active then add course_run.seat_types to seat_types - seat_types = [ - seat.slug for course_run in result.object.course_runs.all() - if course_run.end is None or course_run.end > now for seat in course_run.seat_types - ] - else: - seat_types = [seat.slug for course_run in result.object.course_runs.all() for seat in course_run.seat_types] + runs = filter(lambda r: r.end is None or r.end > now, runs) + + restriction_list = request.query_params.get('restriction_list', '').split(',') + runs = filter(lambda r: not hasattr(r, "restricted_run") or r.restricted_run.restriction_type in restriction_list, runs) + + seat_types = [seat.slug for course_run in runs for seat in course_run.seat_types] return list(set(seat_types)) def get_skill_names(self, result): diff --git a/course_discovery/apps/learner_pathway/api/serializers.py b/course_discovery/apps/learner_pathway/api/serializers.py index fecfb48fdc..fbead029de 100644 --- a/course_discovery/apps/learner_pathway/api/serializers.py +++ b/course_discovery/apps/learner_pathway/api/serializers.py @@ -84,7 +84,10 @@ def get_card_image_url(self, step): return program.card_image_url def get_courses(self, obj): - return obj.get_linked_courses_and_course_runs() + restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') + forbidden = set(CourseRunRestrictionType.values) - set(restriction_list) + + return obj.get_linked_courses_and_course_runs(forbidden_restriction_types=forbidden) class LearnerPathwayBlockSerializer(serializers.ModelSerializer): diff --git a/course_discovery/apps/learner_pathway/models.py b/course_discovery/apps/learner_pathway/models.py index 1278dbe7e1..bf96e43aa6 100644 --- a/course_discovery/apps/learner_pathway/models.py +++ b/course_discovery/apps/learner_pathway/models.py @@ -299,13 +299,13 @@ def get_skills(self) -> [str]: return program_skills - def get_linked_courses_and_course_runs(self) -> [dict]: + def get_linked_courses_and_course_runs(self, forbidden_restriction_types=[]) -> [dict]: """ Returns list of dict where each dict contains a course key linked with program and all its course runs """ courses = [] for course in self.program.courses.all(): - course_runs = list(course.course_runs.filter(status=CourseRunStatus.Published).values('key')) + course_runs = list(course.course_runs.filter(status=CourseRunStatus.Published).exclude(restricted_run__restriction_type__in=forbidden_restriction_types).values('key')) courses.append({"key": course.key, "course_runs": course_runs}) return courses diff --git a/course_discovery/apps/learner_pathway/utils.py b/course_discovery/apps/learner_pathway/utils.py index e99fb5e13e..e75bd56313 100644 --- a/course_discovery/apps/learner_pathway/utils.py +++ b/course_discovery/apps/learner_pathway/utils.py @@ -6,7 +6,7 @@ def get_advertised_course_run_estimated_hours(course): active_course_runs = course.active_course_runs - if course.advertised_course_run: + if course.advertised_course_run(): advertised_course_run_uuid = course.advertised_course_run().uuid for course_run in active_course_runs: if course_run.uuid == advertised_course_run_uuid: From 43fecbb2b1e0bd7b99a9c7e2023f4a87b2586b9f Mon Sep 17 00:00:00 2001 From: Ali Nawaz Date: Wed, 8 May 2024 14:03:46 +0500 Subject: [PATCH 6/6] add tests --- .../api/v1/tests/test_views/test_catalogs.py | 22 +++- .../v1/tests/test_views/test_course_runs.py | 38 ++++++ .../api/v1/tests/test_views/test_courses.py | 37 +++++- .../api/v1/tests/test_views/test_programs.py | 25 +++- .../api/v1/tests/test_views/test_search.py | 108 +++++++++++++++++- .../api/v1/tests/test_views.py | 18 ++- 6 files changed, 230 insertions(+), 18 deletions(-) diff --git a/course_discovery/apps/api/v1/tests/test_views/test_catalogs.py b/course_discovery/apps/api/v1/tests/test_views/test_catalogs.py index c294e1384c..ddb4274e01 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_catalogs.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_catalogs.py @@ -11,6 +11,7 @@ from django.contrib.auth import get_user_model from django.core.management import call_command from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, OAuth2Mixin, SerializationMixin @@ -21,7 +22,7 @@ from course_discovery.apps.course_metadata.choices import CourseRunStatus from course_discovery.apps.course_metadata.models import Course, CourseType from course_discovery.apps.course_metadata.tests.factories import ( - CourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory + CourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory, RestrictedCourseRunFactory ) from course_discovery.conftest import get_course_run_states @@ -292,9 +293,10 @@ def test_courses_with_subjects_and_negative_query(self): assert response.data['results'] == self.serialize_catalog_course(desired_courses, many=True) @ddt.data( - *STATES() + [(st, res, lst) for st in STATES() for res in [True, False] for lst in ['', 'custom-b2c']] ) - def test_courses(self, state): + @ddt.unpack + def test_courses(self, state, restriction, restriction_list): """ Verify the endpoint returns the list of available courses contained in the catalog, and that courses appearing in the response always have at @@ -305,6 +307,8 @@ def test_courses(self, state): Course.objects.all().delete() course_run = CourseRunFactory(course__title='ABC Test Course') + if restriction: + RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c') for function in state: function(course_run) @@ -325,10 +329,16 @@ def test_courses(self, state): # Emulate prefetching behavior. filtered_course_run.delete() - assert response.data['results'] == self.serialize_catalog_course([course], many=True) + mock_request = APIRequestFactory() + mock_request.query_params = {'restriction_list', restriction_list} + assert response.data['results'] == self.serialize_catalog_course([course], many=True, extra_context={'request': mock_request}) + + if not restriction: + # Any course appearing in the response must have at least one serialized run. + assert response.data['results'][0]['course_runs'] + else: + assert not response.data['results'][0]['course_runs'] - # Any course appearing in the response must have at least one serialized run. - assert response.data['results'][0]['course_runs'] else: response = self.client.get(url) diff --git a/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py b/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py index f997d5b9c7..c088e5527a 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_course_runs.py @@ -53,6 +53,8 @@ def setUp(self): self.draft_course = CourseFactory(partner=self.partner, draft=True, product_source=self.product_source) self.draft_course_run = CourseRunFactory(course=self.draft_course, draft=True) self.draft_course_run.course.authoring_organizations.add(OrganizationFactory(key='course-id')) + self.restricted_run = CourseRunFactory(course__partner=self.partner) + RestrictedCourseRunFactory(course_run=self.restricted_run, restriction_type='custom-b2c') self.course_run_type = CourseRunTypeFactory(tracks=[TrackFactory()]) self.verified_type = CourseRunType.objects.get(slug=CourseRunType.VERIFIED_AUDIT) self.refresh_index() @@ -1197,6 +1199,21 @@ def test_list(self): response.data['results'], self.serialize_course_run(CourseRun.objects.all().order_by(Lower('key')), many=True) ) + restrieved_keys = [r.key for r in response.data['results']] + assert self.restricted_run.key not in restrieved_keys + + + def test_list_include_restricted(self): + """ Verify the endpoint returns a list of all course runs. """ + url = reverse('api:v1:course_run-list') + '?restriction_list=custom-b2c' + + with self.assertNumQueries(14, threshold=3): + response = self.client.get(url) + + assert response.status_code == 200 + restrieved_keys = [r.key for r in response.data['results']] + assert self.restricted_run.key in restrieved_keys + def test_list_sorted_by_course_start_date(self): """ Verify the endpoint returns a list of all course runs sorted by start date. """ @@ -1215,6 +1232,8 @@ def test_list_query(self): """ Verify the endpoint returns a filtered list of courses """ course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner) CourseRunFactory(title='non-matching name') + restricted_run = CourseRunFactory(title='Some random title', course__partner=self.partner) + RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c') query = 'title:Some random title' url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query) @@ -1226,6 +1245,25 @@ def test_list_query(self): key=lambda course_run: course_run['key']) self.assertListEqual(actual_sorted, expected_sorted) + + def test_list_query_include_restricted(self): + """ Verify the endpoint returns a filtered list of courses """ + course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner) + CourseRunFactory(title='non-matching name') + restricted_run = CourseRunFactory(title='Some random title', course__partner=self.partner) + RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c') + query = 'title:Some random title' + url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query) + url += '?restriction_list=custom-b2c,custom-b2b-enterprise' + + with self.assertNumQueries(25, threshold=3): + response = self.client.get(url) + + actual_sorted = sorted(response.data['results'], key=lambda course_run: course_run['key']) + expected_sorted = sorted(self.serialize_course_run([*course_runs, restricted_run], many=True), + key=lambda course_run: course_run['key']) + self.assertListEqual(actual_sorted, expected_sorted) + def assert_list_results(self, url, expected, extra_context=None): expected = sorted(expected, key=lambda course_run: course_run.key.lower()) response = self.client.get(url) diff --git a/course_discovery/apps/api/v1/tests/test_views/test_courses.py b/course_discovery/apps/api/v1/tests/test_views/test_courses.py index ae5a2018f6..bdba46ccc1 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_courses.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_courses.py @@ -33,7 +33,7 @@ disconnect_course_data_modified_timestamp_signal_handlers, product_meta_taggable_changed ) from course_discovery.apps.course_metadata.tests.factories import ( - CourseEditorFactory, CourseEntitlementFactory, CourseFactory, CourseLocationRestrictionFactory, CourseRunFactory, + CourseEditorFactory, CourseEntitlementFactory, CourseFactory, CourseLocationRestrictionFactory, CourseRunFactory, RestrictedCourseRunFactory, CourseTypeFactory, GeoLocationFactory, LevelTypeFactory, OrganizationFactory, ProductValueFactory, ProgramFactory, SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory ) @@ -278,6 +278,41 @@ def test_course_runs_are_ordered(self): self.assertListEqual(response.data['course_run_keys'], expected_keys) self.assertListEqual([run['key'] for run in response.data['course_runs']], expected_keys) + + def test_course_runs_are_restricted_by_default(self): + run_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Published) + run_not_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Unpublished) + RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c') + SeatFactory(course_run=run_restricted) + + url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) + with self.assertNumQueries(26, threshold=3): + response = self.client.get(url) + assert response.status_code == 200 + + self.assertListEqual(response.data['course_run_keys'], [run_not_restricted.key]) + self.assertListEqual(response.data['course_run_statuses'], [run_not_restricted.status]) + self.assertEqual(len(response.data['course_runs']), 1) + self.assertEqual(response.data['advertised_course_run_uuid'], None) + + def test_course_runs_restriction_param(self): + run_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Published) + run_not_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Unpublished) + RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c') + SeatFactory(course_run=run_restricted) + + url = reverse('api:v1:course-detail', kwargs={'key': self.course.key}) + url += '?restriction_list=custom-b2c' + with self.assertNumQueries(26, threshold=3): + response = self.client.get(url) + assert response.status_code == 200 + + self.assertListEqual(response.data['course_run_keys'], [run_not_restricted.key, run_restricted.key]) + self.assertListEqual(response.data['course_run_statuses'], [run_not_restricted.status, run_restricted.statuss]) + self.assertEqual(len(response.data['course_runs']), 2) + self.assertEqual(response.data['advertised_course_run_uuid'], run_restricted.uuid) + + def test_list(self): """ Verify the endpoint returns a list of all courses. """ url = reverse('api:v1:course-list') diff --git a/course_discovery/apps/api/v1/tests/test_views/test_programs.py b/course_discovery/apps/api/v1/tests/test_views/test_programs.py index a419013328..ec6c06cd22 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_programs.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_programs.py @@ -13,13 +13,13 @@ from course_discovery.apps.api.v1.views.programs import ProgramViewSet from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory from course_discovery.apps.core.tests.helpers import make_image_file -from course_discovery.apps.course_metadata.choices import ProgramStatus +from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus from course_discovery.apps.course_metadata.models import CourseType, Program, ProgramType from course_discovery.apps.course_metadata.tests.factories import ( CorporateEndorsementFactory, CourseFactory, CourseRunFactory, CurriculumCourseMembershipFactory, CurriculumFactory, CurriculumProgramMembershipFactory, DegreeAdditionalMetadataFactory, DegreeFactory, EndorsementFactory, ExpectedLearningItemFactory, JobOutlookItemFactory, OrganizationFactory, PersonFactory, ProgramFactory, - ProgramTypeFactory, VideoFactory + ProgramTypeFactory, RestrictedCourseRunFactory, VideoFactory ) @@ -48,13 +48,16 @@ def setup(self, client, django_assert_num_queries, partner): self.partner = partner self.request = request - def create_program(self, courses=None, program_type=None): + def create_program(self, courses=None, program_type=None, include_restricted_run=False): organizations = [OrganizationFactory(partner=self.partner)] person = PersonFactory() if courses is None: courses = [CourseFactory(partner=self.partner)] - CourseRunFactory(course=courses[0], staff=[person]) + cr = CourseRunFactory(course=courses[0], staff=[person]) + + if include_restricted_run: + RestrictedCourseRunFactory(course_run=cr, restriction_type='custom-b2c') if program_type is None: program_type = ProgramTypeFactory() @@ -216,6 +219,20 @@ def test_list(self): self.assert_list_results(self.list_path, expected, 26) + def test_list_exclude_restricted_by_default(self): + """ Verify the endpoint returns a list of all programs. """ + self.create_program(include_restricted_run=True) + resp = self.client.get(self.list_path) + self.assertListEqual(resp.data['results'][0]['courses'][0]['course_runs'], []) + self.assertEqual(resp.data['results'][0]['course_run_statuses'], []) + + def test_list_include_restricted(self): + """ Verify the endpoint returns a list of all programs. """ + self.create_program(include_restricted_run=True) + resp = self.client.get(self.list_path + '?restriction_list=custom-b2c') + self.assertEqual(len(resp.data['results'][0]['courses'][0]['course_runs']), 1) + self.assertEqual(resp.data['results'][0]['course_run_statuses'], [CourseRunStatus.Published]) + def test_extended_query_param_fields(self): """ Verify that the `extended` query param will result in an extended amount of fields returned. """ for _ in range(3): diff --git a/course_discovery/apps/api/v1/tests/test_views/test_search.py b/course_discovery/apps/api/v1/tests/test_views/test_search.py index 5e0213a20d..ceb21fa796 100644 --- a/course_discovery/apps/api/v1/tests/test_views/test_search.py +++ b/course_discovery/apps/api/v1/tests/test_views/test_search.py @@ -22,7 +22,7 @@ CourseRunSearchDocumentSerializer, CourseRunSearchModelSerializer, LimitedAggregateSearchSerializer ) from course_discovery.apps.course_metadata.tests.factories import ( - CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory, ProgramFactory, SeatFactory + CourseFactory, CourseRunFactory, OrganizationFactory, PersonFactory, PositionFactory, ProgramFactory, RestrictedCourseRunFactory, SeatFactory ) from course_discovery.apps.learner_pathway.models import LearnerPathway from course_discovery.apps.learner_pathway.tests.factories import LearnerPathwayStepFactory @@ -108,6 +108,40 @@ def test_search(self, path, serializer): """ Verify the view returns search results. """ self.assert_successful_search(path=path, serializer=serializer) + + + @ddt.data( + list_path, + detailed_path + ) + def test_search_restricted_default(self, path): + course_run = CourseRunFactory(course__partner=self.partner, course__title='Software Testing', + status=CourseRunStatus.Published) + RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c') + + response = self.get_response('software', path=path) + + assert response.status_code == 200 + assert response.data['results'] == [] + assert response.data['count'] == 0 + + @ddt.data( + list_path, + detailed_path + ) + def test_search_restricted_param(self, path): + course_run = CourseRunFactory(course__partner=self.partner, course__title='Software Testing', + status=CourseRunStatus.Published) + RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c') + + path = path + '?restriction_list=custom-b2c' + response = self.get_response('software', path=self.list_path) + + assert response.status_code == 200 + assert response.data['count'] == 1 + + + def test_faceted_search(self): """ Verify the view returns results and facets. """ course_run, response_data = self.assert_successful_search(path=self.faceted_path) @@ -322,6 +356,78 @@ def test_results_only_include_specific_key_objects(self): self.serialize_course_search(course) ] + + def test_results_exclude_restricted(self): + + CourseFactory( + key=self.regular_key, + title='ABCs of Ͳҽʂէìղց', + partner=self.partner + ) + course = CourseFactory( + key=self.desired_key, + title='ABCs of Ͳҽʂէìղց', + partner=self.partner + ) + course_run = CourseRunFactory( + course__partner=self.partner, + course=course, + status=CourseRunStatus.Published, + key=self.desired_key, + type__is_marketable=True + ) + RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c') + CourseRunFactory( + course__partner=self.partner, + status=CourseRunStatus.Published, + key=self.regular_key, + type__is_marketable=True + ) + response = self.get_response(query={'key.raw': self.desired_key}, endpoint=self.list_path) + assert response.status_code == 200 + response_data = response.json() + assert response_data["results"] == [ + self.serialize_course_search(course) + ] + + assert response_data["results"][0]["course_runs"] == [] + + def test_results_restricted_param(self): + + CourseFactory( + key=self.regular_key, + title='ABCs of Ͳҽʂէìղց', + partner=self.partner + ) + course = CourseFactory( + key=self.desired_key, + title='ABCs of Ͳҽʂէìղց', + partner=self.partner + ) + course_run = CourseRunFactory( + course__partner=self.partner, + course=course, + status=CourseRunStatus.Published, + key=self.desired_key, + type__is_marketable=True + ) + RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c') + CourseRunFactory( + course__partner=self.partner, + status=CourseRunStatus.Published, + key=self.regular_key, + type__is_marketable=True + ) + response = self.get_response(query={'key.raw': self.desired_key}, endpoint=self.list_path + '?restriction_list=custom-b2c') + assert response.status_code == 200 + response_data = response.json() + assert response_data["results"] == [ + self.serialize_course_run_search(course_run), + self.serialize_course_search(course) + ] + + assert not response_data["results"][1]["course_runs"] == [] + def test_results_include_active_course_runs(self): """ Verify the search results include course runs that are active (means the course run is currently open for diff --git a/course_discovery/apps/learner_pathway/api/v1/tests/test_views.py b/course_discovery/apps/learner_pathway/api/v1/tests/test_views.py index 071e7d9518..473d541352 100644 --- a/course_discovery/apps/learner_pathway/api/v1/tests/test_views.py +++ b/course_discovery/apps/learner_pathway/api/v1/tests/test_views.py @@ -8,7 +8,7 @@ from course_discovery.apps import learner_pathway from course_discovery.apps.core.tests.factories import UserFactory -from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory +from course_discovery.apps.course_metadata.tests.factories import CourseRunFactory, RestrictedCourseRunFactory from course_discovery.apps.learner_pathway.choices import PathwayStatus from course_discovery.apps.learner_pathway.tests.factories import ( LearnerPathwayCourseFactory, LearnerPathwayFactory, LearnerPathwayProgramFactory, LearnerPathwayStepFactory @@ -90,7 +90,7 @@ def setUp(self): course__title=LEARNER_PATHWAY_DATA['steps'][0]['courses'][0]['title'], course__short_description=LEARNER_PATHWAY_DATA['steps'][0]['courses'][0]['short_description'], ) - __ = CourseRunFactory( + self.learner_pathway_course__course_run = CourseRunFactory( course=self.learner_pathway_course.course, key=LEARNER_PATHWAY_DATA['steps'][0]['courses'][0]['course_runs'][0]['key'], status='published', @@ -106,7 +106,7 @@ def setUp(self): self.view_url = '/api/v1/learner-pathway/{}/'.format(self.learner_pathway.uuid) # reverse('learner-pathway') - def _verify_learner_pathway_data(self, api_response, expected_data): + def _verify_learner_pathway_data(self, api_response, expected_data, restricted=False): """ Verify that learner pathway api response matches the expected data. """ @@ -132,7 +132,10 @@ def _verify_learner_pathway_data(self, api_response, expected_data): api_response_step_course = data['steps'][0]['courses'][0] expected_lerner_pathway_step_course = expected_data['steps'][0]['courses'][0] for key, value in expected_lerner_pathway_step_course.items(): - assert api_response_step_course[key] == value + if restricted and key=='course_runs': + assert api_response_step_course[key] == [] + else: + assert api_response_step_course[key] == value # course card_image_url should not be empty assert api_response_step_course['card_image_url'] @@ -146,12 +149,15 @@ def _verify_learner_pathway_data(self, api_response, expected_data): # program card_image_url should not be empty assert api_response_step_course['card_image_url'] - def test_learner_pathway_api(self): + @ddt.data([True, False]) + def test_learner_pathway_api(self, restricted_run): """ Verify that learner pathway api returns the expected response. """ + if restricted_run: + RestrictedCourseRunFactory(course_run=self.learner_pathway_course__course_run, restriction_type='custom-b2c') api_response = self.client.get(self.view_url) - self._verify_learner_pathway_data(api_response, LEARNER_PATHWAY_DATA) + self._verify_learner_pathway_data(api_response, LEARNER_PATHWAY_DATA, restricted=restricted_run) def test_learner_pathway_api_filtering(self): """