diff --git a/learning_resources/serializers.py b/learning_resources/serializers.py index 96e5576a7b..401f4f52d2 100644 --- a/learning_resources/serializers.py +++ b/learning_resources/serializers.py @@ -888,3 +888,68 @@ class Meta: model = models.UserListRelationship extra_kwargs = {"position": {"required": False}} exclude = COMMON_IGNORED_FIELDS + + +class BaseRelationshipRequestSerializer(serializers.Serializer): + """ + Base class for validating requests that set relationships between + learning resources + """ + + learning_resource_id = serializers.IntegerField() + + def validate_learning_resource_id(self, learning_resource_id): + """Ensure that the learning resource exists""" + try: + models.LearningResource.objects.get(id=learning_resource_id) + except models.LearningResource.DoesNotExist as dne: + msg = f"Invalid learning resource id: {learning_resource_id}" + raise ValidationError(msg) from dne + return learning_resource_id + + +class SetLearningPathsRequestSerializer(BaseRelationshipRequestSerializer): + """ + Validate request parameters for setting learning paths for a learning resource + """ + + learning_path_ids = serializers.ListField( + child=serializers.IntegerField(), allow_empty=True + ) + + def validate_learning_path_ids(self, learning_path_ids): + """Ensure that the learning paths exist""" + valid_learning_path_ids = set( + models.LearningResource.objects.filter( + id__in=learning_path_ids, + resource_type=LearningResourceType.learning_path.name, + ).values_list("id", flat=True) + ) + missing = set(learning_path_ids).difference(valid_learning_path_ids) + if missing: + msg = f"Invalid learning path ids: {missing}" + raise ValidationError(msg) + return learning_path_ids + + +class SetUserListsRequestSerializer(BaseRelationshipRequestSerializer): + """ + Validate request parameters for setting userlist for a learning resource + """ + + userlist_ids = serializers.ListField( + child=serializers.IntegerField(), allow_empty=True + ) + + def validate_userlist_ids(self, userlist_ids): + """Ensure that the learning paths exist""" + valid_userlist_ids = set( + models.UserList.objects.filter( + id__in=userlist_ids, + ).values_list("id", flat=True) + ) + missing = set(userlist_ids).difference(valid_userlist_ids) + if missing: + msg = f"Invalid learning path ids: {missing}" + raise ValidationError(msg) + return userlist_ids diff --git a/learning_resources/serializers_test.py b/learning_resources/serializers_test.py index 960c0bc622..4c49eee142 100644 --- a/learning_resources/serializers_test.py +++ b/learning_resources/serializers_test.py @@ -550,3 +550,57 @@ def test_content_file_serializer(settings, expected_types, has_channels): ), }, ) + + +def test_set_learning_path_request_serializer(): + """Test serializer for setting learning path relationships""" + lists = factories.LearningPathFactory.create_batch(2) + resource = factories.LearningResourceFactory.create() + + serializer = serializers.SetLearningPathsRequestSerializer() + + data1 = { + "learning_path_ids": [ + str(lists[0].learning_resource.id), + lists[1].learning_resource.id, + ], + "learning_resource_id": str(resource.id), + } + assert serializer.to_internal_value(data1) == { + "learning_path_ids": [ + lists[0].learning_resource.id, + lists[1].learning_resource.id, + ], + "learning_resource_id": resource.id, + } + + invalid = serializers.SetLearningPathsRequestSerializer( + data={"learning_path_ids": [1, 2], "learning_resource_id": 3} + ) + assert invalid.is_valid() is False + assert "learning_path_ids" in invalid.errors + assert "learning_resource_id" in invalid.errors + + +def test_set_userlist_request_serializer(): + """Test serializer for setting userlist relationships""" + lists = factories.UserListFactory.create_batch(2) + resource = factories.LearningResourceFactory.create() + + serializer = serializers.SetUserListsRequestSerializer() + + data1 = { + "userlist_ids": [str(lists[0].id), lists[1].id], + "learning_resource_id": str(resource.id), + } + assert serializer.to_internal_value(data1) == { + "userlist_ids": [lists[0].id, lists[1].id], + "learning_resource_id": resource.id, + } + + invalid = serializers.SetUserListsRequestSerializer( + data={"userlist_ids": [1, 2], "learning_resource_id": 3} + ) + assert invalid.is_valid() is False + assert "userlist_ids" in invalid.errors + assert "learning_resource_id" in invalid.errors diff --git a/learning_resources/views.py b/learning_resources/views.py index 683742d110..9e1b699bc1 100644 --- a/learning_resources/views.py +++ b/learning_resources/views.py @@ -74,6 +74,8 @@ PodcastEpisodeResourceSerializer, PodcastResourceSerializer, ProgramResourceSerializer, + SetLearningPathsRequestSerializer, + SetUserListsRequestSerializer, UserListRelationshipSerializer, UserListSerializer, VideoPlaylistResourceSerializer, @@ -433,10 +435,16 @@ def userlists(self, request, *args, **kwargs): # noqa: ARG002 """ Set User List relationships for a given Learning Resource """ - learning_resource_id = kwargs.get("pk") - user_list_ids = request.query_params.getlist("userlist_id") + req_data = SetUserListsRequestSerializer().to_internal_value( + { + "userlist_ids": request.query_params.getlist("userlist_id"), + "learning_resource_id": kwargs.get("pk"), + } + ) + learning_resource_id = req_data["learning_resource_id"] + userlist_ids = req_data["userlist_ids"] if ( - UserList.objects.filter(pk__in=user_list_ids) + UserList.objects.filter(pk__in=userlist_ids) .exclude(author=request.user) .exists() ): @@ -445,9 +453,14 @@ def userlists(self, request, *args, **kwargs): # noqa: ARG002 current_relationships = UserListRelationship.objects.filter( parent__author=request.user, child_id=learning_resource_id ) - current_relationships.exclude(parent_id__in=user_list_ids).delete() - for userlist_id in user_list_ids: + + # Remove the resource from lists it WAS in before but is not in now + current_relationships.exclude(parent_id__in=userlist_ids).delete() + current_parent_lists = current_relationships.values_list("parent_id", flat=True) + + for userlist_id in userlist_ids: last_index = 0 + # re-number the positions for surviving items for index, relationship in enumerate( UserListRelationship.objects.filter( parent__author=request.user, parent__id=userlist_id @@ -456,11 +469,13 @@ def userlists(self, request, *args, **kwargs): # noqa: ARG002 relationship.position = index relationship.save() last_index = index - UserListRelationship.objects.create( - parent_id=userlist_id, - child_id=learning_resource_id, - position=last_index + 1, - ) + # Add new items as necessary + if userlist_id not in list(current_parent_lists): + UserListRelationship.objects.create( + parent_id=userlist_id, + child_id=learning_resource_id, + position=last_index + 1, + ) SerializerClass = self.get_serializer_class() serializer = SerializerClass(current_relationships, many=True) return Response(serializer.data) @@ -489,14 +504,25 @@ def learning_paths(self, request, *args, **kwargs): # noqa: ARG002 """ Set Learning Path relationships for a given Learning Resource """ - learning_resource_id = kwargs.get("pk") - learning_path_ids = request.query_params.getlist("learning_path_id") + req_data = SetLearningPathsRequestSerializer().to_internal_value( + { + "learning_path_ids": request.query_params.getlist("learning_path_id"), + "learning_resource_id": kwargs.get("pk"), + } + ) + learning_resource_id = req_data["learning_resource_id"] + learning_path_ids = req_data["learning_path_ids"] current_relationships = LearningResourceRelationship.objects.filter( child_id=learning_resource_id ) + # Remove the resource from lists it WAS in before but is not in now current_relationships.exclude(parent_id__in=learning_path_ids).delete() - for learning_path_id in learning_path_ids: + current_parent_lists = current_relationships.values_list("parent_id", flat=True) + + for learning_path_id_str in learning_path_ids: + learning_path_id = int(learning_path_id_str) last_index = 0 + # re-number the positions for surviving items for index, relationship in enumerate( LearningResourceRelationship.objects.filter( parent__id=learning_path_id @@ -505,12 +531,15 @@ def learning_paths(self, request, *args, **kwargs): # noqa: ARG002 relationship.position = index relationship.save() last_index = index - LearningResourceRelationship.objects.create( - parent_id=learning_path_id, - child_id=learning_resource_id, - relation_type=LearningResourceRelationTypes.LEARNING_PATH_ITEMS, - position=last_index + 1, - ) + + # Add new items as necessary + if learning_path_id not in list(current_parent_lists): + LearningResourceRelationship.objects.create( + parent_id=learning_path_id, + child_id=learning_resource_id, + relation_type=LearningResourceRelationTypes.LEARNING_PATH_ITEMS, + position=last_index + 1, + ) SerializerClass = self.get_serializer_class() serializer = SerializerClass(current_relationships, many=True) return Response(serializer.data) diff --git a/learning_resources/views_learningpath_test.py b/learning_resources/views_learningpath_test.py index 9d479771dc..51709bad99 100644 --- a/learning_resources/views_learningpath_test.py +++ b/learning_resources/views_learningpath_test.py @@ -425,3 +425,39 @@ def test_set_learning_path_relationships(client, staff_user): assert not course.learning_resource.learning_path_parents.filter( parent__id=previous_learning_path.learning_resource.id ).exists() + + +def test_adding_to_learning_path_not_effect_existing_membership(client, staff_user): + """ + Given L1 (existing parent), L2 (new parent), and R (resource), + test that adding R to L2 does not affect L1. + """ + course = factories.CourseFactory.create() + + existing_parent = factories.LearningPathFactory.create(author=staff_user) + factories.LearningPathRelationshipFactory.create( + parent=existing_parent.learning_resource, child=course.learning_resource + ) + new_additional_parent = factories.LearningPathFactory.create(author=staff_user) + + prev_parent_count = existing_parent.learning_resource.resources.count() + new_additional_parent_count = ( + new_additional_parent.learning_resource.resources.count() + ) + + url = reverse( + "lr:v1:learning_resource_relationships_api-learning-paths", + args=[course.learning_resource.id], + ) + client.force_login(staff_user) + lps = [existing_parent, new_additional_parent] + resp = client.patch( + f"{url}?{"".join([f"learning_path_id={lp.learning_resource.id}&" for lp in lps])}" + ) + + assert resp.status_code == 200 + assert prev_parent_count == existing_parent.learning_resource.resources.count() + assert ( + new_additional_parent_count + 1 + == new_additional_parent.learning_resource.resources.count() + ) diff --git a/learning_resources/views_userlist_test.py b/learning_resources/views_userlist_test.py index d404b19741..ff8679b09d 100644 --- a/learning_resources/views_userlist_test.py +++ b/learning_resources/views_userlist_test.py @@ -350,3 +350,34 @@ def assign_userlists(course, userlists): assert ( UserListRelationship.objects.filter(child=course.learning_resource).count() == 3 ) + + +def test_adding_to_userlist_not_effect_existing_membership(client, user): + """ + Given L1 (existing parent), L2 (new parent), and R (resource), + test that adding R to L2 does not affect L1. + """ + course = factories.CourseFactory.create() + + existing_parent = factories.UserListFactory.create(author=user) + factories.UserListRelationshipFactory.create( + parent=existing_parent, child=course.learning_resource + ) + new_additional_parent = factories.UserListFactory.create(author=user) + + prev_parent_count = existing_parent.resources.count() + new_additional_parent_count = new_additional_parent.resources.count() + + url = reverse( + "lr:v1:learning_resource_relationships_api-userlists", + args=[course.learning_resource.id], + ) + client.force_login(user) + lists = [existing_parent, new_additional_parent] + resp = client.patch( + f"{url}?{"".join([f"userlist_id={userlist.id}&" for userlist in lists])}" + ) + + assert resp.status_code == 200 + assert prev_parent_count == existing_parent.resources.count() + assert new_additional_parent_count + 1 == new_additional_parent.resources.count()