diff --git a/forum/models/comments.py b/forum/models/comments.py index bcbe4c2d..171a193f 100644 --- a/forum/models/comments.py +++ b/forum/models/comments.py @@ -5,6 +5,7 @@ from bson import ObjectId +from forum.models.threads import CommentThread from forum.models.contents import BaseContents from forum.models.users import Users @@ -96,10 +97,10 @@ def insert( Returns: str: The ID of the inserted document. """ - parent_comment = parent_id and self.get(parent_id) - parent_child_count = parent_comment and parent_comment.get("child_count") or 0 - if parent_comment and not comment_thread_id: - comment_thread_id = parent_comment.get("comment_thread_id") + if parent_id and not comment_thread_id: + parent_comment = self.get(parent_id) + if parent_comment: + comment_thread_id = parent_comment.get("comment_thread_id") date = datetime.now() comment_data = { @@ -115,7 +116,6 @@ def insert( "endorsed": False, "anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers, - "parent_id": ObjectId(parent_id) if parent_id else None, "author_id": author_id, "comment_thread_id": ObjectId(comment_thread_id), "child_count": 0, @@ -124,9 +124,14 @@ def insert( "created_at": date, "updated_at": date, } + if parent_id: + comment_data["parent_id"] = ObjectId(parent_id) + result = self._collection.insert_one(comment_data) if parent_id: - self.update(parent_id, child_count=parent_child_count + 1) + self.update_child_count_in_parent_comment(parent_id, 1) + if comment_thread_id: + self.update_comment_count_in_comment_thread(comment_thread_id, 1) return str(result.inserted_id) def update( @@ -238,16 +243,77 @@ def delete(self, _id: str) -> int: """ comment = self.get(_id) parent_comment_id = comment and comment.get("parent_id") - parent_comment = parent_comment_id and self.get(parent_comment_id) - parent_comment_child_count = parent_comment and parent_comment.get( - "child_count", - ) + child_comments_deleted_count = 0 + if not parent_comment_id: + child_comments_deleted_count = self.delete_child_comments(_id) + result = self._collection.delete_one({"_id": ObjectId(_id)}) - if parent_comment_id and parent_comment_child_count: - self.update(parent_comment_id, child_count=parent_comment_child_count - 1) - return result.deleted_count + if parent_comment_id: + self.update_child_count_in_parent_comment(parent_comment_id, -1) + + no_of_comments_delete = result.deleted_count + child_comments_deleted_count + comment_thread_id = comment and comment.get("comment_thread_id") + if comment_thread_id: + self.update_comment_count_in_comment_thread( + comment_thread_id, -(int(no_of_comments_delete)) + ) + return no_of_comments_delete def get_author_username(self, author_id: str) -> str | None: """Return username for the respective author_id(user_id)""" user = Users().get(author_id) return user.get("username") if user else None + + def delete_child_comments(self, _id: str) -> int: + """ + Delete child comments from the database based on the id. + + Args: + _id: The ID of the parent comment whose child comments will be deleted. + + Returns: + The number of child comments deleted. + """ + child_comments_to_delete = self.find({"parent_id": ObjectId(_id)}) + child_comment_ids_to_delete = [ + child_comment.get("_id") for child_comment in child_comments_to_delete + ] + child_comments_deleted = self._collection.delete_many( + {"_id": {"$in": child_comment_ids_to_delete}} + ) + return child_comments_deleted.deleted_count + + def update_child_count_in_parent_comment(self, parent_id: str, count: int) -> None: + """ + Update(increment/decrement) child_count in parent comment. + + Args: + parent_id: The ID of the parent comment whose child_count will be updated. + count: It can be any number. + If positive, this function will increase child_count by the count. + If negative, this function will decrease child_count by the count. + + Returns: + None. + """ + update_child_count_query = {"$inc": {"child_count": count}} + self.update_count(parent_id, update_child_count_query) + + def update_comment_count_in_comment_thread( + self, comment_thread_id: str, count: int + ) -> None: + """ + Update(increment/decrement) comment_count in comment thread. + + Args: + comment_thread_id: The ID of the comment thread + whose comment_count will be updated. + count: It can be any number. + If positive, this function will increase comment_count by the count. + If negative, this function will decrease comment_count by the count. + + Returns: + None. + """ + update_comment_count_query = {"$inc": {"comment_count": count}} + CommentThread().update_count(comment_thread_id, update_comment_count_query) diff --git a/forum/models/contents.py b/forum/models/contents.py index fda2d86a..94aefc55 100644 --- a/forum/models/contents.py +++ b/forum/models/contents.py @@ -97,6 +97,20 @@ def update_votes(self, content_id: str, votes: dict[str, Any]) -> int: ) return result.modified_count + def update_count(self, content_id: str, query: dict[str, Any]) -> int: + """ + Updates count of a field in the content document based on query. + + Args: + content_id (str): The id of the content(Commentthread id or Comment id) model. + query (dict[str, Any]): Query to update the count in a specific field. + """ + result = self._collection.update_one( + {"_id": ObjectId(content_id)}, + query, + ) + return result.modified_count + class Contents(BaseContents): """ diff --git a/forum/models/model_utils.py b/forum/models/model_utils.py index b850e8de..e2152864 100644 --- a/forum/models/model_utils.py +++ b/forum/models/model_utils.py @@ -202,14 +202,6 @@ def remove_vote(thread: dict[str, Any], user: dict[str, Any]) -> bool: return update_vote(thread, user, is_deleted=True) -def get_comments_count(thread_id: str) -> int: - """ - Returns that comments count in a perticular thread - """ - comments = list(Comment().list(comment_thread_id=ObjectId(thread_id))) - return len(comments) if comments else 0 - - def validate_thread_and_user( user_id: str, thread_id: str ) -> tuple[dict[str, Any], dict[str, Any]]: diff --git a/forum/serializers/thread.py b/forum/serializers/thread.py index 64bcc62b..15506a98 100644 --- a/forum/serializers/thread.py +++ b/forum/serializers/thread.py @@ -8,7 +8,6 @@ from forum.models.model_utils import ( get_abuse_flagged_count, - get_comments_count, get_endorsed, get_read_states, get_username_from_id, @@ -32,7 +31,7 @@ class ThreadSerializer(ContentSerializer): tags (list): A list of tags associated with the thread. group_id (str or None): The ID of the group associated with the thread, if any. pinned (bool): Whether the thread is pinned at the top of the list. - comments_count (int): The number of comments on the thread. + comment_count (int): The number of comments on the thread. This serializer extends the `ThreadSerializer` and customizes fields based on various context parameters. It manages fields related to read state, comment counts, endorsements, abuse flags, @@ -59,7 +58,7 @@ class ThreadSerializer(ContentSerializer): tags = serializers.ListField(default=[]) group_id = serializers.CharField(allow_null=True, default=None) pinned = serializers.BooleanField(default=False) - comments_count = serializers.SerializerMethodField() + comment_count = serializers.IntegerField(default=0) read = serializers.SerializerMethodField() unread_comments_count = serializers.SerializerMethodField() endorsed = serializers.SerializerMethodField() @@ -126,7 +125,7 @@ def get_read(self, obj: dict[str, Any]) -> Optional[bool]: course_id = obj["course_id"] thread_key = obj["id"] is_read, _ = get_read_states([obj], user_id, course_id).get( - thread_key, (False, obj["comments_count"]) + thread_key, (False, obj["comment_count"]) ) return is_read return None @@ -148,7 +147,7 @@ def get_unread_comments_count(self, obj: dict[str, Any]) -> Optional[int]: course_id = obj["course_id"] thread_key = obj["id"] _, unread_count = get_read_states([obj], user_id, course_id).get( - thread_key, (False, obj["comments_count"]) + thread_key, (False, obj["comment_count"]) ) return unread_count return None @@ -243,11 +242,6 @@ def update(self, instance: Any, validated_data: dict[str, Any]) -> Any: """Raise NotImplementedError""" raise NotImplementedError - def get_comments_count(self, obj: dict[str, Any]) -> int: - """Retrieve the count of comments for the given thread.""" - count = get_comments_count(obj["_id"]) - return count - def get_closed_by(self, obj: dict[str, Any]) -> Optional[str]: """Retrieve the username of the person who closed the object.""" if closed_by_id := obj.get("closed_by_id"): diff --git a/forum/views/comments.py b/forum/views/comments.py index dfb93e40..9615ca4c 100644 --- a/forum/views/comments.py +++ b/forum/views/comments.py @@ -41,8 +41,8 @@ def create_comment( new_comment_id = Comment().insert( body=data["body"], course_id=data["course_id"], - anonymous=data.get("anonymous", False), - anonymous_to_peers=data.get("anonymous_to_peers", False), + anonymous=str_to_bool(data.get("anonymous", "False")), + anonymous_to_peers=str_to_bool(data.get("anonymous_to_peers", "False")), author_id=data["user_id"], comment_thread_id=thread_id, parent_id=parent_id, @@ -114,7 +114,7 @@ def get(self, request: Request, comment_id: str) -> Response: ) data = prepare_comment_api_response( comment, - exclude_fields=["sk"], + exclude_fields=["sk", "endorsement"], ) return Response(data, status=status.HTTP_200_OK) diff --git a/forum/views/subscriptions.py b/forum/views/subscriptions.py index 60f18b98..9a58ea6e 100644 --- a/forum/views/subscriptions.py +++ b/forum/views/subscriptions.py @@ -180,6 +180,7 @@ def _validate_params(self, params: dict[str, Any]) -> Response | None: "sort_key", "page", "per_page", + "request_id'", ] for key in params: