Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit #4349

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 35 additions & 11 deletions course_discovery/apps/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ class MinimalCourseSerializer(FlexFieldsSerializerMixin, TimestampModelSerialize
url_slug = serializers.SerializerMethodField()
course_type = serializers.SerializerMethodField()
enterprise_subscription_inclusion = serializers.BooleanField(required=False)
course_run_statuses = serializers.ReadOnlyField()
course_run_statuses = serializers.SerializerMethodField()

@classmethod
def prefetch_queryset(cls, queryset=None, course_runs=None):
Expand All @@ -1207,6 +1207,10 @@ def get_course_type(self, obj):
def get_url_slug(self, obj): # pylint: disable=unused-argument
return None # this has been removed from the MinimalCourseSerializer, set to None to not break APIs

def get_course_run_statuses(self, course):
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')
return course.course_run_statuses(restriction_list=restriction_list)

@classmethod
def prefetch_course_runs(cls, serializer_class, course_runs=None):
"""Returns a Prefetch object for the course runs in a course."""
Expand Down Expand Up @@ -1317,7 +1321,6 @@ class CourseSerializer(TaggitSerializer, MinimalCourseSerializer):
url_slug = serializers.SlugField(read_only=True, source='active_url_slug')
url_slug_history = serializers.SlugRelatedField(slug_field='url_slug', read_only=True, many=True)
url_redirects = serializers.SlugRelatedField(slug_field='value', read_only=True, many=True)
course_run_statuses = serializers.ReadOnlyField()
editors = CourseEditorSerializer(many=True, read_only=True)
collaborators = SlugRelatedFieldWithReadSerializer(slug_field='uuid', required=False, many=True,
queryset=Collaborator.objects.all(),
Expand Down Expand Up @@ -1400,7 +1403,7 @@ class Meta(MinimalCourseSerializer.Meta):
'syllabus_raw', 'outcome', 'original_image', 'card_image_url', 'canonical_course_run_key',
'extra_description', 'additional_information', 'additional_metadata', 'faq', 'learner_testimonials',
'enrollment_count', 'recent_enrollment_count', 'topics', 'partner', 'key_for_reruns', 'url_slug',
'url_slug_history', 'url_redirects', 'course_run_statuses', 'editors', 'collaborators', 'skill_names',
'url_slug_history', 'url_redirects', 'editors', 'collaborators', 'skill_names',
'skills', 'organization_short_code_override', 'organization_logo_override_url',
'enterprise_subscription_inclusion', 'geolocation', 'location_restriction', 'in_year_value',
'product_source', 'data_modified_timestamp', 'excluded_from_search', 'excluded_from_seo', 'watchers',
Expand Down Expand Up @@ -1603,16 +1606,22 @@ def prefetch_queryset(cls, partner, queryset=None, course_runs=None, programs=No
)

def get_advertised_course_run_uuid(self, course):
if course.advertised_course_run:
return course.advertised_course_run.uuid
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')
advertised_run = course.advertised_course_run(restriction_list)
if advertised_run:
return advertised_run.uuid
return None

def get_course_run_keys(self, course):
return [course_run.key for course_run in course.course_runs.all()]
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')

return [course_run.key for course_run in CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list)]

def get_course_runs(self, course):
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')

return CourseRunSerializer(
course.course_runs,
CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list),
many=True,
context={
'request': self.context.get('request'),
Expand Down Expand Up @@ -1667,7 +1676,9 @@ def get_marketing_url(self, obj):
)

def get_course_run_keys(self, course):
return [course_run.key for course_run in course.course_runs.all()]
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')

return [course_run.key for course_run in CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list)]

class Meta(CourseSerializer.Meta):
model = Course
Expand Down Expand Up @@ -1710,8 +1721,10 @@ def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
)

def get_course_runs(self, course):
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')

return CourseRunSerializer(
course.course_runs,
CourseRun.get_exposed_runs(course.course_runs.all(), restriction_list),
many=True,
context=self.context
).data
Expand All @@ -1735,7 +1748,10 @@ class MinimalProgramCourseSerializer(MinimalCourseSerializer):
course_runs = serializers.SerializerMethodField()

def get_course_runs(self, course):
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')

course_runs = self.context['course_runs']
course_runs = CourseRun.get_exposed_runs(course_runs, restriction_list)
course_runs = [course_run for course_run in course_runs if course_run.course == course]

if self.context.get('published_course_runs_only'):
Expand Down Expand Up @@ -1979,7 +1995,7 @@ class MinimalProgramSerializer(TaggitSerializer, FlexFieldsSerializerMixin, Base
degree = DegreeSerializer()
curricula = CurriculumSerializer(many=True)
card_image_url = serializers.SerializerMethodField()
course_run_statuses = serializers.ReadOnlyField()
course_run_statuses = serializers.SerializerMethodField()
organization_short_code_override = serializers.CharField(required=False, allow_blank=True)
organization_logo_override_url = serializers.SerializerMethodField()
primary_subject_override = SubjectSerializer()
Expand All @@ -1995,6 +2011,10 @@ def get_organization_logo_override_url(self, obj):
return logo_image_override.url
return None

def get_course_run_statuses(self, program):
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we are using restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',') in many places in this file. Maybe add a unified util/method to avoid repeating this?

return program.course_run_statuses(restriction_list=restriction_list)

@classmethod
def prefetch_queryset(cls, partner, queryset=None):
# Explicitly check if the queryset is None before selecting related
Expand Down Expand Up @@ -2299,7 +2319,11 @@ class PathwaySerializer(BaseModelSerializer):
description = serializers.CharField()
destination_url = serializers.CharField()
pathway_type = serializers.CharField()
course_run_statuses = serializers.ReadOnlyField()
course_run_statuses = serializers.SerializerMethodField()

def get_course_run_statuses(self, pathway):
restriction_list = self.context['request'].query_params.get('restriction_list', '').split(',')
return pathway.course_run_statuses(restriction_list=restriction_list)

@classmethod
def prefetch_queryset(cls, partner):
Expand Down
22 changes: 16 additions & 6 deletions course_discovery/apps/api/v1/tests/test_views/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.contrib.auth import get_user_model
from django.core.management import call_command
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory

from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase, OAuth2Mixin, SerializationMixin
Expand All @@ -21,7 +22,7 @@
from course_discovery.apps.course_metadata.choices import CourseRunStatus
from course_discovery.apps.course_metadata.models import Course, CourseType
from course_discovery.apps.course_metadata.tests.factories import (
CourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory
CourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory, RestrictedCourseRunFactory
)
from course_discovery.conftest import get_course_run_states

Expand Down Expand Up @@ -292,9 +293,10 @@ def test_courses_with_subjects_and_negative_query(self):
assert response.data['results'] == self.serialize_catalog_course(desired_courses, many=True)

@ddt.data(
*STATES()
[(st, res, lst) for st in STATES() for res in [True, False] for lst in ['', 'custom-b2c']]
)
def test_courses(self, state):
@ddt.unpack
def test_courses(self, state, restriction, restriction_list):
"""
Verify the endpoint returns the list of available courses contained in
the catalog, and that courses appearing in the response always have at
Expand All @@ -305,6 +307,8 @@ def test_courses(self, state):
Course.objects.all().delete()

course_run = CourseRunFactory(course__title='ABC Test Course')
if restriction:
RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c')
for function in state:
function(course_run)

Expand All @@ -325,10 +329,16 @@ def test_courses(self, state):
# Emulate prefetching behavior.
filtered_course_run.delete()

assert response.data['results'] == self.serialize_catalog_course([course], many=True)
mock_request = APIRequestFactory()
mock_request.query_params = {'restriction_list', restriction_list}
assert response.data['results'] == self.serialize_catalog_course([course], many=True, extra_context={'request': mock_request})

if not restriction:
# Any course appearing in the response must have at least one serialized run.
assert response.data['results'][0]['course_runs']
else:
assert not response.data['results'][0]['course_runs']

# Any course appearing in the response must have at least one serialized run.
assert response.data['results'][0]['course_runs']
else:
response = self.client.get(url)

Expand Down
38 changes: 38 additions & 0 deletions course_discovery/apps/api/v1/tests/test_views/test_course_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def setUp(self):
self.draft_course = CourseFactory(partner=self.partner, draft=True, product_source=self.product_source)
self.draft_course_run = CourseRunFactory(course=self.draft_course, draft=True)
self.draft_course_run.course.authoring_organizations.add(OrganizationFactory(key='course-id'))
self.restricted_run = CourseRunFactory(course__partner=self.partner)
RestrictedCourseRunFactory(course_run=self.restricted_run, restriction_type='custom-b2c')
self.course_run_type = CourseRunTypeFactory(tracks=[TrackFactory()])
self.verified_type = CourseRunType.objects.get(slug=CourseRunType.VERIFIED_AUDIT)
self.refresh_index()
Expand Down Expand Up @@ -1197,6 +1199,21 @@ def test_list(self):
response.data['results'],
self.serialize_course_run(CourseRun.objects.all().order_by(Lower('key')), many=True)
)
restrieved_keys = [r.key for r in response.data['results']]
assert self.restricted_run.key not in restrieved_keys


def test_list_include_restricted(self):
""" Verify the endpoint returns a list of all course runs. """
url = reverse('api:v1:course_run-list') + '?restriction_list=custom-b2c'

with self.assertNumQueries(14, threshold=3):
response = self.client.get(url)

assert response.status_code == 200
restrieved_keys = [r.key for r in response.data['results']]
assert self.restricted_run.key in restrieved_keys


def test_list_sorted_by_course_start_date(self):
""" Verify the endpoint returns a list of all course runs sorted by start date. """
Expand All @@ -1215,6 +1232,8 @@ def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """
course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner)
CourseRunFactory(title='non-matching name')
restricted_run = CourseRunFactory(title='Some random title', course__partner=self.partner)
RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c')
query = 'title:Some random title'
url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query)

Expand All @@ -1226,6 +1245,25 @@ def test_list_query(self):
key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted)


def test_list_query_include_restricted(self):
""" Verify the endpoint returns a filtered list of courses """
course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner)
CourseRunFactory(title='non-matching name')
restricted_run = CourseRunFactory(title='Some random title', course__partner=self.partner)
RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c')
query = 'title:Some random title'
url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query)
url += '?restriction_list=custom-b2c,custom-b2b-enterprise'

with self.assertNumQueries(25, threshold=3):
response = self.client.get(url)

actual_sorted = sorted(response.data['results'], key=lambda course_run: course_run['key'])
expected_sorted = sorted(self.serialize_course_run([*course_runs, restricted_run], many=True),
key=lambda course_run: course_run['key'])
self.assertListEqual(actual_sorted, expected_sorted)

def assert_list_results(self, url, expected, extra_context=None):
expected = sorted(expected, key=lambda course_run: course_run.key.lower())
response = self.client.get(url)
Expand Down
37 changes: 36 additions & 1 deletion course_discovery/apps/api/v1/tests/test_views/test_courses.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
disconnect_course_data_modified_timestamp_signal_handlers, product_meta_taggable_changed
)
from course_discovery.apps.course_metadata.tests.factories import (
CourseEditorFactory, CourseEntitlementFactory, CourseFactory, CourseLocationRestrictionFactory, CourseRunFactory,
CourseEditorFactory, CourseEntitlementFactory, CourseFactory, CourseLocationRestrictionFactory, CourseRunFactory, RestrictedCourseRunFactory,
CourseTypeFactory, GeoLocationFactory, LevelTypeFactory, OrganizationFactory, ProductValueFactory, ProgramFactory,
SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory
)
Expand Down Expand Up @@ -278,6 +278,41 @@ def test_course_runs_are_ordered(self):
self.assertListEqual(response.data['course_run_keys'], expected_keys)
self.assertListEqual([run['key'] for run in response.data['course_runs']], expected_keys)


def test_course_runs_are_restricted_by_default(self):
run_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Published)
run_not_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Unpublished)
RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c')
SeatFactory(course_run=run_restricted)

url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(26, threshold=3):
response = self.client.get(url)
assert response.status_code == 200

self.assertListEqual(response.data['course_run_keys'], [run_not_restricted.key])
self.assertListEqual(response.data['course_run_statuses'], [run_not_restricted.status])
self.assertEqual(len(response.data['course_runs']), 1)
self.assertEqual(response.data['advertised_course_run_uuid'], None)

def test_course_runs_restriction_param(self):
run_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Published)
run_not_restricted = CourseRunFactory(course=self.course, start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC), status=CourseRunStatus.Unpublished)
RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c')
SeatFactory(course_run=run_restricted)

url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?restriction_list=custom-b2c'
with self.assertNumQueries(26, threshold=3):
response = self.client.get(url)
assert response.status_code == 200

self.assertListEqual(response.data['course_run_keys'], [run_not_restricted.key, run_restricted.key])
self.assertListEqual(response.data['course_run_statuses'], [run_not_restricted.status, run_restricted.statuss])
self.assertEqual(len(response.data['course_runs']), 2)
self.assertEqual(response.data['advertised_course_run_uuid'], run_restricted.uuid)


def test_list(self):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
Expand Down
25 changes: 21 additions & 4 deletions course_discovery/apps/api/v1/tests/test_views/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from course_discovery.apps.api.v1.views.programs import ProgramViewSet
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import CourseType, Program, ProgramType
from course_discovery.apps.course_metadata.tests.factories import (
CorporateEndorsementFactory, CourseFactory, CourseRunFactory, CurriculumCourseMembershipFactory, CurriculumFactory,
CurriculumProgramMembershipFactory, DegreeAdditionalMetadataFactory, DegreeFactory, EndorsementFactory,
ExpectedLearningItemFactory, JobOutlookItemFactory, OrganizationFactory, PersonFactory, ProgramFactory,
ProgramTypeFactory, VideoFactory
ProgramTypeFactory, RestrictedCourseRunFactory, VideoFactory
)


Expand Down Expand Up @@ -48,13 +48,16 @@ def setup(self, client, django_assert_num_queries, partner):
self.partner = partner
self.request = request

def create_program(self, courses=None, program_type=None):
def create_program(self, courses=None, program_type=None, include_restricted_run=False):
organizations = [OrganizationFactory(partner=self.partner)]
person = PersonFactory()

if courses is None:
courses = [CourseFactory(partner=self.partner)]
CourseRunFactory(course=courses[0], staff=[person])
cr = CourseRunFactory(course=courses[0], staff=[person])

if include_restricted_run:
RestrictedCourseRunFactory(course_run=cr, restriction_type='custom-b2c')

if program_type is None:
program_type = ProgramTypeFactory()
Expand Down Expand Up @@ -216,6 +219,20 @@ def test_list(self):

self.assert_list_results(self.list_path, expected, 26)

def test_list_exclude_restricted_by_default(self):
""" Verify the endpoint returns a list of all programs. """
self.create_program(include_restricted_run=True)
resp = self.client.get(self.list_path)
self.assertListEqual(resp.data['results'][0]['courses'][0]['course_runs'], [])
self.assertEqual(resp.data['results'][0]['course_run_statuses'], [])

def test_list_include_restricted(self):
""" Verify the endpoint returns a list of all programs. """
self.create_program(include_restricted_run=True)
resp = self.client.get(self.list_path + '?restriction_list=custom-b2c')
self.assertEqual(len(resp.data['results'][0]['courses'][0]['course_runs']), 1)
self.assertEqual(resp.data['results'][0]['course_run_statuses'], [CourseRunStatus.Published])

def test_extended_query_param_fields(self):
""" Verify that the `extended` query param will result in an extended amount of fields returned. """
for _ in range(3):
Expand Down