Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 79 additions & 13 deletions forum/models/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from bson import ObjectId

from forum.models.threads import CommentThread
from forum.models.contents import BaseContents
from forum.models.users import Users

Expand Down Expand Up @@ -96,10 +97,10 @@ def insert(
Returns:
str: The ID of the inserted document.
"""
parent_comment = parent_id and self.get(parent_id)
parent_child_count = parent_comment and parent_comment.get("child_count") or 0
if parent_comment and not comment_thread_id:
comment_thread_id = parent_comment.get("comment_thread_id")
if parent_id and not comment_thread_id:
parent_comment = self.get(parent_id)
if parent_comment:
comment_thread_id = parent_comment.get("comment_thread_id")

date = datetime.now()
comment_data = {
Expand All @@ -115,7 +116,6 @@ def insert(
"endorsed": False,
"anonymous": anonymous,
"anonymous_to_peers": anonymous_to_peers,
"parent_id": ObjectId(parent_id) if parent_id else None,
"author_id": author_id,
"comment_thread_id": ObjectId(comment_thread_id),
"child_count": 0,
Expand All @@ -124,9 +124,14 @@ def insert(
"created_at": date,
"updated_at": date,
}
if parent_id:
comment_data["parent_id"] = ObjectId(parent_id)

result = self._collection.insert_one(comment_data)
if parent_id:
self.update(parent_id, child_count=parent_child_count + 1)
self.update_child_count_in_parent_comment(parent_id, 1)
if comment_thread_id:
self.update_comment_count_in_comment_thread(comment_thread_id, 1)
return str(result.inserted_id)

def update(
Expand Down Expand Up @@ -238,16 +243,77 @@ def delete(self, _id: str) -> int:
"""
comment = self.get(_id)
parent_comment_id = comment and comment.get("parent_id")
parent_comment = parent_comment_id and self.get(parent_comment_id)
parent_comment_child_count = parent_comment and parent_comment.get(
"child_count",
)
child_comments_deleted_count = 0
if not parent_comment_id:
child_comments_deleted_count = self.delete_child_comments(_id)

result = self._collection.delete_one({"_id": ObjectId(_id)})
if parent_comment_id and parent_comment_child_count:
self.update(parent_comment_id, child_count=parent_comment_child_count - 1)
return result.deleted_count
if parent_comment_id:
self.update_child_count_in_parent_comment(parent_comment_id, -1)

no_of_comments_delete = result.deleted_count + child_comments_deleted_count
comment_thread_id = comment and comment.get("comment_thread_id")
if comment_thread_id:
self.update_comment_count_in_comment_thread(
comment_thread_id, -(int(no_of_comments_delete))
)
return no_of_comments_delete

def get_author_username(self, author_id: str) -> str | None:
"""Return username for the respective author_id(user_id)"""
user = Users().get(author_id)
return user.get("username") if user else None

def delete_child_comments(self, _id: str) -> int:
"""
Delete child comments from the database based on the id.

Args:
_id: The ID of the parent comment whose child comments will be deleted.

Returns:
The number of child comments deleted.
"""
child_comments_to_delete = self.find({"parent_id": ObjectId(_id)})
child_comment_ids_to_delete = [
child_comment.get("_id") for child_comment in child_comments_to_delete
]
child_comments_deleted = self._collection.delete_many(
{"_id": {"$in": child_comment_ids_to_delete}}
)
return child_comments_deleted.deleted_count

def update_child_count_in_parent_comment(self, parent_id: str, count: int) -> None:
"""
Update(increment/decrement) child_count in parent comment.

Args:
parent_id: The ID of the parent comment whose child_count will be updated.
count: It can be any number.
If positive, this function will increase child_count by the count.
If negative, this function will decrease child_count by the count.

Returns:
None.
"""
update_child_count_query = {"$inc": {"child_count": count}}
self.update_count(parent_id, update_child_count_query)

def update_comment_count_in_comment_thread(
self, comment_thread_id: str, count: int
) -> None:
"""
Update(increment/decrement) comment_count in comment thread.

Args:
comment_thread_id: The ID of the comment thread
whose comment_count will be updated.
count: It can be any number.
If positive, this function will increase comment_count by the count.
If negative, this function will decrease comment_count by the count.

Returns:
None.
"""
update_comment_count_query = {"$inc": {"comment_count": count}}
CommentThread().update_count(comment_thread_id, update_comment_count_query)
14 changes: 14 additions & 0 deletions forum/models/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def update_votes(self, content_id: str, votes: dict[str, Any]) -> int:
)
return result.modified_count

def update_count(self, content_id: str, query: dict[str, Any]) -> int:
"""
Updates count of a field in the content document based on query.

Args:
content_id (str): The id of the content(Commentthread id or Comment id) model.
query (dict[str, Any]): Query to update the count in a specific field.
"""
result = self._collection.update_one(
{"_id": ObjectId(content_id)},
query,
)
return result.modified_count


class Contents(BaseContents):
"""
Expand Down
8 changes: 0 additions & 8 deletions forum/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,6 @@ def remove_vote(thread: dict[str, Any], user: dict[str, Any]) -> bool:
return update_vote(thread, user, is_deleted=True)


def get_comments_count(thread_id: str) -> int:
"""
Returns that comments count in a perticular thread
"""
comments = list(Comment().list(comment_thread_id=ObjectId(thread_id)))
return len(comments) if comments else 0


def validate_thread_and_user(
user_id: str, thread_id: str
) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down
14 changes: 4 additions & 10 deletions forum/serializers/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from forum.models.model_utils import (
get_abuse_flagged_count,
get_comments_count,
get_endorsed,
get_read_states,
get_username_from_id,
Expand All @@ -32,7 +31,7 @@ class ThreadSerializer(ContentSerializer):
tags (list): A list of tags associated with the thread.
group_id (str or None): The ID of the group associated with the thread, if any.
pinned (bool): Whether the thread is pinned at the top of the list.
comments_count (int): The number of comments on the thread.
comment_count (int): The number of comments on the thread.

This serializer extends the `ThreadSerializer` and customizes fields based on various context
parameters. It manages fields related to read state, comment counts, endorsements, abuse flags,
Expand All @@ -59,7 +58,7 @@ class ThreadSerializer(ContentSerializer):
tags = serializers.ListField(default=[])
group_id = serializers.CharField(allow_null=True, default=None)
pinned = serializers.BooleanField(default=False)
comments_count = serializers.SerializerMethodField()
comment_count = serializers.IntegerField(default=0)
read = serializers.SerializerMethodField()
unread_comments_count = serializers.SerializerMethodField()
endorsed = serializers.SerializerMethodField()
Expand Down Expand Up @@ -126,7 +125,7 @@ def get_read(self, obj: dict[str, Any]) -> Optional[bool]:
course_id = obj["course_id"]
thread_key = obj["id"]
is_read, _ = get_read_states([obj], user_id, course_id).get(
thread_key, (False, obj["comments_count"])
thread_key, (False, obj["comment_count"])
)
return is_read
return None
Expand All @@ -148,7 +147,7 @@ def get_unread_comments_count(self, obj: dict[str, Any]) -> Optional[int]:
course_id = obj["course_id"]
thread_key = obj["id"]
_, unread_count = get_read_states([obj], user_id, course_id).get(
thread_key, (False, obj["comments_count"])
thread_key, (False, obj["comment_count"])
)
return unread_count
return None
Expand Down Expand Up @@ -243,11 +242,6 @@ def update(self, instance: Any, validated_data: dict[str, Any]) -> Any:
"""Raise NotImplementedError"""
raise NotImplementedError

def get_comments_count(self, obj: dict[str, Any]) -> int:
"""Retrieve the count of comments for the given thread."""
count = get_comments_count(obj["_id"])
return count

def get_closed_by(self, obj: dict[str, Any]) -> Optional[str]:
"""Retrieve the username of the person who closed the object."""
if closed_by_id := obj.get("closed_by_id"):
Expand Down
6 changes: 3 additions & 3 deletions forum/views/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def create_comment(
new_comment_id = Comment().insert(
body=data["body"],
course_id=data["course_id"],
anonymous=data.get("anonymous", False),
anonymous_to_peers=data.get("anonymous_to_peers", False),
anonymous=str_to_bool(data.get("anonymous", "False")),
anonymous_to_peers=str_to_bool(data.get("anonymous_to_peers", "False")),
author_id=data["user_id"],
comment_thread_id=thread_id,
parent_id=parent_id,
Expand Down Expand Up @@ -114,7 +114,7 @@ def get(self, request: Request, comment_id: str) -> Response:
)
data = prepare_comment_api_response(
comment,
exclude_fields=["sk"],
exclude_fields=["sk", "endorsement"],
)
return Response(data, status=status.HTTP_200_OK)

Expand Down
1 change: 1 addition & 0 deletions forum/views/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _validate_params(self, params: dict[str, Any]) -> Response | None:
"sort_key",
"page",
"per_page",
"request_id'",
]

for key in params:
Expand Down